from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from scipy import interpolate
from ..visualization import auto_grid, auto_range_iqr
from .data import get_current_through_paths
from .solution import Solution
def setup_color_limits(
dict_of_arrays: Dict[str, np.ndarray],
vmin: Union[float, None] = None,
vmax: Union[float, None] = None,
share_color_scale: bool = False,
symmetric_color_scale: bool = False,
auto_range_cutoff: Optional[Union[float, Tuple[float, float]]] = None,
) -> Dict[str, Tuple[float, float]]:
"""Set up color limits (vmin, vmax) for a dictionary of numpy arrays.
Args:
dict_of_arrays: Dict of ``{name: array}`` for which to compute color limits.
vmin: If provided, this vmin will be used for all arrays. If vmin is not None,
then vmax must also not be None.
vmax: If provided, this vmax will be used for all arrays. If vmax is not None,
then vmin must also not be None.
share_color_scale: Whether to force all arrays to share the same color scale.
This option is ignored if vmin and vmax are provided.
symmetric_color_scale: Whether to use a symmetric color scale (vmin = -vmax).
This option is ignored if vmin and vmax are provided.
auto_range_cutoff: Cutoff percentile for :func:`tdgl.solution.plot_solution.auto_range_iqr`.
Returns:
A dict of ``{name: (vmin, vmax)}``
"""
if (vmin is not None and vmax is None) or (vmax is not None and vmin is None):
raise ValueError("If either vmin or max is provided, both must be provided.")
if vmin is not None:
return {name: (vmin, vmax) for name in dict_of_arrays}
if auto_range_cutoff is None:
clims = {
name: (np.nanmin(array), np.nanmax(array))
for name, array in dict_of_arrays.items()
}
else:
clims = {
name: auto_range_iqr(array, cutoff_percentile=auto_range_cutoff)
for name, array in dict_of_arrays.items()
}
if share_color_scale:
# All subplots share the same color scale
global_vmin = np.inf
global_vmax = -np.inf
for vmin, vmax in clims.values():
global_vmin = min(vmin, global_vmin)
global_vmax = max(vmax, global_vmax)
clims = {name: (global_vmin, global_vmax) for name in dict_of_arrays}
if symmetric_color_scale:
# Set vmin = -vmax
new_clims = {}
for name, (vmin, vmax) in clims.items():
new_vmax = max(vmax, -vmin)
new_clims[name] = (-new_vmax, new_vmax)
clims = new_clims
return clims
def cross_section(
dataset_coords: np.ndarray,
dataset_values: np.ndarray,
cross_section_coords: Union[np.ndarray, Sequence[np.ndarray]],
interp_method: Literal["linear", "cubic"] = "linear",
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Takes a cross-section of the specified dataset values along
a path given by the given dataset coordinates.
Args:
dataset_coords: A shape (n, 2) array of (x, y) coordinates for the dataset.
dataset_values: A shape (n, ) array of dataset values of which
to take a cross-section.
cross_section_coords: A shape (m, 2) array of (x, y) coordinates specifying
the cross-section path (or a list of such arrays for multiple
cross sections).
interp_method: The interpolation method to use: "linear" or "cubic".
Returns:
A list of coordinate arrays, a list of curvilinear coordinate (path) arrays,
and a list of cross section values.
"""
valid_methods = ("linear", "cubic")
if interp_method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} "
f"(got {interp_method})."
)
interpolator = {
"linear": interpolate.LinearNDInterpolator,
"cubic": interpolate.CloughTocher2DInterpolator,
}[interp_method]
if not (isinstance(cross_section_coords, Sequence)):
cross_section_coords = [cross_section_coords]
cross_section_coords = [np.asarray(c) for c in cross_section_coords]
for i, arr in enumerate(cross_section_coords):
if arr.ndim != 2 or arr.shape[1] != 2:
raise ValueError(
f"Invalid shape for coordinate array {i}: {arr.shape}. "
f"Coordinate arrays must have shape (n, 2)."
)
# Calculcate curvilinear cross section coordinates
paths = []
for c in cross_section_coords:
path = np.cumsum(np.sqrt(np.sum(np.diff(c, axis=0) ** 2, axis=1)))
paths.append(np.concatenate([[0], path], axis=0))
# Calculate cross sections.
cross_sections = []
mask = np.isfinite(dataset_values)
z_interp = interpolator(dataset_coords[mask], dataset_values[mask])
for c in cross_section_coords:
cross_sections.append(z_interp(c[:, 0], c[:, 1]))
return cross_section_coords, paths, cross_sections
[docs]def plot_currents(
solution: Solution,
ax: Union[plt.Axes, None] = None,
dataset: Union[str, None] = None,
units: Union[str, None] = None,
cmap: str = "inferno",
colorbar: bool = True,
auto_range_cutoff: Optional[Union[float, Tuple[float, float]]] = None,
symmetric_color_scale: bool = False,
vmin: Union[float, None] = None,
vmax: Union[float, None] = None,
streamplot: bool = True,
min_stream_amp: float = 0.025,
cross_section_coords: Union[np.ndarray, Sequence[np.ndarray], None] = None,
**kwargs,
) -> Tuple[plt.Figure, Sequence[plt.Axes]]:
"""Plots the sheet current density for a given :class:`tdgl.Solution`.
Additional keyword arguments are passed to ``plt.subplots()``.
.. seealso:
:meth:`tdgl.Solution.plot_currents`
Args:
solution: The Solution from which to extract sheet current.
dataset: The dataset to plot, either ``"supercurrent"`` or
``"normal_current"``. ``None`` indicates the total current density.
ax: Matplotlib axes on which to plot.
units: Units in which to plot the current density. Defaults to
``solution.current_units / solution.device.length_units``.
cmap: Name of the matplotlib colormap to use.
colorbar: Whether to add a colorbar to each subplot.
auto_range_cutoff: Cutoff percentile for :func:`tdgl.solution.plot_solution.auto_range_iqr`.
symmetric_color_scale: Whether to use a symmetric color scale (vmin = -vmax).
vmin: Color scale minimum to use for all layers
vmax: Color scale maximum to use for all layers
streamplot: Whether to overlay current streamlines on the plot.
min_stream_amp: Streamlines will not be drawn anywhere the
current density is less than min_stream_amp * max(current_density).
This avoids streamlines being drawn where there is no current flowing.
cross_section_coords: Shape (m, 2) array of (x, y) coordinates for a
cross-section (or a list of such arrays).
Returns:
matplotlib figure and axes
"""
device = solution.device
length_units = device.ureg(device.length_units).units
old_units = device.ureg(f"{solution.current_units} / {device.length_units}").units
units = units or old_units
if isinstance(units, str):
units = device.ureg(units).units
if dataset is None:
J = solution.current_density
elif dataset in ["supercurrent"]:
J = solution.supercurrent_density
elif dataset in ["normal_current"]:
J = solution.normal_current_density
else:
raise ValueError(f"Unexpected dataset: {dataset}.")
if ax is None:
fig, ax = plt.subplots(**kwargs)
else:
fig = ax.get_figure()
J = J.to(units).magnitude
Jx = J[:, 0]
Jy = J[:, 1]
Jnorm = np.sqrt(Jx**2 + Jy**2)
x = solution.device.points[:, 0]
y = solution.device.points[:, 1]
t = solution.device.triangles
clabel = "$|\\,\\vec{K}\\,|$" + f" [${units:~L}$]"
clim = setup_color_limits(
{"J": Jnorm},
vmin=vmin,
vmax=vmax,
symmetric_color_scale=symmetric_color_scale,
auto_range_cutoff=auto_range_cutoff,
)["J"]
vmin, vmax = clim
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im = ax.tripcolor(x, y, t, Jnorm, shading="gouraud", cmap=cmap, norm=norm)
ax.set_title(f"{clabel.split('[')[0].strip()}")
ax.set_aspect("equal")
ax.set_xlabel(f"$x$ [${length_units:~L}$]")
ax.set_ylabel(f"$y$ [${length_units:~L}$]")
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())
if cross_section_coords is not None:
ax_divider = make_axes_locatable(ax)
cax = ax_divider.append_axes("bottom", size="40%", pad="30%")
coords, paths, cross_sections = cross_section(
np.array([x, y]).T,
Jnorm,
cross_section_coords,
)
for i, (coord, path, cross) in enumerate(zip(coords, paths, cross_sections)):
color = f"C{i % 10}"
cross[~device.contains_points(coord)] = np.nan
ax.plot(*coord.T, "--", color=color, lw=2)
ax.plot(*coord[0], "o", color=color)
ax.plot(*coord[-1], "s", color=color)
cax.plot(path, cross, color=color, lw=2)
cax.plot(path[0], cross[0], "o", color=color)
cax.plot(path[-1], cross[-1], "s", color=color)
cax.grid(True)
cax.set_xlabel(f"Distance along cut [${length_units:~L}$]")
cax.set_ylabel(clabel)
if streamplot:
xgrid, ygrid, Jgrid = solution.grid_current_density(
dataset=dataset,
grid_shape=200,
method="cubic",
units=str(units),
with_units=False,
)
Jx, Jy = Jgrid
J = np.sqrt(Jx**2 + Jy**2)
xy = np.array([xgrid.ravel(), ygrid.ravel()]).T
ix = np.where(~solution.device.contains_points(xy))[0]
ix = np.unravel_index(ix, J.shape)
Jx[ix] = np.nan
Jy[ix] = np.nan
if min_stream_amp is not None:
cutoff = np.nanmax(J) * min_stream_amp
Jx[J < cutoff] = np.nan
Jy[J < cutoff] = np.nan
ax.streamplot(xgrid, ygrid, Jx, Jy, color="w", density=1, linewidth=0.75)
if colorbar:
cbar = fig.colorbar(im, ax=ax, orientation="vertical")
cbar.set_label(clabel)
return fig, ax
[docs]def plot_field_at_positions(
solution: Solution,
positions: np.ndarray,
zs: Optional[Union[float, np.ndarray]] = None,
vector: bool = False,
units: Union[str, None] = None,
grid_shape: Union[int, Tuple[int, int]] = (200, 200),
grid_method: str = "cubic",
cmap: str = "cividis",
colorbar: bool = True,
auto_range_cutoff: Optional[Union[float, Tuple[float, float]]] = None,
share_color_scale: bool = False,
symmetric_color_scale: bool = False,
vmin: Union[float, None] = None,
vmax: Union[float, None] = None,
cross_section_coords: Optional[Union[float, List[float]]] = None,
**kwargs,
) -> Tuple[plt.Figure, Sequence[plt.Axes]]:
"""Plots the Biot-Savart field (either all three components or just the
z component) at a given set of positions (x, y, z) outside of the device.
.. note::
This function plots only the field due to currents flowing in the device.
It does not include the applied field.
.. seealso:
:meth:`tdgl.Solution.plot_field_at_positions`
Additional keyword arguments are passed to ``plt.subplots()``. This function first
evaluates the field at ``positions``, then interpolates the resulting fields to a
rectangular grid for plotting.
Args:
solution: The Solution from which to extract fields.
positions: Shape (m, 2) array of (x, y) coordinates, or (m, 3) array of (x, y, z)
coordinates at which to calculate the magnetic field.
zs: z coordinates at which to calculate the field. If positions has shape (m, 3), then
this argument is not allowed. If zs is a scalar, then the fields are calculated in
a plane parallel to the x-y plane. If zs is an array, then it must be same length
as positions.
vector: Whether to plot the full vector magnetic field or just the z component.
units: Units in which to plot the fields. Defaults to ``solution.field_units``.
grid_shape: Shape of the desired rectangular grid. If a single integer ``n``
is given, then the grid will be square, shape ``(n, n)``.
grid_method: Interpolation method to use (see :func:`scipy.interpolate.griddata`).
max_cols: Maximum number of columns in the grid of subplots.
cmap: Name of the matplotlib colormap to use.
colorbar: Whether to add a colorbar to each subplot.
auto_range_cutoff: Cutoff percentile for :func:`tdgl.solution.plot_solution.auto_range_iqr`.
share_color_scale: Whether to force all layers to use the same color scale.
symmetric_color_scale: Whether to use a symmetric color scale (vmin = -vmax).
vmin: Color scale minimum to use for all layers
vmax: Color scale maximum to use for all layers
cross_section_coords: Shape (m, 2) array of (x, y) coordinates for a
cross-section (or a list of such arrays).
Returns:
matplotlib figure and axes
"""
device = solution.device
# Length units from the Device
length_units = device.ureg(device.length_units).units
# The units the fields are currently in
old_units = device.ureg(solution.field_units).units
# The units we want to convert to
if units is None:
units = old_units
if isinstance(units, str):
units = device.ureg(units).units
fields = solution.field_at_position(
positions,
zs=zs,
vector=vector,
units=units,
with_units=False,
)
if fields.ndim == 1:
fields = fields[:, np.newaxis]
if vector:
num_subplots = 3
else:
num_subplots = 1
fig, axes = auto_grid(num_subplots, **kwargs)
if not isinstance(axes, (list, np.ndarray)):
axes = [axes]
x, y, *_ = positions.T
xs = np.linspace(x.min(), x.max(), grid_shape[1])
ys = np.linspace(y.min(), y.max(), grid_shape[0])
xgrid, ygrid = np.meshgrid(xs, ys)
# Shape grid_shape or (grid_shape + (3, ))
fields = interpolate.griddata(
positions[:, :2],
fields,
(xgrid, ygrid),
method=grid_method,
)
clabels = [f"{label} [${units:~L}$]" for label in ["$H_x$ ", "$H_y$ ", "$H_z$ "]]
if "[mass]" in units.dimensionality:
# We want flux density, B = mu0 * H
clabels = ["$\\mu_0$" + clabel for clabel in clabels]
if not vector:
clabels = clabels[-1:]
fields_dict = {label: fields[:, :, i] for i, label in enumerate(clabels)}
clim_dict = setup_color_limits(
fields_dict,
vmin=vmin,
vmax=vmax,
share_color_scale=share_color_scale,
symmetric_color_scale=symmetric_color_scale,
auto_range_cutoff=auto_range_cutoff,
)
for ax, label in zip(fig.axes, clabels):
field = fields_dict[label]
layer_vmin, layer_vmax = clim_dict[label]
norm = mpl.colors.Normalize(vmin=layer_vmin, vmax=layer_vmax)
im = ax.pcolormesh(xgrid, ygrid, field, shading="auto", cmap=cmap, norm=norm)
ax.set_title(f"{label.split('[')[0].strip()}")
ax.set_aspect("equal")
ax.set_xlabel(f"$x$ [${length_units:~L}$]")
ax.set_ylabel(f"$y$ [${length_units:~L}$]")
ax.set_xlim(xgrid.min(), xgrid.max())
ax.set_ylim(ygrid.min(), ygrid.max())
if cross_section_coords is not None:
ax_divider = make_axes_locatable(ax)
cax = ax_divider.append_axes("bottom", size="40%", pad="30%")
coords, paths, cross_sections = cross_section(
np.array([xgrid.ravel(), ygrid.ravel()]).T,
field.ravel(),
cross_section_coords=cross_section_coords,
)
for i, (coord, path, cross) in enumerate(
zip(coords, paths, cross_sections)
):
color = f"C{i % 10}"
ax.plot(*coord.T, "--", color=color, lw=2)
ax.plot(*coord[0], "o", color=color)
ax.plot(*coord[-1], "s", color=color)
cax.plot(path, cross, color=color, lw=2)
cax.plot(path[0], cross[0], "o", color=color)
cax.plot(path[-1], cross[-1], "s", color=color)
cax.grid(True)
cax.set_xlabel(f"Distance along cut [${length_units:~L}$]")
cax.set_ylabel(label)
if colorbar:
cbar = fig.colorbar(im, ax=ax, orientation="vertical")
cbar.set_label(label)
return fig, axes
[docs]def plot_order_parameter(
solution: Solution,
squared: bool = False,
mag_cmap: str = "viridis",
phase_cmap: str = "twilight_shifted",
shading: str = "gouraud",
**kwargs,
) -> Tuple[plt.Figure, Sequence[plt.Axes]]:
"""Plots the magnitude (or the magnitude squared) and
phase of the complex order parameter, :math:`\\psi=|\\psi|e^{i\\theta}`.
.. seealso:
:meth:`tdgl.Solution.plot_order_parameter`
Args:
solution: The solution for which to plot the order parameter.
squared: Whether to plot the magnitude squared, :math:`|\\psi|^2`.
mag_cmap: Name of the colormap to use for the magnitude.
phase_cmap: Name of the colormap to use for the phase.
shading: May be ``"flat"`` or ``"gouraud"``. The latter does some interpolation.
Returns:
matplotlib Figure and an array of two Axes objects.
"""
kwargs.setdefault("figsize", (8, 3))
kwargs.setdefault("constrained_layout", True)
device = solution.device
psi = solution.tdgl_data.psi
mag = np.abs(psi)
psi_label = "$|\\psi|$"
if squared:
mag = mag**2
psi_label = "$|\\psi|^2$"
phase = np.angle(psi) / np.pi
points = device.points
triangles = device.triangles
fig, axes = plt.subplots(1, 2, **kwargs)
im = axes[0].tripcolor(
points[:, 0],
points[:, 1],
mag,
triangles=triangles,
vmin=0,
vmax=1,
cmap=mag_cmap,
shading=shading,
)
cbar = fig.colorbar(im, ax=axes[0])
cbar.set_label(psi_label)
im = axes[1].tripcolor(
points[:, 0],
points[:, 1],
phase,
triangles=triangles,
vmin=-1,
vmax=1,
cmap=phase_cmap,
shading=shading,
)
cbar = fig.colorbar(im, ax=axes[1])
cbar.set_label("$\\theta / \\pi$")
length_units = device.ureg(device.length_units).units
for ax in axes:
ax.set_aspect("equal")
ax.set_xlabel(f"$x$ [${length_units:~L}$]")
ax.set_ylabel(f"$y$ [${length_units:~L}$]")
return fig, axes
[docs]def plot_vorticity(
solution: Solution,
ax: Union[plt.Axes, None] = None,
cmap: str = "coolwarm",
units: Union[str, None] = None,
auto_range_cutoff: Optional[Union[float, Tuple[float, float]]] = None,
symmetric_color_scale: bool = True,
vmin: Union[float, None] = None,
vmax: Union[float, None] = None,
shading: str = "gouraud",
**kwargs,
):
"""Plots the vorticity in the film:
:math:`\\mathbf{\\omega}=\\mathbf{\\nabla}\\times\\mathbf{K}`.
.. seealso:
:meth:`tdgl.Solution.plot_vorticity`
Args:
solution: The solution for which to plot the vorticity.
ax: Matplotlib axes on which to plot.
cmap: Name of the matplotlib colormap to use.
units: The units in which to plot the vorticity. Must have dimensions of
[current] / [length]^2.
auto_range_cutoff: Cutoff percentile for :func:`tdgl.solution.plot_solution.auto_range_iqr`.
symmetric_color_scale: Whether to use a symmetric color scale (vmin = -vmax).
vmin: Color scale minimum.
vmax: Color scale maximum.
shading: May be ``"flat"`` or ``"gouraud"``. The latter does some interpolation.
Returns:
matplotlib Figure and and Axes.
"""
if ax is None:
kwargs.setdefault("constrained_layout", True)
fig, ax = plt.subplots(**kwargs)
else:
fig = ax.get_figure()
ax.set_aspect("equal")
device = solution.device
points = device.points
triangles = device.triangles
length_units = device.ureg(device.length_units).units
if units is None:
units = solution.vorticity.units
else:
units = device.ureg(units)
v = solution.vorticity.to(units).m
clim = setup_color_limits(
{"v": v},
vmin=vmin,
vmax=vmax,
symmetric_color_scale=symmetric_color_scale,
auto_range_cutoff=auto_range_cutoff,
)["v"]
vmin, vmax = clim
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
x, y = points[:, 0], points[:, 1]
im = ax.tripcolor(
x,
y,
v,
triangles=triangles,
cmap=cmap,
norm=norm,
shading=shading,
)
cbar = fig.colorbar(im, ax=ax)
ax.set_title("$\\vec{\\omega}=\\vec{\\nabla}\\times\\vec{K}$")
ax.set_aspect("equal")
ax.set_xlabel(f"$x$ [${length_units:~L}$]")
ax.set_ylabel(f"$y$ [${length_units:~L}$]")
cbar.set_label(f"$\\vec{{\\omega}}\\cdot\\hat{{z}}$ [${units:~L}$]")
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())
return fig, ax
[docs]def plot_scalar_potential(
solution: Solution,
ax: Union[plt.Axes, None] = None,
cmap: str = "magma",
auto_range_cutoff: Optional[Union[float, Tuple[float, float]]] = None,
vmin: Union[float, None] = None,
vmax: Union[float, None] = None,
shading: str = "gouraud",
**kwargs,
):
"""Plots the scalar potential :math:`\\mu(\\mathbf{r})` in the film.
.. seealso:
:meth:`tdgl.Solution.plot_scalar_potential`
Args:
solution: The solution for which to plot the scalar potential.
ax: Matplotlib axes on which to plot.
cmap: Name of the matplotlib colormap to use.
auto_range_cutoff: Cutoff percentile for :func:`tdgl.solution.plot_solution.auto_range_iqr`.
vmin: Color scale minimum.
vmax: Color scale maximum.
shading: May be ``"flat"`` or ``"gouraud"``. The latter does some interpolation.
Returns:
matplotlib Figure and and Axes.
"""
if ax is None:
kwargs.setdefault("constrained_layout", True)
fig, ax = plt.subplots(**kwargs)
else:
fig = ax.get_figure()
ax.set_aspect("equal")
device = solution.device
points = device.points
triangles = device.triangles
length_units = device.ureg(device.length_units).units
mu = solution.tdgl_data.mu
mu = mu - np.nanmin(mu)
clim = setup_color_limits(
{"mu": mu},
vmin=vmin,
vmax=vmax,
auto_range_cutoff=auto_range_cutoff,
)["mu"]
vmin, vmax = clim
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
x, y = points[:, 0], points[:, 1]
im = ax.tripcolor(
x,
y,
mu,
triangles=triangles,
cmap=cmap,
norm=norm,
shading=shading,
)
cbar = fig.colorbar(im, ax=ax)
ax.set_title("$\\mu/v_0$")
ax.set_aspect("equal")
ax.set_xlabel(f"$x$ [${length_units:~L}$]")
ax.set_ylabel(f"$y$ [${length_units:~L}$]")
cbar.set_label("$\\mu/v_0$")
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())
return fig, ax
[docs]def plot_current_through_paths(
solution_path: str,
paths: Union[np.ndarray, List[np.ndarray]],
dataset: Optional[str] = None,
interp_method: Literal["linear", "cubic"] = "linear",
units: Optional[str] = None,
progress_bar: bool = True,
grid: bool = True,
labels: bool = True,
legend: bool = True,
**figure_kwargs,
) -> Tuple[
Tuple[plt.Figure, plt.Axes], Tuple[np.ndarray, Union[np.ndarray, List[np.ndarray]]]
]:
"""Plots the current through one or more paths for each saved time step.
Args:
solution_path: Path to the solution HDF5 file.
paths: A list of ``(n, 2)`` arrays of ``(x, y)`` coordinates defining
the paths. A single ``(n, 2)`` array is also allowed.
dataset: ``None``, ``"supercurrent"``, or ``"normal_current"``.
``None`` indicates the total current.
interp_method: Interpolation method: either "linear" or "cubic".
units: The current units to return.
with_units: Whether to return a :class:`pint.Quantity` with units attached.
progress_bar: Whether to display a progress bar.
grid: Whether to add grid lines to the plot.
labels: Whether to include axis labels.
legend: Whether to include a legend.
Returns:
``(fig, ax), (times, currents)``, where ``currents`` is a list of arrays of
the time-dependent current through each path. If ``paths`` is given as a
single array, ``currents`` will be returned as a single array.
"""
times, currents = get_current_through_paths(
solution_path,
paths,
dataset=dataset,
interp_method=interp_method,
units=units,
with_units=True,
progress_bar=progress_bar,
)
if isinstance(paths, np.ndarray):
currents = [currents]
current_units = currents[0].units
label = {
"supercurrent": "Supercurrent",
"normal_current": "Normal current",
None: "Total current",
}[dataset]
fig, ax = plt.subplots(**figure_kwargs)
ax.grid(grid)
for i, current in enumerate(currents):
ax.plot(times, current.magnitude, label=f"Path {i}")
if labels:
ax.set_ylabel(f"{label} [${current_units:~L}$]")
ax.set_xlabel("Time, $t$ [$\\tau_0$]")
if legend:
ax.legend(loc=0)
return (fig, ax), (times, currents)
def _patch_docstring(func):
other_func = getattr(Solution, func.__name__)
other_func.__doc__ = (
other_func.__doc__
+ "\n\n"
+ "\n".join(
[line for line in func.__doc__.split("\n ") if "solution:" not in line]
)
)
annotations = func.__annotations__.copy()
_ = annotations.pop("solution", None)
other_func.__annotations__.update(annotations)
for func in (
plot_currents,
plot_field_at_positions,
plot_order_parameter,
plot_scalar_potential,
plot_vorticity,
):
_patch_docstring(func)