diff --git a/confuse.py b/confuse.py index 18b18f0..7c4633d 100644 --- a/confuse.py +++ b/confuse.py @@ -32,6 +32,7 @@ import yaml import re from collections import OrderedDict +from functools import wraps if sys.version_info >= (3, 3): from collections import abc else: @@ -134,36 +135,193 @@ def __init__(self, filename, reason=None): # Views and sources. + +UNSET = object() # sentinel + + +def _load_first(func): + '''Call self.load() before the function is called - used for lazy source + loading''' + def inner(self, *a, **kw): + self.load() + return func(self, *a, **kw) + + try: + return wraps(func)(inner) + except AttributeError: + # in v2 they don't ignore missing attributes + # v3: https://github.com/python/cpython/blob/3.8/Lib/functools.py + # v2: https://github.com/python/cpython/blob/2.7/Lib/functools.py + inner.__name__ = func.__name__ + return inner + + +def all_subclasses(cls): + return set(cls.__subclasses__()).union( + s for c in cls.__subclasses__() for s in all_subclasses(c)) + + class ConfigSource(dict): - """A dictionary augmented with metadata about the source of the + '''A dictionary augmented with metadata about the source of the configuration. - """ - def __init__(self, value, filename=None, default=False): - super(ConfigSource, self).__init__(value) + ''' + def __init__(self, value=UNSET, filename=None, default=False): + # track whether a config source has been set yet + self.loaded = value is not UNSET + super(ConfigSource, self).__init__(value if self.loaded else {}) if filename is not None and not isinstance(filename, BASESTRING): raise TypeError(u'filename must be a string or None') self.filename = filename self.default = default def __repr__(self): - return 'ConfigSource({0!r}, {1!r}, {2!r})'.format( - super(ConfigSource, self), - self.filename, - self.default, - ) + return '{}({}, filename={}, default={})'.format( + self.__class__.__name__, + dict.__repr__(self) + if self.loaded else '[Unloaded]' + if self.exists else "[Source doesn't exist]", + self.filename, self.default) + + @property + def exists(self): + '''Does this config have access to usable configuration values?''' + return self.loaded or self.filename and os.path.isfile(self.filename) + + def load(self): + '''Ensure that the source is loaded.''' + if not self.loaded: + self.config_dir() + self.loaded = self._load() is not False + return self + + def _load(self): + '''Load config from source and update self. + If it doesn't load, return False to keep it marked as unloaded. + Otherwise it will be assumed to be loaded. + ''' + + def config_dir(self, create=True): + '''Create the config dir, if there's a filename associated with the + source.''' + if self.filename: + dirname = os.path.dirname(self.filename) + if create and dirname and not os.path.isdir(dirname): + os.makedirs(dirname) + return dirname + return None @classmethod - def of(cls, value): - """Given either a dictionary or a `ConfigSource` object, return - a `ConfigSource` object. This lets a function accept either type - of object as an argument. - """ + def is_of_type(cls, value): + '''Determine if value is a valid parameter for this Source.''' + return + + # overriding dict methods so that the configuration is loaded before any + # of them are run + __getitem__ = _load_first(dict.__getitem__) + __iter__ = _load_first(dict.__iter__) + # __len__ = _load_first(dict.__len__) + keys = _load_first(dict.keys) + values = _load_first(dict.values) + + @classmethod + def of(cls, value, **kw): + '''Try to convert value to a `ConfigSource` object. This lets a + function accept values that are convertable to a source. + ''' + # ignore if already a source if isinstance(value, ConfigSource): return value - elif isinstance(value, dict): - return ConfigSource(value) - else: - raise TypeError(u'source value must be a dict') + + # convert using one of the configsource subtypes + for subcls in all_subclasses(ConfigSource): + if subcls.is_of_type(value): + return subcls(value, **kw) + + # if none matched, use convert dict to ConfigSource + if isinstance(value, dict): + return ConfigSource(value, **kw) + raise TypeError( + u'ConfigSource.of value unable to cast to ConfigSource.') + + +class YamlSource(ConfigSource): + '''A config source pulled from yaml files.''' + EXTENSIONS = '.yaml', '.yml' + + def __init__(self, filename=None, value=UNSET, default=False, + ignore_missing=False): + self.ignore_missing = ignore_missing + super(YamlSource, self).__init__(value, filename, default) + + @classmethod + def is_of_type(cls, value): + '''Does value look like a .yaml file?''' + return ( + isinstance(value, BASESTRING) + and os.path.splitext(value)[1] in cls.EXTENSIONS) + + _loader = lambda self, f: load_yaml(f) + + def _load(self): + '''Load the file if it exists.''' + if self.ignore_missing and not os.path.isfile(self.filename): + return False + self.update(self._loader(self.filename)) + + +class EnvSource(ConfigSource): + '''A config source pulled from environment variables.''' + DEFAULT_PREFIX = 'CONFUSE' + + def __init__(self, prefix=None, sep='__', value=UNSET): + self._prefix = prefix + self._sep = sep + super(EnvSource, self).__init__( + value=value, filename=None, default=False) + + def __repr__(self): + return '{}({}, prefix={}, separator={})'.format( + self.__class__.__name__, + dict.__repr__(self) + if self.loaded else '[Unloaded]', + self._prefix, self._sep) + + def _load(self): + self.update(nest_keys( + os.environ, sep=self._sep, + prefix=self._prefix or self.DEFAULT_PREFIX)) + + +def nest_keys(value, sep='.', prefix='', overwrite_higher=False, + overwrite_lower=False): + '''Convert a flat dict with splittable keys to a nested dict split by + that separator.''' + out = {} + # filter by prefix and split by separator + items = ((k[len(prefix):].split(sep), v) + for k, v in value.items() + if k.startswith(prefix)) + + # for each environment variable + for ks, val in sorted(items, key=lambda x: -len(x[0])): + # recursively set keys + dct = out + for k in ks[:-1]: + if (not overwrite_higher and k in dct + and not isinstance(dct[k], dict)): + raise ValueError( + 'Trying to overwrite another value with a nested dict.') + + if k not in dct: + dct[k] = {} + dct = dct[k] + + k = ks[-1] + if not overwrite_lower and k in dct and isinstance(dct[k], dict): + raise ValueError( + 'Trying to overwrite a nested dict with another value.') + dct[k] = val + return out class ConfigView(object): @@ -566,13 +724,17 @@ def __init__(self, sources): self.redactions = set() def add(self, obj): - self.sources.append(ConfigSource.of(obj)) + src = ConfigSource.of(obj) + self.sources.append(src) + return src def set(self, value): - self.sources.insert(0, ConfigSource.of(value)) + src = ConfigSource.of(value) + self.sources.insert(0, src) + return src def resolve(self): - return ((dict(s), s) for s in self.sources) + return ((dict(s.load()), s) for s in self.sources) def clear(self): """Remove all sources (and redactions) from this @@ -638,10 +800,10 @@ def resolve(self): yield value, source def set(self, value): - self.parent.set({self.key: value}) + return self.parent.set({self.key: value}) def add(self, value): - self.parent.add({self.key: value}) + return self.parent.add({self.key: value}) def root(self): return self.parent.root() @@ -726,6 +888,66 @@ def config_dirs(): return out +def find_user_config_files(appname, env_var=None, config_fname=CONFIG_FILENAME, + first=True): + """Get the path to the user configuration directory. The + directory is guaranteed to exist as a postcondition (one may be + created if none exist). + + If the application's ``...DIR`` environment variable is set, it + is used as the configuration directory. Otherwise, + platform-specific standard configuration locations are searched + for a ``config.yaml`` file. If no configuration file is found, a + fallback path is used. + + Arguments: + appname (str): the subdirectory to search for in default config + locations. + env_var (str, optional): the environment variable to look for + config_fname (str): the config filename to look for. + first (bool): only return the first config file. Set to False to + return all matching config files. This will create directories + for all files returned. + + Returns: + config_file (str) if ``first == True`` else config_files (list(str)). + """ + foundcfgs = [] + + # If environment variable is set, use it. + if env_var and env_var in os.environ: + appdir = os.path.abspath(os.path.expanduser(os.environ[env_var])) + foundcfgs.append( + appdir if os.path.isfile(appdir) else + os.path.join(appdir, config_fname)) + + # Search platform-specific locations. If no config file is + # found, fall back to the first directory in the list. + cfgfiles = [os.path.join(d, appname, config_fname) for d in config_dirs()] + foundcfgs.extend( + [f for f in cfgfiles if os.path.isfile(f)] or cfgfiles[:1]) + + return foundcfgs[0] if first else foundcfgs + + +def find_package_config(modname, config_fname=DEFAULT_FILENAME): + '''Return a package default config file if it exists.''' + package_path = _package_path(modname) + if package_path: + default_config_file = os.path.join(package_path, config_fname) + if os.path.isfile(default_config_file): + return default_config_file + return None + + +def _ensure_list(x): + '''Convert to list. e.g. 1 => [1], (1, 2) => [1, 2], None => [].''' + return ( + x if isinstance(x, list) else + list(x) if isinstance(x, tuple) else + [x] if x else []) + + # YAML loading. class Loader(yaml.SafeLoader): @@ -895,7 +1117,8 @@ def restore_yaml_comments(data, default_data): # Main interface. class Configuration(RootView): - def __init__(self, appname, modname=None, read=True): + def __init__(self, appname, modname=None, source=None, read=True, + config_filename=None, default_filename=None, user=None): """Create a configuration object by reading the automatically-discovered config files for the application for a given name. If `modname` is specified, it should be the import @@ -909,45 +1132,40 @@ def __init__(self, appname, modname=None, read=True): super(Configuration, self).__init__([]) self.appname = appname self.modname = modname - - # Resolve default source location. We do this ahead of time to - # avoid unexpected problems if the working directory changes. - if self.modname: - self._package_path = _package_path(self.modname) - else: - self._package_path = None - - self._env_var = '{0}DIR'.format(self.appname.upper()) + self.config_filename = config_filename or CONFIG_FILENAME + self.default_filename = default_filename or DEFAULT_FILENAME + self._env_var = env_var = '{}DIR'.format(self.appname.upper()) + self._base_sources = [] + + # convert user-provided sources to a list of config files + for source in _ensure_list(source): + self.add(source, ignore_missing=True) + + # search the users system for config files + if user is not False or not self.sources: + self.add( + find_user_config_files( + self.appname, env_var, + config_fname=self.config_filename, + first=True), ignore_missing=True) + + # if user specified a module name, load the config + if modname: + self.add(find_package_config( + modname, self.default_filename), default=True) if read: self.read() - def user_config_path(self): - """Points to the location of the user configuration. - - The file may not exist. - """ - return os.path.join(self.config_dir(), CONFIG_FILENAME) - - def _add_user_source(self): - """Add the configuration options from the YAML file in the - user's configuration directory (given by `config_dir`) if it - exists. - """ - filename = self.user_config_path() - if os.path.isfile(filename): - self.add(ConfigSource(load_yaml(filename) or {}, filename)) - - def _add_default_source(self): - """Add the package's default configuration settings. This looks - for a YAML file located inside the package for the module - `modname` if it was given. + def config_dir(self): + """Get the path to the user configuration directory. This + looks for the first source that has a filename and uses the + file's parent directory. Returns None if none are found. """ - if self.modname: - if self._package_path: - filename = os.path.join(self._package_path, DEFAULT_FILENAME) - if os.path.isfile(filename): - self.add(ConfigSource(load_yaml(filename), filename, True)) + for source in self.sources + self._base_sources: + if source.filename: + return source.config_dir() + return None def read(self, user=True, defaults=True): """Find and read the files for this configuration and set them @@ -955,53 +1173,38 @@ def read(self, user=True, defaults=True): discovered user configuration files or the in-package defaults, set `user` or `defaults` to `False`. """ - if user: - self._add_user_source() - if defaults: - self._add_default_source() - - def config_dir(self): - """Get the path to the user configuration directory. The - directory is guaranteed to exist as a postcondition (one may be - created if none exist). - - If the application's ``...DIR`` environment variable is set, it - is used as the configuration directory. Otherwise, - platform-specific standard configuration locations are searched - for a ``config.yaml`` file. If no configuration file is found, a - fallback path is used. - """ - # If environment variable is set, use it. - if self._env_var in os.environ: - appdir = os.environ[self._env_var] - appdir = os.path.abspath(os.path.expanduser(appdir)) - if os.path.isfile(appdir): - raise ConfigError(u'{0} must be a directory'.format( - self._env_var - )) - - else: - # Search platform-specific locations. If no config file is - # found, fall back to the first directory in the list. - configdirs = config_dirs() - for confdir in configdirs: - appdir = os.path.join(confdir, self.appname) - if os.path.isfile(os.path.join(appdir, CONFIG_FILENAME)): - break - else: - appdir = os.path.join(configdirs[0], self.appname) - - # Ensure that the directory exists. - if not os.path.isdir(appdir): - os.makedirs(appdir) - return appdir - - def set_file(self, filename): - """Parses the file as YAML and inserts it into the configuration - sources with highest priority. - """ - filename = os.path.abspath(filename) - self.set(ConfigSource(load_yaml(filename), filename)) + for source in self.sources: + if user and not source.default: + source.load() + if defaults and source.default: + source.load() + + def set(self, source, **kw): + return super(Configuration, self).set( + ConfigSource.of(self._to_filename(source), **kw)) + + def add(self, source, **kw): + return super(Configuration, self).add( + ConfigSource.of(self._to_filename(source), **kw)) + + def _to_filename(self, source, default=False): + '''Convert a config directory/file to an absolute config file.''' + if isinstance(source, ConfigSource): + return source + if isinstance(source, BASESTRING): + source = os.path.abspath(os.path.expanduser(source)) + # if the source is a directory, look for a config file inside + if (os.path.isdir(source) + or os.path.splitext(source)[1] not in {'.yaml', '.yml'}): + source = os.path.join(source, ( + self.default_filename if default + else self.config_filename)) + + # ensure directory exists + cfgdir = os.path.dirname(source) + if not os.path.isdir(cfgdir): + os.makedirs(cfgdir) + return source def dump(self, full=True, redact=False): """Dump the Configuration object to a YAML file. @@ -1048,43 +1251,9 @@ class LazyConfig(Configuration): accessed. This is appropriate for using as a global config object at the module level. """ - def __init__(self, appname, modname=None): - super(LazyConfig, self).__init__(appname, modname, False) - self._materialized = False # Have we read the files yet? - self._lazy_prefix = [] # Pre-materialization calls to set(). - self._lazy_suffix = [] # Calls to add(). - - def read(self, user=True, defaults=True): - self._materialized = True - super(LazyConfig, self).read(user, defaults) - - def resolve(self): - if not self._materialized: - # Read files and unspool buffers. - self.read() - self.sources += self._lazy_suffix - self.sources[:0] = self._lazy_prefix - return super(LazyConfig, self).resolve() - - def add(self, value): - super(LazyConfig, self).add(value) - if not self._materialized: - # Buffer additions to end. - self._lazy_suffix += self.sources - del self.sources[:] - - def set(self, value): - super(LazyConfig, self).set(value) - if not self._materialized: - # Buffer additions to beginning. - self._lazy_prefix[:0] = self.sources - del self.sources[:] - - def clear(self): - """Remove all sources from this configuration.""" - super(LazyConfig, self).clear() - self._lazy_suffix = [] - self._lazy_prefix = [] + def __init__(self, appname, modname=None, *a, **kw): + super(LazyConfig, self).__init__( + appname, modname, *a, read=False, **kw) # "Validated" configuration views: experimental! diff --git a/test/test_dump.py b/test/test_dump.py index e40add7..43be3c1 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -49,8 +49,8 @@ def test_dump_ordered_dict(self): def test_dump_sans_defaults(self): config = confuse.Configuration('myapp', read=False) - config.add({'foo': 'bar'}) - config.sources[0].default = True + src = config.add({'foo': 'bar'}) + src.default = True config.add({'baz': 'qux'}) yaml = config.dump().strip() @@ -95,8 +95,8 @@ def test_dump_unredacted(self): def test_dump_redacted_sans_defaults(self): config = confuse.Configuration('myapp', read=False) - config.add({'foo': 'bar'}) - config.sources[0].default = True + src = config.add({'foo': 'bar'}) + src.default = True config.add({'baz': 'qux'}) config['baz'].redact = True diff --git a/test/test_paths.py b/test/test_paths.py index ad6d10b..e16e961 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -115,12 +115,12 @@ def tearDown(self): def test_no_sources_when_files_missing(self): config = confuse.Configuration('myapp', read=False) - filenames = [s.filename for s in config.sources] + filenames = [s.filename for s in config.sources if s.loaded] self.assertEqual(filenames, []) def test_search_package(self): config = confuse.Configuration('myapp', __name__, read=False) - config._add_default_source() + config.read(user=False, defaults=True) for source in config.sources: if source.default: @@ -141,9 +141,12 @@ class EnvVarTest(FakeSystem): def setUp(self): super(EnvVarTest, self).setUp() - self.config = confuse.Configuration('myapp', read=False) os.environ['MYAPPDIR'] = self.home # use the tmp home as a config dir + @property + def config(self): + return confuse.Configuration('myapp', read=False) + def test_env_var_name(self): self.assertEqual(self.config._env_var, 'MYAPPDIR') @@ -176,7 +179,9 @@ def setUp(self): os.path.join = self.join os.makedirs, self._makedirs = self.makedirs, os.makedirs - self.config = confuse.Configuration('test', read=False) + @property + def config(self): + return confuse.Configuration('test', read=False) def tearDown(self): super(PrimaryConfigDirTest, self).tearDown() @@ -207,3 +212,8 @@ def test_do_not_create_dir_if_lower_priority_exists(self): self.assertEqual(self.config.config_dir(), path2) self.assertFalse(os.path.isdir(path1)) self.assertTrue(os.path.isdir(path2)) + + def test_override_config_dir(self): + path = os.path.join(self.home, 'asdfasdfasdfd', 'test') + config = confuse.Configuration('test', source=path, read=False) + self.assertEqual(config.config_dir(), path) diff --git a/test/test_sources.py b/test/test_sources.py new file mode 100644 index 0000000..8413236 --- /dev/null +++ b/test/test_sources.py @@ -0,0 +1,79 @@ +from __future__ import division, absolute_import, print_function + +import os +import confuse +import sys +import unittest + +PY3 = sys.version_info[0] == 3 + + +class ConfigSourceTest(unittest.TestCase): + def _load_yaml(self, file): + return {'a': 5, 'file': file} + + def setUp(self): + self._orig_load_yaml = confuse.load_yaml + confuse.load_yaml = self._load_yaml + + def tearDown(self): + confuse.load_yaml = self._orig_load_yaml + + def test_source_conversion(self): + # test pure dict source + src = confuse.ConfigSource.of({'a': 5}) + self.assertIsInstance(src, confuse.ConfigSource) + self.assertEqual(src.loaded, True) + # test yaml filename + src = confuse.ConfigSource.of('asdf/asfdd.yml') + self.assertIsInstance(src, confuse.YamlSource) + self.assertEqual(src.loaded, False) + self.assertEqual(src.exists, False) + self.assertEqual(src.config_dir(create=False), 'asdf') + + def test_explicit_load(self): + src = confuse.ConfigSource.of('asdf.yml') + self.assertEqual(src.loaded, False) + src.load() + self.assertEqual(src.loaded, True) + self.assertEqual(src['a'], 5) + + def test_load_getitem(self): + src = confuse.ConfigSource.of('asdf.yml') + self.assertEqual(src.loaded, False) + self.assertEqual(src['a'], 5) + self.assertEqual(src.loaded, True) + + # def test_load_cast_dict(self): + # src = confuse.ConfigSource.of('asdf.yml') + # self.assertEqual(src.loaded, False) + # src.load() + # self.assertEqual(dict(src), {'a': 5, 'file': 'asdf.yml'}) + # self.assertEqual(src.loaded, True) + + def test_load_keys(self): + src = confuse.ConfigSource.of('asdf.yml') + self.assertEqual(src.loaded, False) + self.assertEqual(set(src.keys()), {'a', 'file'}) + self.assertEqual(src.loaded, True) + + +class EnvSourceTest(unittest.TestCase): + def setenv(self, *a, **kw): + for dct in a + (kw,): + os.environ.update({k: str(v) for k, v in dct.items()}) + + def test_env_var_load(self): + prefix = 'asdf' + expected = {'a': {'b': '5', 'c': '6'}, 'b': '7'} + + # setup environment + before = set(os.environ.keys()) + self.setenv({prefix + 'a__b': 5, prefix + 'a__c': 6, prefix + 'b': 7}) + self.assertGreater(len(set(os.environ.keys()) - before), 0) + + src = confuse.EnvSource(prefix) + self.assertEqual(src.loaded, False) + src.load() + self.assertEqual(src.loaded, True) + self.assertEqual(dict(src), expected)