"""Run Control object for regolith
"""
from __future__ import print_function
import json
import os
import io
from pprint import pformat
from collections.abc import (
Mapping,
Iterable,
Hashable,
)
from warnings import warn
from regolith.validators import always_true, noop, DEFAULT_VALIDATORS
from regolith.database import connect
FORBIDDEN_NAMES = frozenset(["del", "global"])
[docs]def warn_forbidden_name(forname, inname=None, rename=None):
"""Warns the user that a forbidden name has been found."""
msg = "found forbidden name {0!r}".format(forname)
if inname is not None:
msg += " in {0!r}".format(inname)
if rename is not None:
msg += ", renaming to {0!r}".format(rename)
warn(msg, RuntimeWarning)
[docs]def ensuredirs(f):
"""For a file path, ensure that its directory path exists."""
d = os.path.split(f)[0]
if not os.path.isdir(d):
os.makedirs(d)
[docs]def touch(filename):
"""Opens a file and updates the mtime, like the posix command of the same name."""
with io.open(filename, "a") as f:
os.utime(filename, None)
[docs]def exec_file(filename, glb=None, loc=None):
"""A function equivalent to the Python 2.x execfile statement."""
with io.open(filename, "r") as f:
src = f.read()
exec(compile(src, filename, "exec"), glb, loc)
#
# Run Control
#
[docs]class NotSpecifiedType(object):
"""A helper class singleton for run control meaning that a 'real' value
has not been given."""
def __repr__(self):
return "NotSpecified"
NotSpecified = NotSpecifiedType()
"""A helper class singleton for run control meaning that a 'real' value
has not been given.
"""
[docs]class RunControl(object):
"""A composable configuration class. Unlike argparse.Namespace,
this keeps the object dictionary (__dict__) separate from the run
control attributes dictionary (_dict).
"""
def __init__(self, _updaters=None, _validators=None, **kwargs):
"""Parameters
-------------
kwargs : optional
Items to place into run control.
"""
self._dict = {}
self._updaters = _updaters or {}
self._validators = _validators or {}
for k, v in kwargs.items():
setattr(self, k, v)
def __getattr__(self, key):
if key in self._dict:
value = self._dict[key]
elif key in self.__dict__:
value = self.__dict__[key]
elif key in self.__class__.__dict__:
value = self.__class__.__dict__[key]
else:
msg = "RunControl object has no attribute {0!r}.".format(key)
raise AttributeError(msg)
if isinstance(value, property):
value = value.fget(self)
return value
def __setattr__(self, key, value):
if key.startswith("_"):
self.__dict__[key] = value
else:
if value is NotSpecified and key in self:
return
value = self._validate(key, value)
self._dict[key] = value
def __delattr__(self, key):
if key in self._dict:
del self._dict[key]
elif key in self.__dict__:
del self.__dict__[key]
elif key in self.__class__.__dict__:
del self.__class__.__dict__[key]
else:
msg = "RunControl object has no attribute {0!r}.".format(key)
raise AttributeError(msg)
def __iter__(self):
return iter(self._dict)
def __repr__(self):
keys = sorted(self._dict.keys())
s = ", ".join(["{0!s}={1!r}".format(k, self._dict[k]) for k in keys])
return "{0}({1})".format(self.__class__.__name__, s)
def _get(self, key, default=None):
try:
val = getattr(self, key)
except (KeyError, AttributeError):
val = default
return val
def _pformat(self):
keys = sorted(self._dict.keys())
f = lambda k: "{0!s}={1}".format(k, pformat(self._dict[k], indent=2))
s = ",\n ".join(map(f, keys))
return "{0}({1})".format(self.__class__.__name__, s)
def __contains__(self, key):
return (
key in self._dict
or key in self.__dict__
or key in self.__class__.__dict__
)
def __eq__(self, other):
if hasattr(other, "_dict"):
return self._dict == other._dict
elif isinstance(other, Mapping):
return self._dict == other
else:
return NotImplemented
def __ne__(self, other):
if hasattr(other, "_dict"):
return self._dict != other._dict
elif isinstance(other, Mapping):
return self._dict != other
else:
return NotImplemented
def __copy__(self):
return type(self)(_updaters=self._updaters, _validators=self._validators, **self._dict)
def _update(self, other):
"""Updates the rc with values from another mapping. If this rc has
if a key is in self, other, and self._updaters, then the updaters
value is called to perform the update. This function should return
a copy to be safe and not update in-place.
"""
if hasattr(other, "_dict"):
other = other._dict
elif not hasattr(other, "items"):
other = dict(other)
for k, v in other.items():
if v is NotSpecified:
pass
elif k in self._updaters and k in self:
v = self._updaters[k](getattr(self, k), v)
setattr(self, k, v)
def _validate(self, key, value):
"""Validates - and possibly converts - a value based on its key and the current
validators.
"""
validators = self._validators
if key in validators:
validator, convertor = validators[key]
else:
for vld in validators:
if isinstance(vld, str):
continue
m = vld.match(key)
if m is not None:
validator, convertor = validators[vld]
else:
validator, convertor = always_true, noop
return value if validator(value) else convertor(value)
[docs]def flatten(iterable):
"""Generator which returns flattened version of nested sequences."""
for el in iterable:
if isinstance(el, basestring):
yield el
elif isinstance(el, Iterable):
for subel in flatten(el):
yield subel
else:
yield el
#
# Memoization
#
[docs]def ishashable(x):
"""Tests if a value is hashable."""
if isinstance(x, Hashable):
if isinstance(x, basestring):
return True
elif isinstance(x, Iterable):
return all(map(ishashable, x))
else:
return True
else:
return False
DEFAULT_RC = RunControl(
_validators=DEFAULT_VALIDATORS,
builddir="_build",
mongodbpath=property(lambda self: os.path.join(self.builddir, "_dbpath")),
user_config=os.path.expanduser("~/.config/regolith/user.json"),
force=False,
)
[docs]def load_json_rcfile(fname):
"""Loads a JSON run control file."""
with open(fname, "r", encoding='utf-8') as f:
rc = json.load(f)
return rc
[docs]def load_rcfile(fname):
"""Loads a run control file."""
base, ext = os.path.splitext(fname)
if ext == ".json":
rc = load_json_rcfile(fname)
else:
raise RuntimeError(
"could not detemine run control file type from extension."
)
return rc
[docs]def filter_databases(rc):
"""Filters the databases list down to only the ones we need, in place."""
dbs = rc.databases
public_only = rc._get("public_only", False)
if public_only:
dbs = [db for db in dbs if db["public"]]
dbname = rc._get("db")
if dbname is not None:
dbs = [db for db in dbs if db["name"] == dbname]
elif len(dbs) == 1:
rc.db = dbs[0]["name"]
rc.databases = dbs
[docs]def connect_db(rc, colls=None):
'''
Load up the db's
Parameters
----------
rc:
The runcontrol instance
colls
The list of collections that should be loaded
Returns
-------
chained_db:
The chained databases in the form of a document
dbs:
The databases in the form of a runcontrol client
'''
with connect(rc, dbs=colls) as rc.client:
dbs = rc.client.dbs
chained_db = rc.client.chained_db
return chained_db, dbs