Source code for tdgl.visualization.common

import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Sequence, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np


class Quantity(Enum):
    ORDER_PARAMETER = "Order parameter"
    PHASE = "Phase"
    SUPERCURRENT = "Supercurrent density"
    NORMAL_CURRENT = "Normal current density"
    VORTICITY = "Vorticity"
    SCALAR_POTENTIAL = "Scalar potential"
    APPLIED_VECTOR_POTENTIAL = "Applied vector potential"
    INDUCED_VECTOR_POTENTIAL = "Induced vector potential"
    EPSILON = "Epsilon"

    @classmethod
    def get_keys(cls) -> Sequence[str]:
        return list(item.name for item in Quantity)

    @classmethod
    def from_key(cls, key: str) -> "Quantity":
        return Quantity[key.upper()]


colormaps = {
    Quantity.ORDER_PARAMETER: "viridis",
    Quantity.PHASE: "twilight_shifted",
    Quantity.SUPERCURRENT: "inferno",
    Quantity.NORMAL_CURRENT: "inferno",
    Quantity.SCALAR_POTENTIAL: "magma",
    Quantity.APPLIED_VECTOR_POTENTIAL: "cividis",
    Quantity.INDUCED_VECTOR_POTENTIAL: "cividis",
    Quantity.EPSILON: "viridis",
    Quantity.VORTICITY: "coolwarm",
}


@dataclass
class PlotDefault:
    cmap: str
    clabel: str
    xlabel: str = "$x/\\xi$"
    ylabel: str = "$y/\\xi$"
    vmin: Union[float, None] = None
    vmax: Union[float, None] = None
    symmetric: bool = False


PLOT_DEFAULTS = {
    Quantity.ORDER_PARAMETER: PlotDefault(
        cmap="viridis", clabel="$|\\psi|$", vmin=0, vmax=1
    ),
    Quantity.PHASE: PlotDefault(
        cmap="twilight_shifted", clabel="$\\arg(\\psi)/\\pi$", vmin=-1, vmax=1
    ),
    Quantity.SUPERCURRENT: PlotDefault(cmap="inferno", clabel="$|\\vec{{J}}_s|/J_0$"),
    Quantity.NORMAL_CURRENT: PlotDefault(cmap="inferno", clabel="$|\\vec{{J}}_n|/J_0$"),
    Quantity.SCALAR_POTENTIAL: PlotDefault(cmap="magma", clabel="$\\mu/v_0$"),
    Quantity.APPLIED_VECTOR_POTENTIAL: PlotDefault(
        cmap="cividis", clabel="$a_\\mathrm{{applied}}/(\\xi B_{{c2}})$"
    ),
    Quantity.INDUCED_VECTOR_POTENTIAL: PlotDefault(
        cmap="cividis", clabel="$a_\\mathrm{{induced}}/(\\xi B_{{c2}})$"
    ),
    Quantity.EPSILON: PlotDefault(
        cmap="viridis", clabel="$\\epsilon$", vmin=-1, vmax=1
    ),
    Quantity.VORTICITY: PlotDefault(
        cmap="coolwarm",
        clabel="$(\\vec{{\\nabla}}\\times\\vec{{J}})\\cdot\\hat{{z}}$",
        symmetric=True,
    ),
}

DEFAULT_QUANTITIES = (
    "order_parameter",
    "phase",
    "supercurrent",
    "normal_current",
)


[docs]def auto_grid( num_plots: int, max_cols: int = 3, delaxes: bool = True, **kwargs, ) -> Tuple[plt.Figure, Sequence[plt.Axes]]: """Creates a grid of at least ``num_plots`` subplots with at most ``max_cols`` columns. Additional keyword arguments are passed to ``plt.subplots()``. Args: num_plots: Total number of plots that will be populated. max_cols: Maximum number of columns in the grid. delaxes: Whether to remove unused axes. Returns: matplotlib figure and axes """ ncols = min(max_cols, num_plots) nrows = int(np.ceil(num_plots / ncols)) fig, axes = plt.subplots(nrows, ncols, **kwargs) if not isinstance(axes, (list, np.ndarray)): axes = np.array([axes]) axes = np.asarray(axes) if delaxes: flat_axes = list(axes.flat) for ax in flat_axes[num_plots:]: fig.delaxes(ax) return fig, axes
[docs]@contextmanager def non_gui_backend(): """A contextmanager that temporarily uses a non-GUI backend for matplotlib.""" with warnings.catch_warnings(): ignore_messages = [ "Matplotlib is currently using agg", "FigureCanvasAgg is non-interactive", ] for msg in ignore_messages: warnings.filterwarnings("ignore", category=UserWarning, message=msg) try: old_backend = mpl.get_backend() mpl.use("Agg") yield finally: mpl.use(old_backend)
[docs]def auto_range_iqr( data_array: np.ndarray, cutoff_percentile: Union[float, Tuple[float, float]] = 1, ) -> Tuple[float, float]: """Get the min and max range of the provided array that excludes outliers following the IQR rule. This function computes the inter-quartile-range (IQR), defined by Q3-Q1, i.e. the percentiles for 75 and 25 percent of the distribution. The region without outliers is defined by [Q1-1.5*IQR, Q3+1.5*IQR]. Taken from `qcodes <https://github.com/QCoDeS/Qcodes/blob/ 6c8f7202f6b6fca4884bfc0f6e1e9a6564628d75/qcodes/utils/plotting.py#L28-L76>`_. Args: data_array: Array of arbitrary dimension containing the statistical data. cutoff_percentile: Percentile of data that may maximally be clipped on both sides of the distribution. If given a tuple (a, b) the percentile limits will be a and 100-b. Returns: vmin, vmax """ if isinstance(cutoff_percentile, tuple): bottom, top = cutoff_percentile else: bottom = cutoff_percentile top = 100 - bottom z = data_array.flatten() zmax = np.nanmax(z) zmin = np.nanmin(z) zrange = zmax - zmin pmin, q3, q1, pmax = np.nanpercentile(z, [bottom, 75, 25, top]) iqr = q3 - q1 # handle corner case of all data zero, such that IQR is zero # to counter numerical artifacts do not test IQR == 0, but IQR on its # natural scale (zrange) to be smaller than some very small number. # also test for zrange to be 0.0 to avoid division by 0. if zrange == 0.0 or iqr / zrange < 1e-8: vmin = zmin vmax = zmax else: vmin = max(q1 - 1.5 * iqr, zmin) vmax = min(q3 + 1.5 * iqr, zmax) vmin = min(vmin, pmin) vmax = max(vmax, pmax) return vmin, vmax