from collections.abc import Callable
from mpi4py import MPI
import dolfinx
import numpy as np
import ufl
from dolfinx import fem
from packaging import version
[docs]
def as_fenics_constant(
value: float | int | fem.Constant, mesh: dolfinx.mesh.Mesh
) -> fem.Constant:
"""Converts a value to a dolfinx.Constant.
Args:
value: the value to convert
mesh: the mesh of the domiain
Returns:
The converted value
Raises:
TypeError: if the value is not a float, an int or a dolfinx.Constant
"""
if isinstance(value, float | int):
return fem.Constant(mesh, dolfinx.default_scalar_type(float(value)))
elif isinstance(value, fem.Constant):
return value
else:
raise TypeError(
f"Value must be a float, an int or a dolfinx.Constant, not {type(value)}"
)
# TODO change this to accept species dependent values
[docs]
def as_mapped_function(
value: Callable,
function_space: fem.FunctionSpace | None = None,
t: fem.Constant | None = None,
temperature: fem.Function | fem.Constant | ufl.core.expr.Expr | None = None,
) -> ufl.core.expr.Expr:
"""Maps a user given callable function to the mesh, time or temperature within
festim as needed.
Args:
value: the callable to convert
function_space: the function space of the domain, optional
t: the time, optional
temperature: the temperature, optional
Returns:
The mapped function
"""
# Extract the input variable names in the callable function `value`
arguments = value.__code__.co_varnames
kwargs = {}
if "t" in arguments:
kwargs["t"] = t
if "x" in arguments:
x = ufl.SpatialCoordinate(function_space.mesh)
kwargs["x"] = x
if "T" in arguments:
kwargs["T"] = temperature
return value(**kwargs)
# TODO change this to accept species dependent values
[docs]
def as_fenics_interp_expr_and_function(
value: Callable,
function_space: dolfinx.fem.function.FunctionSpace,
t: fem.Constant | None = None,
temperature: fem.Function | fem.Constant | ufl.core.expr.Expr | None = None,
) -> tuple[fem.Expression, fem.Function]:
"""Takes a user given callable function, maps the function to the mesh, time or
temperature within festim as needed. Then creates the fenics interpolation
expression and function objects.
Args:
value: the callable to convert
function_space: The function space to interpolate function over
t: the time, optional
temperature: the temperature, optional
Returns:
fenics interpolation expression, fenics function
"""
mapped_function = as_mapped_function(
value=value, function_space=function_space, t=t, temperature=temperature
)
fenics_interpolation_expression = fem.Expression(
mapped_function,
get_interpolation_points(function_space.element),
)
fenics_object = fem.Function(function_space)
fenics_object.interpolate(fenics_interpolation_expression)
return fenics_interpolation_expression, fenics_object
[docs]
class Value:
"""A class to handle input values from users and convert them to a relevent fenics
object.
Args:
input_value: The value of the user input
Attributes:
input_value : The value of the user input
fenics_interpolation_expression : The expression of the user input that is used
to update the `fenics_object`
fenics_object : The value of the user input in fenics format
explicit_time_dependent : True if the user input value is explicitly time
dependent
temperature_dependent : True if the user input value is temperature dependent
"""
input_value: (
float
| int
| fem.Constant
| np.ndarray
| fem.Expression
| ufl.core.expr.Expr
| fem.Function
)
ufl_expression: ufl.core.expr.Expr
fenics_interpolation_expression: fem.Expression
fenics_object: fem.Function | fem.Constant | ufl.core.expr.Expr
explicit_time_dependent: bool
temperature_dependent: bool
def __init__(self, input_value):
self.input_value = input_value
self.ufl_expression = None
self.fenics_interpolation_expression = None
self.fenics_object = None
def __repr__(self) -> str:
return str(self.input_value)
@property
def input_value(self):
return self._input_value
@input_value.setter
def input_value(self, value):
if value is None:
self._input_value = value
elif isinstance(
value,
float
| int
| fem.Constant
| np.ndarray
| fem.Expression
| ufl.core.expr.Expr
| fem.Function,
):
self._input_value = value
elif callable(value):
self._input_value = value
else:
raise TypeError(
"Value must be a float, int, fem.Constant, np.ndarray, fem.Expression,"
f" ufl.core.expr.Expr, fem.Function, or callable not {value}"
)
@property
def explicit_time_dependent(self) -> bool:
"""Returns true if the value given is time dependent."""
if self.input_value is None:
return False
if isinstance(self.input_value, fem.Constant | ufl.core.expr.Expr):
return False
if callable(self.input_value):
arguments = self.input_value.__code__.co_varnames
return "t" in arguments
else:
return False
@property
def temperature_dependent(self) -> bool:
"""Returns true if the value given is temperature dependent."""
if self.input_value is None:
return False
if isinstance(self.input_value, fem.Constant | ufl.core.expr.Expr):
return False
if callable(self.input_value):
arguments = self.input_value.__code__.co_varnames
return "T" in arguments
else:
return False
[docs]
def update(self, t: float):
"""Updates the value.
Args:
t: the time
"""
if callable(self.input_value):
arguments = self.input_value.__code__.co_varnames
if isinstance(self.fenics_object, fem.Constant) and "t" in arguments:
self.fenics_object.value = float(self.input_value(t=t))
elif isinstance(self.fenics_object, fem.Function):
if self.fenics_interpolation_expression is not None:
self.fenics_object.interpolate(self.fenics_interpolation_expression)
# Check the version of dolfinx
dolfinx_version = dolfinx.__version__
# Define the appropriate method based on the version
if version.parse(dolfinx_version) > version.parse("0.9.0"):
def get_interpolation_points(element):
return element.interpolation_points
else:
def get_interpolation_points(element):
return element.interpolation_points()
[docs]
def nmm_interpolate(
f_out: fem.Function,
f_in: fem.Function,
cells: dolfinx.mesh.MeshTags | None = None,
padding: float | None = 1e-11,
):
"""Non Matching Mesh Interpolate: interpolate one function (f_in) from one mesh into
another function (f_out) with a mismatching mesh
args:
f_out: function to interpolate into
f_in: function to interpolate from
notes:
https://fenicsproject.discourse.group/t/gjk-error-in-interpolation-between-non-matching-second-ordered-3d-meshes/16086/6
"""
if cells is None:
dim = f_out.function_space.mesh.topology.dim
index_map = f_out.function_space.mesh.topology.index_map(dim)
ncells = index_map.size_local + index_map.num_ghosts
cells = np.arange(ncells, dtype=np.int32)
interpolation_data = fem.create_interpolation_data(
f_out.function_space, f_in.function_space, cells, padding=padding
)
f_out.interpolate_nonmatching(f_in, cells, interpolation_data=interpolation_data)
[docs]
def is_it_time_to_export(
times: list | None, current_time: float, atol=0, rtol=1.0e-5
) -> bool:
"""Checks if the exported field should be written to a file or not based on the
current time and the times in `export.times`
After a successful match, the corresponding time is removed from the list to
prevent multiple exports for the same target time.
Args:
current_time: the current simulation time
atol: absolute tolerance for time comparison
rtol: relative tolerance for time comparison
times: the times at which the field should be exported, if None, returns True
Returns:
bool: True if the exported field should be written to a file, else False
"""
if times is None:
return True
for i, time in enumerate(times):
if np.isclose(time, current_time, atol=atol, rtol=rtol):
times.pop(i) # consume the time so it is not exported again
return True
return False
_residual0 = 0
_prev_xnorm = 0
def convergenceTest(snes, it, norms):
global _residual0
_xnorm, gnorm, f = norms # ||x_k||, ||x_k-x_k-1||, ||F(x_k)||
rtol, atol, stol, max_its = snes.getTolerances()
if it == 0:
_residual0 = f
if it > max_its:
return snes.ConvergedReason.DIVERGED_MAX_IT
elif f < atol:
# elif f < atol and it > 0:
return snes.ConvergedReason.CONVERGED_FNORM_ABS
elif f / _residual0 < rtol:
return snes.ConvergedReason.CONVERGED_FNORM_RELATIVE
elif gnorm < stol and it > 0:
return snes.ConvergedReason.CONVERGED_SNORM_RELATIVE
else:
return snes.ConvergedReason.ITERATING
def SnesMonitor(snes, iter, rnorm):
global _prev_xnorm
if MPI.COMM_WORLD.rank == 0:
rtol, atol, stol, _max_its = snes.getTolerances()
x = snes.getSolution()
xnorm = x.norm()
stepsize_rel = abs(xnorm - _prev_xnorm) / xnorm if iter > 0 else float("inf")
if iter == 0:
relative_residual = float("inf")
else:
relative_residual = rnorm / _residual0
dolfinx.log.log(
dolfinx.log.LogLevel.INFO,
f"SNES {iter=} ; {rnorm=:.5e} ({atol=}) ; {relative_residual=:.5e} ({rtol=}) ; {stepsize_rel=:.5e} ({stol=:.5e})", # noqa: E501
)
# Update previous xnorm
_prev_xnorm = xnorm
def KSPMonitor(ksp, iter, rnorm):
dolfinx.log.log(dolfinx.log.LogLevel.DEBUG, f"KSP {iter=}, {_residual0=:.5e}")
if MPI.COMM_WORLD.rank == 0:
dolfinx.log.log(dolfinx.log.LogLevel.DEBUG, f"KSP {iter=} {rnorm=:.5e}")