Source code for indica.operators.spline_fit

"""Fit 1-D splines to data."""

from itertools import accumulate
from itertools import chain
from itertools import repeat
from itertools import tee
from typing import cast
from typing import Dict
from typing import Hashable
from typing import List
from typing import Tuple
from typing import Union
import warnings

import numpy as np
from scipy.interpolate import CubicSpline
from scipy.optimize import least_squares
from scipy.sparse import lil_matrix
from xarray import DataArray

from .abstractoperator import EllipsisType
from .abstractoperator import Operator
from .. import session
from ..converters import bin_to_time_labels
from ..converters import CoordinateTransform
from ..converters import FluxSurfaceCoordinates
from ..datatypes import DataType
from ..numpy_typing import ArrayLike
from ..utilities import broadcast_spline
from ..utilities import coord_array

SingleBoundaryType = Union[str, Tuple[int, ArrayLike]]
BoundaryType = Union[str, Tuple[SingleBoundaryType, SingleBoundaryType]]

[docs]class Spline: """Callable class wrapping a `:class:scipy.interpolate.CubicSpline` object so it will work with DataArrays. It performs interpolation over one dimension, but can do this onto a multidimensional grid. Parameters ---------- values : DataArray The values to interpolate. dim : Hashable The axis along which to interpolate. coord_transform : CoordinateTransform The transform describing the coordinate system used by `values`. bounds : BoundaryType The boundary condition to pass to `:class:scipy.interpolate.CubicSpline`. """ def __init__( self, values: DataArray, dim: Hashable, coord_transform: CoordinateTransform, bounds: BoundaryType = "clamped", ): self.dim = dim self.spline_dims = tuple(d for d in values.dims if d != dim) self.spline_coords: Dict[Hashable, np.ndarray] = { k: np.asarray(v) for k, v in values.coords.items() if k != self.dim } transpose_order = (self.dim,) + self.spline_dims self.spline = CubicSpline( values.coords[dim], values.transpose(*transpose_order), 0, bounds, False ) self.transform = coord_transform
[docs] def __call__( self, coord_system: CoordinateTransform, x1: DataArray, x2: DataArray, t: DataArray, ) -> DataArray: """Get the spline values at the locations given by the coordinates. Although it takes multiple coordinates as arguments, the actual interpolation will only be done along the `dim` specified at instantiation. Parameters ---------- coord_system The transform describing the system used by the provided coordinates x1 The first spatial coordinate x2 The second spatial coordinate t The time coordinate """ self_x1, self_x2 = cast( Tuple[DataArray, DataArray], coord_system.convert_to(self.transform, x1, x2, t), ) coord = self_x1 if self.dim == self.transform.x1_name else self_x2 result = broadcast_spline( self.spline, self.spline_dims, self.spline_coords, coord ) result.attrs["transform"] = coord_system return result
[docs]class SplineFit(Operator): """Fit a 1-D spline to data. The spline will be given on poloidal flux surface coordinates, as specified by the user. It can derive a single spline fit for multiple DataArray arguments simultaneously. Parameters ---------- knots : ArrayLike A 1-D array containing the location of spline knots to use when fitting the data. lower_bound : ArrayLike The lower bounds to use for values at each not. May be either a scalar or an array of the same shape as ``knots``. upper_bound : ArrayLike The upper bounds to use for values at each not. May be either a scalar or an array of the same shape as ``knots``. sess : session.Session An object representing the session being run. Contains information such as provenance data. """ ARGUMENT_TYPES: List[Union[DataType, EllipsisType]] = [ ("norm_flux_pol", "plasma"), ("time", "plasma"), ("temperature", "electrons"), ..., ] def __init__( self, knots: ArrayLike = [0.0, 0.3, 0.6, 0.85, 0.95, 1.05], lower_bound: ArrayLike = -np.inf, upper_bound: ArrayLike = np.inf, sess: session.Session = session.global_session, ): self.knots = coord_array(knots, "rho_poloidal") self.lower_bound = lower_bound if isinstance(lower_bound, np.ndarray) and lower_bound.size != self.knots.size: raise ValueError( "lower_bound must be either a scalar or array of same size as knots" ) self.upper_bound = upper_bound if isinstance(upper_bound, np.ndarray) and upper_bound.size != self.knots.size: raise ValueError( "lower_bound must be either a scalar or array of same size as knots" ) self.spline: Spline self.spline_vals: DataArray super().__init__( sess, knots=str(knots), lower_bound=str(lower_bound), upper_bound=str(upper_bound), )
[docs] def return_types(self, *args: DataType) -> Tuple[DataType, ...]: """Indicates the datatypes of the results when calling the operator with arguments of the given types. It is assumed that the argument types are valid. Parameters ---------- args The datatypes of the parameters which the operator is to be called with. Returns ------- : The datatype of each result that will be returned if the operator is called with these arguments. """ input_type = args[-1] return (input_type,) * len(args)
[docs] def __call__( # type: ignore[override] self, rho: DataArray, times: DataArray, *data: DataArray, ) -> Tuple[DataArray, ...]: """Fit a spline to the provided data. Parameters ---------- rho The poloidal flux values on which to return the result. times The times at which to bin the data and return the result. data The data to fit the spline to. Returns ------- : The results of the fit on the give \\rho and time values. It contains the attribute `splines` which can be used to interpolate results onto arbitrary coordinates. """ self.validate_arguments(rho, times, *data) n_knots = len(self.knots) flux_surfaces = FluxSurfaceCoordinates("poloidal") flux_surfaces.set_equilibrium(data[0].indica.equilibrium) binned_data = [bin_to_time_labels(, d) for d in data] droppable_dims = [ [dim for dim in d.dims if dim != d.attrs["transform"].x1_name] for d in data ] good_channels: List[np.ndarray] = [ np.ravel( cast( DataArray, np.logical_not(np.isnan(d.isel({dim: 0 for dim in droppable}))), ).drop_vars(droppable) ) for d, droppable in zip(data, droppable_dims) ] for d, g in zip(binned_data, good_channels): d.attrs["nchannels"] = ( d.size * int(np.sum(g)) // (d.coords[d.attrs["transform"].x1_name].size * times.size) ) nt = len(times) rows = sum(d.attrs["nchannels"] for d in binned_data) * nt cols = (n_knots - 1) * nt sparsity = lil_matrix((rows, cols), dtype=int) nc1, nc2 = tee(d.attrs["nchannels"] for d in binned_data) for nc, data_row_start in zip( nc1, accumulate(map(lambda x: x * nt, chain(repeat(0, 1), nc2))) ): for i in range(nt): rstart = data_row_start + i * nc rend = rstart + nc cstart = i * (n_knots - 1) cend = cstart + (n_knots - 1) sparsity[rstart:rend, cstart:cend] = 1 def knotvals_to_xarray(knotvals): all_knots = np.empty((nt, n_knots)) all_knots[:, :-1] = knotvals.reshape(nt, n_knots - 1) all_knots[:, -1] = 0.0 return DataArray( all_knots, coords=[("t",, ("rho_poloidal",] ) # TODO: Consider how to handle locations outside of interpolation range. # For now just setting the interpolated values to 0.0 def residuals(knotvals): self.spline_vals = knotvals_to_xarray(knotvals) self.spline = Spline(self.spline_vals, "rho_poloidal", flux_surfaces) start = 0 resid = np.empty(rows) for d, g in zip(binned_data, good_channels): end = start + d.attrs["nchannels"] * nt rho, theta = d.indica.convert_coords(flux_surfaces) temp_resid = ( self.spline(flux_surfaces, rho, theta, times).fillna(0.0) - d ).isel({d.attrs["transform"].x1_name: g}) if d.ndim == 2: resid[start:end] = np.ravel( temp_resid.transpose("t", d.attrs["transform"].x1_name) ) elif d.ndim == 3: resid[start:end] = np.ravel( temp_resid.transpose( "t", d.attrs["transform"].x1_name, d.attrs["transform"].x2_name, ) ) start = end # assert np.all(np.isfinite(resid)) return resid guess = np.concatenate( tuple( np.mean([d.sel(t=t).mean() for d in binned_data]) * np.ones(n_knots - 1) for t in times ) ) fit = least_squares( residuals, guess, bounds=(self.lower_bound, self.upper_bound), jac_sparsity=sparsity, verbose=2, ) if fit.status == -1: raise RuntimeError( "Improper input to `least_squares` function when trying to " "fit emissivity to radiation data." ) elif fit.status == 0: warnings.warn( "Attempt to fit splines reached maximum number of function " "evaluations.", RuntimeWarning, ) result = self.spline(flux_surfaces, rho, DataArray(0.0), times) result.attrs["splines"] = self.spline self.spline_vals.attrs["datatype"] = result.attrs["datatype"] = data[0].attrs[ "datatype" ] = "knot_values" self.assign_provenance(result) self.assign_provenance(self.spline_vals) for d in binned_data: self.assign_provenance(d) return result, self.spline_vals, *binned_data