Source code for indica.operators.abstractoperator

"""Experimental design for performing mathematical operations on data.
"""

from abc import ABC
from abc import abstractmethod
import datetime
from itertools import zip_longest
from typing import Any
from typing import cast
from typing import List
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from warnings import warn

import prov.model as prov
from xarray import DataArray
from xarray import Dataset

from .. import session
from ..datatypes import ArrayType
from ..datatypes import DatasetType
from ..datatypes import DataType
from ..datatypes import DatatypeWarning
from ..datatypes import GENERAL_DATATYPES
from ..datatypes import SPECIFIC_DATATYPES

Data = Union[DataArray, Dataset]

if TYPE_CHECKING:
    from builtins import ellipsis as EllipsisType
else:
    EllipsisType = type(Ellipsis)


[docs]class OperatorError(Exception): """An Exception class raised by :py:class:`operator.Operator` when receiving erroneous arguments. """
[docs]class Operator(ABC): """Abstract base class for performing calculations with data. Note that the "Parameters" section below describes the paramters used when calling an object of this class and *not* when constructing a new object as would normally be the case. Parameters ---------- sess: session.Session An object representing the session being run. Contains information such as provenance data. kwargs: Any Any other arguments which should be recorded in the PROV entity for the reader. Attributes ---------- ARGUMENT_TYPES: ClassVar[List[DataType]] Ordered list of the types of data expected for each argument of the operator. If there are variadic positional arguments then their type is given by the final element of the list. RETURN_TYPES: ClassVar[List[DataType]] Ordered list of the types of data returned by the operator. prov_id: str The hash used to identify this object in provenance documents. agent: prov.model.ProvAgent An agent representing this object in provenance documents. DataArray objects can be attributed to it. entity: prov.model.ProvEntity An entity representing this object in provenance documents. It is used to provide information on the object's own provenance. """ ARGUMENT_TYPES: List[Union[DataType, EllipsisType]] = [] def __init__(self, sess: session.Session = session.global_session, **kwargs: Any): """Creates a provenance entity/agent for the operator object. Also checks arguments and results are of valid datatypes. Should be called by initialisers in subclasses. """ self._session = sess # TODO: also include library version and, ideally, version of # relevent dependency in the hash self.prov_id = session.hash_vals( operator_type=self.__class__.__name__, **kwargs ) self.agent = self._session.prov.agent(self.prov_id) self._session.prov.actedOnBehalfOf(self.agent, self._session.agent) self.entity = self._session.prov.entity(self.prov_id, kwargs) self._session.prov.generation( self.entity, self._session.session, time=datetime.datetime.now() ) self._session.prov.attribution(self.entity, self._session.agent) self._input_provenance: List[prov.ProvEntity] = [] self._prov_count = 0 self._end_time: datetime.datetime self.activity: prov.ProvActivity for i, datatype in enumerate(self.ARGUMENT_TYPES): if isinstance(datatype, EllipsisType): if i + 1 != len(self.ARGUMENT_TYPES): raise TypeError( ( "Operator class {} uses ellipsis dots as a type for" " argument {}. Only supported in final position." ).format(self.__class__.__name__, i + 1) ) else: continue if datatype[0] and datatype[0] not in GENERAL_DATATYPES: warn( "Operator class {} expects argument {} to have " "unrecognised general datatype '{}'".format( self.__class__.__name__, i + 1, datatype[0] ), DatatypeWarning, ) if datatype[1] and datatype[1] not in SPECIFIC_DATATYPES: warn( "Operator class {} expects argument {} to have " "unrecognised specific datatype '{}'".format( self.__class__.__name__, i + 1, datatype[1] ), DatatypeWarning, ) def _ellipsis_type(self, arg: Data) -> DataType: """Given the argument corresponding to the penultimate argument type, return the type required for all further variadic arguments. """ if isinstance(arg, DataArray): return arg.attrs["datatype"] else: dtype = arg.attrs["datatype"] return dtype[0], { k: dtype[1][k] for k in cast(DatasetType, self.ARGUMENT_TYPES[-2])[1] }
[docs] def validate_arguments(self, *args: Data): """Checks that arguments to the operator are of the expected types. Also gathers provenance information for use later. Parameters ---------- args All of the arguments to be used in the operation. """ self._start_time = datetime.datetime.now() self._input_provenance = [ arg.attrs["provenance"] for arg in args if "provenance" in arg.attrs ] arg_len = len(args) expected_len = len(self.ARGUMENT_TYPES) if expected_len > 0 and self.ARGUMENT_TYPES[-1] == Ellipsis: iterator = zip_longest( args, self.ARGUMENT_TYPES[:-1], fillvalue=self._ellipsis_type(args[expected_len - 2]), ) elif arg_len != expected_len: message = ( "Operator of class {} received {} arguments but " "expected {}".format(self.__class__.__name__, arg_len, expected_len) ) raise OperatorError(message) else: # MyPy complaining since iterator is set to type zip_longest earlier in the # code, and is set to type zip here even though the two assignments are in # two mutually exclusive branches!(if-else branches not git branches) # Ignoring for now. iterator = zip(args, self.ARGUMENT_TYPES) # type: ignore for i, (arg, expected) in enumerate(iterator): if isinstance(arg, DataArray): datatype = arg.attrs["datatype"] expected = cast(ArrayType, expected) if expected[0] and datatype[0] != expected[0]: message = ( "Argument {} of wrong general data type for operator {}: " "expected {}, received {}.".format( i + 1, self.__class__.__name__, expected[0], datatype[0], ) ) raise OperatorError(message) if expected[1] and datatype[1] != expected[1]: message = ( "Argument {} of wrong specific data type for operator {}: " "expected to be for {}, received {}.".format( i + 1, self.__class__.__name__, expected[1], datatype[1], ) ) raise OperatorError(message) elif isinstance(arg, Dataset): datatype = arg.attrs["datatype"] expected = cast(DatasetType, expected) if expected[0] and datatype[0] != expected[0]: message = ( "Argument {} of wrong specific data type for operator {}: " "expected {}, received {}.".format( i + 1, self.__class__.__name__, expected[0], datatype[0], ) ) raise OperatorError(message) for key, general_type in expected[1].items(): if key not in datatype[1]: message = ( "Variable {} required by operator {} is missing from " "argument {}.".format( key, self.__class__.__name__, i + 1, ) ) raise OperatorError(message) if datatype[1][key] != general_type: message = ( "Variable {} of argument {} of wrong general data type for " "operator {}: expected {}, received {}.".format( key, i + 1, self.__class__.__name__, general_type, datatype[1][key], ) ) raise OperatorError(message) else: raise OperatorError( "Argument {} is not a DataArray or Dataset".format(arg) )
[docs] @abstractmethod def return_types(self, *args: DataType) -> Tuple[DataType, ...]: """Indicates the datatypes of the results when calling the operator with arguments of the given types. It is assumed that the argument types are valid. Parameters ---------- args The datatypes of the parameters which the operator is to be called with. Returns ------- : The datatype of each result that will be returned if the operator is called with these arguments. """ raise NotImplementedError( "{} does not implement a " "'return_types' method.".format(self.__class__.__name__) )
[docs] def assign_provenance(self, data: Union[DataArray, Dataset]) -> prov.ProvEntity: """Create and assign a provenance entity to the argument. This argument should be one of the results of the operator. This should only be called after :py:meth:`validate_arguments`, as it relies on that routine to collect information about the inputs to the operator. It should not be called until after all calculations are finished, as the first call will be used to determine the end-time of the calculation. Returns ------- : A provenance entity for the newly calculated data. """ # TODO: Generate multiple pieces of PROV data for multiple return values if self._prov_count == 0: self.end_time = datetime.datetime.now() activity_id = session.hash_vals(agent=self.prov_id, date=self.end_time) self.activity = self._session.prov.activity( activity_id, self._start_time, self.end_time, {prov.PROV_TYPE: "Calculation"}, ) self.activity.wasAssociatedWith(self._session.agent) self.activity.wasAssociatedWith(self.agent) self.activity.wasInformedBy(self._session.session) for arg in self._input_provenance: self.activity.used(arg) entity_id = session.hash_vals( creator=self.prov_id, date=self.end_time, result_number=self._prov_count, **{ "arg" + str(i): p.identifier for i, p in enumerate(self._input_provenance) } ) self._prov_count += 1 if isinstance(data, Dataset): entity = self._session.prov.collection( entity_id, {prov.PROV_TYPE: "Dataset"} ) for array in data.data_vars.values(): if "provenance" not in array.attrs: print("Creating provenenace") self.assign_provenance(array) entity.hadMember(array.attrs["provenance"]) else: entity = self._session.prov.entity( entity_id, { prov.PROV_TYPE: "DataArray", prov.PROV_VALUE: ",".join(data.attrs["datatype"]), }, ) entity.wasGeneratedBy(self.activity, self.end_time) entity.wasAttributedTo(self._session.agent) entity.wasAttributedTo(self.agent) for arg in self._input_provenance: entity.wasDerivedFrom(arg) if isinstance(data, Dataset): data.attrs["provenance"] = entity else: data.attrs["partial_provenance"] = entity if data.indica.equilibrium: data.indica._update_prov_for_equilibrium( data.attrs["transform"].equilibrium ) else: data.attrs["provenance"] = entity
[docs] @abstractmethod def __call__(self, *args: DataArray) -> Union[DataArray, Dataset]: """The invocation of the operator. The exact number of arguments should be determined by the subclass. However, it is anticipated that these would all be :py:class:`xarray.DataArray` objects. Unfortunately, we can not use Mypy static type-checking for this routine or its overriding implementations, as the number of arguments will vary. """ raise NotImplementedError( "{} does not implement a " "'__call__' method.".format(self.__class__.__name__) )