Source code for tdgl.parameter

import hashlib
import inspect
import operator
from numbers import Number
from typing import Callable, Optional, Union

import cloudpickle
import numpy as np


class _FakeArgSpec:
    def __init__(
        self,
        args=None,
        varargs=None,
        varkw=None,
        defaults=None,
        kwonlyargs=None,
        kwonlydefaults=None,
        annotations=None,
    ):
        self.args = args
        self.varargs = varargs
        self.varkw = varkw
        self.defaults = defaults
        self.kwonlyargs = kwonlyargs
        self.kwonlydefaults = kwonlydefaults
        self.annotations = annotations


def function_repr(
    func: Callable,
    argspec: Optional[Union[_FakeArgSpec, inspect.FullArgSpec]] = None,
) -> str:
    """Returns a human-readable string representation for a function."""
    if argspec is None:
        argspec = inspect.getfullargspec(func)
    args = [str(arg) for arg in argspec.args]

    if argspec.defaults:
        for i, val in enumerate(argspec.defaults[::-1]):
            args[-(i + 1)] = args[-(i + 1)] + f"={val!r}"

    if argspec.varargs:
        args.append("*" + argspec.varargs)

    if argspec.kwonlyargs:
        if not argspec.varargs:
            args.append("*")
        args.extend(argspec.kwonlyargs)
    if argspec.kwonlydefaults:
        for i, name in enumerate(args):
            if name in argspec.kwonlydefaults:
                args[i] = args[i] + f"={argspec.kwonlydefaults[name]!r}"
    if argspec.varkw:
        args.append("**" + argspec.varkw)

    if argspec.annotations:
        for i, name in enumerate(args):
            if name in argspec.annotations:
                args[i] = args[i] + f": {argspec.annotations[name].__name__!r}"

    return func.__name__ + "(" + ", ".join(args) + ")"


