"""Experimental design for handling provenance using W3C PROV.
"""
from contextlib import contextmanager
from contextlib import redirect_stderr
import datetime
from functools import wraps
import hashlib
import importlib
import io
import os
from pathlib import Path
import platform
import re
import subprocess
import typing
import pkg_resources
import prov.model as prov
from xarray import DataArray
from xarray import Dataset
from .utilities import positional_parameters
if typing.TYPE_CHECKING:
from .equilibrium import Equilibrium
from .operators import Operator
from .readers import DataReader
__author__ = "Marco Sertoli"
__credits__ = ["Chris MacMackin", "Marco Sertoli"]
ORCID_RE = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{4}$")
global_session: "Session"
[docs]def get_dependency_data():
"""A generator for provenance data on dependencies."""
raise NotImplementedError("TODO: write this function")
[docs]def hash_vals(**kwargs: typing.Any) -> str:
"""Produces an SHA256 hash from the key-value pairs passed as
arguments.
Parameters
---------
kwargs
The data to use for the hash.
Returns
-------
str
A hexadecimal representation of the hash.
"""
# TODO: include date/time in hash
hash_result = hashlib.sha256()
for key, val in kwargs.items():
hash_result.update(bytes(key, encoding="utf-8"))
hash_result.update(b":")
hash_result.update(bytes(str(val), encoding="utf-8"))
hash_result.update(b",")
return hash_result.hexdigest()
[docs]def package_provenance(
doc: prov.ProvDocument, package_name: str
) -> typing.Tuple[prov.ProvEntity, prov.ProvEntity]:
"""Returns provenance for the requested package. This provenance will
include version information for all dependencies. Returns a tuple
of the provenance for the package in general and the specific
installation being used here.
"""
doc.add_namespace("pypi", "https://pypi.org/project/")
doc.add_namespace("local", "file://")
package = pkg_resources.working_set.find(
pkg_resources.Requirement.parse(package_name)
)
assert isinstance(package, pkg_resources.Distribution)
general_entity = doc.entity(
f"pypi:{package.project_name}",
{"pypi:package": package.project_name},
)
version_entity = doc.entity(
f"pypi:{package.project_name}/{package.version}",
{"pypi:version": package.version},
)
version_entity.specializationOf(general_entity)
# Some modules print things when imported, so capture this
tmp_output = io.StringIO()
try:
with redirect_stderr(tmp_output), redirect_stderr(tmp_output):
path = Path(
importlib.import_module(package.project_name).__file__ # type: ignore
).parent
# Check this directory and the parent directory for git repository
# TODO: Check all parent directories, but only if the child directory
# is not ignored.
# if any((p / ".git").exists() for p in [path] + path.parents):
if (path / ".git").exists() or (path.parent / ".git").exists():
git_hash = subprocess.check_output(
["git", "describe", "--always"], cwd=path, text=True
).strip()
git_diff = subprocess.check_output(
["git", "diff", "HEAD", "--", "indica"], text=True
).strip()
if len(git_diff) > 0:
git_hash += "-dirty"
elif (path / "git_version").exists():
with (path / "git_version").open() as f:
git_hash = f.read()
else:
git_hash = "UNKNOWN"
except ModuleNotFoundError:
path = Path(package.location) / package.project_name
git_hash = "UNKNOWN"
except Exception:
tmp_output.seek(0)
print(tmp_output.read())
raise
installed_entity = doc.entity(
f"local:{path}", {"host": platform.node(), "git_commit": git_hash}
)
installed_entity.specializationOf(version_entity)
for dep in package.requires():
dep_general_entity, dep_installed_entity = package_provenance(
doc, dep.project_name
)
version_entity.wasDerivedFrom(dep_general_entity)
installed_entity.wasDerivedFrom(dep_installed_entity)
return general_entity, installed_entity
[docs]class Session:
"""Manages the a particular run of the software.
Has the following uses:
- keep information about version of package and dependencies
- hold provenance information
- track the data read/calculated and operators instantiated
- allow that data to be exported and reloaded
TODO: Consider whether some of these behaviours should be spun off
into separate classes which are then aggregated into this one.
Parameters
----------
user_id: str
Something with which to identify the user. Recommend either an email
address or an ORCiD ID.
Attributes
----------
data: typing.Dict[str, DataArray]
All of the data which has been read in or calculated during this
session.
equilibria: typing.Dict[str, Equilibrium]
All of the equilibrium objects which have been created during this
session.
operators: typing.Dict[str, AbstractOperator]
All of the operators which have been instantiated during this session.
prov: prov.model.ProvDocument
The document containing all of the provenance information for this
session.
readers: typing.Dict[str, DataReader]
session: prov.model.ProvActivity
The provenance Activity object representing this session. It should
contain information about versions of different libraries being used.
"""
def __init__(self, user_id: str):
self.prov = prov.ProvDocument()
self.prov.set_default_namespace("https://ccfe.ukaea.uk/")
if ORCID_RE.match(user_id):
self.prov.add_namespace("orcid", "https://orcid.org/")
self._user = [self.prov.agent("orcid:" + user_id)]
else:
self._user = [
self.prov.agent(user_id if user_id else "example@example.com")
]
date = datetime.datetime.now()
session_properties = {
"os": platform.platform(),
"directory": os.getcwd(),
"host": platform.node(),
"python": platform.python_version(),
}
session_id = hash_vals(startTime=date, **session_properties)
self.session = self.prov.activity(session_id, date, None, session_properties)
# Use an empty ID to short-circuit all of the provenance
# calculation. This is useful to prevent provenance being
# built whenever this module is imported.
if user_id != "":
self.indica_prov = package_provenance(self.prov, "indica")[1]
self.session.used(self.indica_prov)
self.prov.association(self.session, self._user[0])
self.data: typing.Dict[str, typing.Union[DataArray, Dataset]] = {}
self.equilibria: typing.Dict[str, Equilibrium] = {}
self.operators: typing.Dict[str, Operator] = {}
self.readers: typing.Dict[str, DataReader] = {}
def __enter__(self):
global global_session
self.old_global_session = global_session
global_session = self
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
global global_session
global_session = self.old_global_session
return False
@property
def agent(self) -> prov.ProvAgent:
"""The agent (person or piece of software) currently in immediate
control of execution.
:returntype: prov.model.ProvAgent
"""
return self._user[-1]
[docs] def push_agent(self, agent: prov.ProvAgent):
"""Delegate responsibility to another agent.
They will appear to be in control of execution now and will be
returned by the :py:meth:`agent` property.
Parameters
----------
agent
The new agent to delegate responsibilityt to.
"""
agent.actedOnBehalfOf(self._user[-1])
self._user.append(agent)
[docs] def pop_agent(self) -> prov.ProvAgent:
"""Take responsibility back from the Agent that it was most recently
delegated to.
The Agent which the responsibility was delegated by will now
appear to be in control of execution and will be the one
returned by the :py:meth:`agent` property.
Returns
-------
prov.ProvAgent
The agent that responsibility was taken away from.
"""
return self._user.pop()
[docs] @contextmanager
def new_agent(self, agent: prov.ProvAgent) -> prov.ProvAgent:
"""A context manager for temporarily adding an agent to the
session. This is useful to ensure the agent will be removed even if
there is an exception thrown.
"""
self.push_agent(agent)
try:
yield agent
finally:
self.pop_agent()
[docs] def export(self, filename: str):
"""Write all of the data and operators from this session into a file,
for reuse later.
"""
raise NotImplementedError
[docs] @classmethod
def begin(cls, user_id: str):
"""Sets up a global session, without bothering with a context
manager.
Parameters
----------
user_id
An identifier, such as an email address or ORCiD ID, for the person
using the software.
"""
global global_session
global_session = cls(user_id)
[docs] @classmethod
def reload(cls, filename: str) -> "Session":
"""Create a session from a saved which was written to
``filename``. Thanks to some Python voodoo, any local
variables in ``__main__`` will be recreated.
"""
raise NotImplementedError
global_session = Session("")
[docs]def generate_prov(pass_sess: bool = False):
"""Decorator to be applied to functions generating
:py:class:`xarray.DataArray` output. It will produce PROV data and
attach it as an attribute.
This should only be applied to stateless functions, as the PROV
data it generates will not accurately describe anything else.
Parameters
----------
pass_sess
Indicates whether, if a keyword argument called ``sess`` is present,
it should be passed to ``func``.
"""
def outer_wrapper(func):
param_names, var_positional = positional_parameters(func)
num_positional = len(param_names)
@wraps(func)
def prov_generator(*args, **kwargs):
session = kwargs.get("sess", global_session)
if "sess" in kwargs and not pass_sess:
kwargs = dict(kwargs)
del kwargs["sess"]
start_time = datetime.datetime.now()
result = func(*args, **kwargs)
end_time = datetime.datetime.now()
args_prov = []
activity_attrs = {prov.PROV_TYPE: func.__name__}
id_attrs = {}
for i, arg in enumerate(args):
if i < num_positional:
argname = param_names[i]
else:
argname = var_positional + str(i - num_positional)
if isinstance(arg, DataArray):
args_prov.append(arg.attrs["provenance"])
id_attrs[argname] = args_prov[-1].identifier
else:
args_prov.append(str(arg))
activity_attrs[argname] = str(arg)
for key, val in kwargs.items():
if isinstance(arg, DataArray):
args_prov.append(val.attrs["provenance"])
id_attrs[key] = args_prov[-1].identifier
else:
args_prov[key] = str(key)
activity_attrs[val] = str(arg)
generated_array = False
activity_id = hash_vals(agent=session.agent, date=end_time, **id_attrs)
activity = session.prov.activity(
activity_id, start_time, end_time, activity_attrs
)
if isinstance(result, DataArray):
entity_id = hash_vals(
activity=activity_id, name=result.name, **id_attrs
)
entity = session.prov.entity(entity_id)
entity.wasGeneratedBy(activity, end_time)
entity.wasAttributedTo(session.agent)
result.attrs["provenance"] = entity
elif isinstance(result, tuple):
for i, r in enumerate(result):
if isinstance(r, DataArray):
entity_id = hash_vals(
activity=activity_id,
position=str(i),
name=r.name,
**id_attrs,
)
entity = session.prov.entity(entity_id)
entity.wasGeneratedBy(activity, end_time)
entity.wasAttributedTo(session.agent)
r.attrs["provenance"] = entity
if not generated_array:
raise ValueError(
"No DataArray object was produced by the "
"function. Can not assign PROV data."
)
return result
return prov_generator
return outer_wrapper