[docs]class Parameter: """A callable object that computes a scalar or vector quantity as a function of position coordinates x, y (and optionally z and time t). Addition, subtraction, multiplication, and division between multiple Parameters and/or numbers is supported. The result of any of these operations is a ``CompositeParameter`` object. Args: func: A callable/function that actually calculates the parameter's value. The function must take x, y (and optionally z) as the first and only positional arguments, and all other arguments must be keyword arguments. Therefore func should have a signature like ``func(x, y, z, a=1, b=2, c=True)``, ``func(x, y, *, a, b, c)``, ``func(x, y, z, *, a, b, c)``, or ``func(x, y, z, *, a, b=None, c=3)``. For time-dependent Parameters, ``func`` must also take time ``t`` as a keyword-only argument. time_dependent: Specifies that ``func`` is a function of time ``t``. kwargs: Keyword arguments for func. """ __slots__ = ("func", "kwargs", "time_dependent", "_cache", "_use_cache") def __init__(self, func: Callable, time_dependent: bool = False, **kwargs): self._use_cache = kwargs.pop("use_cache", None) argspec = inspect.getfullargspec(func) args = argspec.args num_args = 2 if args[:num_args] != ["x", "y"]: raise ValueError( "The first function arguments must be x and y, " f"not {', '.join(args[:num_args])!r}." ) if "z" in args: if args.index("z") != num_args: raise ValueError( "If the function takes an argument z, " "it must be the third argument (x, y, z)." ) num_args = 3 defaults = argspec.defaults or [] if len(defaults) != len(args) - num_args: raise ValueError( "All arguments other than x, y, z must be keyword arguments." ) self.time_dependent = time_dependent defaults_dict = dict(zip(args[num_args:], defaults)) kwonlyargs = set(kwargs) - set(argspec.args[num_args:]) if not kwonlyargs.issubset(set(argspec.kwonlyargs or [])): raise ValueError( f"Provided keyword-only arguments ({kwonlyargs!r}) " f"do not match the function signature: {function_repr(func)}." ) defaults_dict.update(argspec.kwonlydefaults or {}) self.func = func self.kwargs = defaults_dict self.kwargs.update(kwargs) self._cache = {} if self.time_dependent and "t" not in argspec.kwonlyargs: raise ValueError( "A time-dependent Parameter must take time t as a keyword argument." ) def _hash_args(self, x, y, z, t) -> str: def _coerce_to_tuple(a): try: return tuple(_coerce_to_tuple(i) for i in a) except TypeError: return a def _to_tuple(items): results = [] for key, value in items: if isinstance(value, dict): value = _to_tuple(value.items()) elif isinstance(value, (list, np.ndarray)): value = _coerce_to_tuple(value) results.append((key, value)) return tuple(results) return ( hex(hash(_to_tuple(self.kwargs.items()))) + hashlib.sha1(np.ascontiguousarray(x)).hexdigest() + hashlib.sha1(np.ascontiguousarray(y)).hexdigest() + hashlib.sha1(np.ascontiguousarray(z)).hexdigest() + hex(hash(t)) ) def _evaluate( self, x: Union[Number, np.ndarray], y: Union[Number, np.ndarray], z: Optional[Union[Number, np.ndarray]] = None, t: Optional[float] = None, ) -> Union[Number, np.ndarray]: kwargs = self.kwargs.copy() if t is not None: kwargs["t"] = t x, y = np.atleast_1d(x, y) if z is not None: kwargs["z"] = np.atleast_1d(z) result = np.asarray(self.func(x, y, **kwargs)).squeeze() if result.ndim == 0: result = result.item() return result def __call__( self, x: Union[Number, np.ndarray], y: Union[Number, np.ndarray], z: Optional[Union[Number, np.ndarray]] = None, t: Optional[float] = None, ) -> Union[Number, np.ndarray]: if self._use_cache: cache_key = self._hash_args(x, y, z, t) if cache_key not in self._cache: self._cache[cache_key] = self._evaluate(x, y, z, t) return self._cache[cache_key] return self._evaluate(x, y, z, t) def _clear_cache(self) -> None: self._cache.clear() def _get_argspec(self) -> _FakeArgSpec: if self.kwargs: kwargs, kwarg_values = list(zip(*self.kwargs.items())) else: kwargs = [] kwarg_values = [] kwargs = list(kwargs) kwarg_values = list(kwarg_values) if self.time_dependent: kwargs.insert(0, "time_dependent") kwarg_values.insert(0, True) return _FakeArgSpec(args=kwargs, defaults=kwarg_values) def __repr__(self) -> str: func_repr = function_repr(self.func, argspec=self._get_argspec()) return f"{self.__class__.__name__}<{func_repr}>" def __add__(self, other) -> "CompositeParameter": """self + other""" return CompositeParameter(self, other, operator.add) def __radd__(self, other) -> "CompositeParameter": """other + self""" return CompositeParameter(other, self, operator.add) def __sub__(self, other) -> "CompositeParameter": """self - other""" return CompositeParameter(self, other, operator.sub) def __rsub__(self, other) -> "CompositeParameter": """other - self""" return CompositeParameter(other, self, operator.sub) def __mul__(self, other) -> "CompositeParameter": """self * other""" return CompositeParameter(self, other, operator.mul) def __rmul__(self, other) -> "CompositeParameter": """other * self""" return CompositeParameter(other, self, operator.mul) def __truediv__(self, other) -> "CompositeParameter": """self / other""" return CompositeParameter(self, other, operator.truediv) def __rtruediv__(self, other) -> "CompositeParameter": """other / self""" return CompositeParameter(other, self, operator.truediv) def __pow__(self, other) -> "CompositeParameter": """self ** other""" return CompositeParameter(self, other, operator.pow) def __rpow__(self, other) -> "CompositeParameter": """other ** self""" return CompositeParameter(other, self, operator.pow) def __eq__(self, other) -> bool: if other is self: return True if not isinstance(other, Parameter): return False # Check if function bytecode is the same if self.func.__code__ != other.func.__code__: return False if set(self.kwargs) != set(other.kwargs): return False def array_safe_equals(a, b) -> bool: """Check if a and b are equal, even if they are numpy arrays.""" if a is b: return True if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): return a.shape == b.shape and np.allclose(a, b) try: return a == b except TypeError: return NotImplemented for key in self.kwargs: if not array_safe_equals(self.kwargs[key], other.kwargs[key]): return False return True
class CompositeParameter(Parameter): """A callable object that behaves like a Parameter (i.e. it computes a scalar or vector quantity as a function of position coordinates x, y, z). A CompositeParameter object is created as a result of mathematical operations between Parameters, CompositeParameters, and/or numbers. Addition, subtraction, multiplication, division, and exponentiation between ``Parameters``, ``CompositeParameters`` and numbers are supported. The result of any of these operations is a new ``CompositeParameter`` object. Args: left: The object on the left-hand side of the operator. right: The object on the right-hand side of the operator. operator_: The operator acting on left and right (or its string representation). """ VALID_OPERATORS = { operator.add: "+", operator.sub: "-", operator.mul: "*", operator.truediv: "/", operator.pow: "**", } def __init__( self, left: Union[Number, Parameter, "CompositeParameter"], right: Union[Number, Parameter, "CompositeParameter"], operator_: Union[Callable, str], ): valid_types = (Number, Parameter, CompositeParameter) if not isinstance(left, valid_types): raise TypeError( f"Left must be a number, Parameter, or CompositeParameter, " f"not {type(left)!r}." ) if not isinstance(right, valid_types): raise TypeError( f"Right must be a number, Parameter, or CompositeParameter, " f"not {type(right)!r}." ) if isinstance(left, Number) and isinstance(right, Number): raise TypeError( "Either left or right must be a Parameter or CompositeParameter." ) if isinstance(operator_, str): operators = {v: k for k, v in self.VALID_OPERATORS.items()} operator_ = operators.get(operator_.strip(), None) if operator_ not in self.VALID_OPERATORS: raise ValueError( f"Unknown operator, {operator_!r}. " f"Valid operators are {list(self.VALID_OPERATORS)!r}." ) self._cache = {} self.left = left self.right = right self.operator = operator_ self.time_dependent = False if isinstance(self.left, Parameter) and self.left.time_dependent: self.time_dependent = True if self.left._use_cache is None: self.left._use_cache = True if isinstance(self.right, Parameter) and self.right.time_dependent: self.time_dependent = True if self.right._use_cache is None: self.right._use_cache = True def _clear_cache(self) -> None: self._cache.clear() if isinstance(self.right._cache, Parameter): self.right._clear_cache() if isinstance(self.left, Parameter): self.left._clear_cache() def __call__( self, x: Union[Number, np.ndarray], y: Union[Number, np.ndarray], z: Union[Number, np.ndarray, None] = None, t: Optional[float] = None, ) -> Union[Number, np.ndarray]: kwargs = dict() if t is None else dict(t=t) values = [] for operand in (self.left, self.right): if isinstance(operand, Parameter): if operand.time_dependent: value = operand(x, y, z, **kwargs) else: value = operand(x, y, z) else: value = operand values.append(value) return self.operator(*values) def _bare_repr(self) -> str: op_str = self.VALID_OPERATORS[self.operator] if isinstance(self.left, CompositeParameter): left_repr = self.left._bare_repr() elif isinstance(self.left, Parameter): left_argspec = self.left._get_argspec() left_repr = function_repr(self.left.func, left_argspec) else: left_repr = str(self.left) if isinstance(self.right, CompositeParameter): right_repr = self.right._bare_repr() elif isinstance(self.right, Parameter): right_argspec = self.right._get_argspec() right_repr = function_repr(self.right.func, right_argspec) else: right_repr = str(self.right) return f"({left_repr} {op_str} {right_repr})" def __eq__(self, other) -> bool: if other is self: return True if not isinstance(other, type(self)): return False return ( self.left == other.left and self.right == other.right and self.operator is other.operator ) def __repr__(self) -> str: return f"{self.__class__.__name__}<{self._bare_repr()}>" def __getstate__(self): state = self.__dict__.copy() state["left"] = cloudpickle.dumps(state["left"]) state["right"] = cloudpickle.dumps(state["right"]) return state def __setstate__(self, state): state["left"] = cloudpickle.loads(state["left"]) state["right"] = cloudpickle.loads(state["right"]) self.__dict__.update(state) class Constant(Parameter): """A Parameter whose value doesn't depend on position or time.""" def __init__(self, value: Number, dimensions: int = 2): if dimensions not in (2, 3): raise ValueError(f"Dimensions must be 2 or 3, got {dimensions}.") if dimensions == 2: def constant(x, y, value=0): return value * np.ones_like(x) else: def constant(x, y, z, value=0): return value * np.ones_like(x) super().__init__(constant, value=value)