From 8a23d38d922fe51be3981da206d46ec43a221170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 3 Jan 2026 17:47:15 +0000 Subject: [PATCH 01/12] Configure mypy with strict settings --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 270324a..a49d0c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,3 +30,8 @@ show_contexts = true [mypy] allow_any_generics = false +allow_subclassing_any = true +allow_untyped_calls = true +# FIXME: Would be better to actually type the libraries (if under our control), +# or write our own stubs. For now, silence errors +strict = true From 90bc84b2e3ec923b978c4ca396eaca4231cd113b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Mon, 29 Dec 2025 22:12:51 +0000 Subject: [PATCH 02/12] Type confuse.util --- confuse/util.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/confuse/util.py b/confuse/util.py index bcfb909..f8ffc34 100644 --- a/confuse/util.py +++ b/confuse/util.py @@ -1,17 +1,26 @@ +from __future__ import annotations + import argparse import importlib.util import optparse import os import platform import sys +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterable + UNIX_DIR_FALLBACK = "~/.config" WINDOWS_DIR_VAR = "APPDATA" WINDOWS_DIR_FALLBACK = "~\\AppData\\Roaming" MAC_DIR = "~/Library/Application Support" +T = TypeVar("T") + -def iter_first(sequence): +def iter_first(sequence: Iterable[T]) -> T: """Get the first element from an iterable or raise a ValueError if the iterator generates no values. """ @@ -22,7 +31,9 @@ def iter_first(sequence): raise ValueError() -def namespace_to_dict(obj): +def namespace_to_dict( + obj: argparse.Namespace | optparse.Values | T, +) -> dict[str, Any] | T: """If obj is argparse.Namespace or optparse.Values we'll return a dict representation of it, else return the original object. @@ -37,7 +48,11 @@ def namespace_to_dict(obj): return obj -def build_dict(obj, sep="", keep_none=False): +def build_dict( + values_obj: argparse.Namespace | optparse.Values | T, + sep: str = "", + keep_none: bool = False, +) -> dict[str, Any] | T: """Recursively builds a dictionary from an argparse.Namespace, optparse.Values, or dict object. @@ -60,19 +75,19 @@ def build_dict(obj, sep="", keep_none=False): """ # We expect our root object to be a dict, but it may come in as # a namespace - obj = namespace_to_dict(obj) + obj = namespace_to_dict(values_obj) # We only deal with dictionaries if not isinstance(obj, dict): return obj # Get keys iterator - keys = obj.keys() + keys: Iterable[str] = obj.keys() if sep: # Splitting keys by `sep` needs sorted keys to prevent parents # from clobbering children keys = sorted(list(keys)) - output = {} + output: dict[str, Any] = {} for key in keys: value = obj[key] if value is None and not keep_none: # Avoid unset options. @@ -108,7 +123,7 @@ def build_dict(obj, sep="", keep_none=False): # defaults. -def find_package_path(name): +def find_package_path(name: str) -> str | None: """Returns the path to the package containing the named module or None if the path could not be identified (e.g., if ``name == "__main__"``). @@ -120,21 +135,21 @@ def find_package_path(name): except (ImportError, ValueError): return None - loader = spec.loader - if loader is None or name == "__main__": + if not spec or (loader := spec.loader) is None or name == "__main__": return None + filepath: str if hasattr(loader, "get_filename"): filepath = loader.get_filename(name) else: # Fall back to importing the specified module. __import__(name) - filepath = sys.modules[name].__file__ + filepath = sys.modules[name].__file__ # type: ignore[assignment] return os.path.dirname(os.path.abspath(filepath)) -def xdg_config_dirs(): +def xdg_config_dirs() -> list[str]: """Returns a list of paths taken from the XDG_CONFIG_DIRS and XDG_CONFIG_HOME environment varibables if they exist """ @@ -149,7 +164,7 @@ def xdg_config_dirs(): return paths -def config_dirs(): +def config_dirs() -> list[str]: """Return a platform-specific list of candidates for user configuration directories on the system. From 46b430fef3e8aeb0fbfc70b49f53cc0750691fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Tue, 30 Dec 2025 02:48:48 +0000 Subject: [PATCH 03/12] Type confuse.yaml_util --- confuse/yaml_util.py | 62 ++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/confuse/yaml_util.py b/confuse/yaml_util.py index 3796a5a..ea124cb 100644 --- a/confuse/yaml_util.py +++ b/confuse/yaml_util.py @@ -1,7 +1,15 @@ +from __future__ import annotations + from collections import OrderedDict +from typing import TYPE_CHECKING, Any import yaml +if TYPE_CHECKING: + from collections.abc import Hashable, Iterable, Iterator + + from _typeshed import SupportsItems + from .exceptions import ConfigReadError # YAML loading. @@ -17,18 +25,22 @@ class Loader(yaml.SafeLoader): """ # All strings should be Unicode objects, regardless of contents. - def _construct_unicode(self, node): + def _construct_unicode(self, node: yaml.ScalarNode) -> str: return self.construct_scalar(node) # Use ordered dictionaries for every YAML map. # From https://gist.github.com/844388 - def construct_yaml_map(self, node): - data = OrderedDict() + def construct_yaml_map( + self, node: yaml.MappingNode + ) -> Iterator[OrderedDict[object, object]]: + data: OrderedDict[object, object] = OrderedDict() yield data value = self.construct_mapping(node) data.update(value) - def construct_mapping(self, node, deep=False): + def construct_mapping( + self, node: yaml.MappingNode, deep: bool = False + ) -> dict[Hashable, Any]: if isinstance(node, yaml.MappingNode): self.flatten_mapping(node) else: @@ -39,7 +51,7 @@ def construct_mapping(self, node, deep=False): node.start_mark, ) - mapping = OrderedDict() + mapping: OrderedDict[object, object] = OrderedDict() for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) try: @@ -56,12 +68,12 @@ def construct_mapping(self, node, deep=False): return mapping # Allow bare strings to begin with %. Directives are still detected. - def check_plain(self): + def check_plain(self) -> bool: plain = super().check_plain() - return plain or self.peek() == "%" + return plain or self.peek() == "%" # type: ignore[no-any-return] @staticmethod - def add_constructors(loader): + def add_constructors(loader: type[yaml.SafeLoader]) -> None: """Modify a PyYAML Loader class to add extra constructors for strings and maps. Call this method on a custom Loader class to make it behave like Confuse's own Loader @@ -74,7 +86,7 @@ def add_constructors(loader): Loader.add_constructors(Loader) -def load_yaml(filename, loader=Loader): +def load_yaml(filename: str, loader: type[yaml.SafeLoader] = Loader) -> Any: """Read a YAML document from a file. If the file cannot be read or parsed, a ConfigReadError is raised. loader is the PyYAML Loader class to use to parse the YAML. By default, @@ -88,7 +100,9 @@ def load_yaml(filename, loader=Loader): raise ConfigReadError(filename, exc) -def load_yaml_string(yaml_string, name, loader=Loader): +def load_yaml_string( + yaml_string: str | bytes, name: str, loader: type[yaml.SafeLoader] = Loader +) -> Any: """Read a YAML document from a string. If the string cannot be parsed, a ConfigReadError is raised. `yaml_string` is a string to be parsed as a YAML document. @@ -103,7 +117,7 @@ def load_yaml_string(yaml_string, name, loader=Loader): raise ConfigReadError(name, exc) -def parse_as_scalar(value, loader=Loader): +def parse_as_scalar(value: object, loader: type[yaml.SafeLoader] = Loader) -> object: """Parse a value as if it were a YAML scalar to perform type conversion that is consistent with YAML documents. `value` should be a string. Non-string inputs or strings that raise YAML @@ -121,10 +135,10 @@ def parse_as_scalar(value, loader=Loader): if not isinstance(value, str): return value try: - loader = loader("") - tag = loader.resolve(yaml.ScalarNode, value, (True, False)) + loader_instance = loader("") + tag = loader_instance.resolve(yaml.ScalarNode, value, (True, False)) node = yaml.ScalarNode(tag, value) - return loader.construct_object(node) + return loader_instance.construct_object(node) except yaml.error.YAMLError: # Fallback to returning the value unchanged return value @@ -139,8 +153,13 @@ class Dumper(yaml.SafeDumper): """ # From http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py - def represent_mapping(self, tag, mapping, flow_style=None): - value = [] + def represent_mapping( + self, + tag: str, + mapping: SupportsItems[Any, Any] | Iterable[tuple[Any, Any]], + flow_style: bool | None = None, + ) -> yaml.MappingNode: + value: list[tuple[yaml.Node, yaml.Node]] = [] node = yaml.MappingNode(tag, value, flow_style=flow_style) if self.alias_key is not None: self.represented_objects[self.alias_key] = node @@ -162,19 +181,18 @@ def represent_mapping(self, tag, mapping, flow_style=None): node.flow_style = best_style return node - def represent_list(self, data): + def represent_list(self, data: Iterable[Any]) -> yaml.SequenceNode: """If a list has less than 4 items, represent it in inline style (i.e. comma separated, within square brackets). """ node = super().represent_list(data) - length = len(data) - if self.default_flow_style is None and length < 4: + if self.default_flow_style is None and len(list(data)) < 4: node.flow_style = True elif self.default_flow_style is None: node.flow_style = False return node - def represent_bool(self, data): + def represent_bool(self, data: bool) -> yaml.ScalarNode: """Represent bool as 'yes' or 'no' instead of 'true' or 'false'.""" if data: value = "yes" @@ -182,7 +200,7 @@ def represent_bool(self, data): value = "no" return self.represent_scalar("tag:yaml.org,2002:bool", value) - def represent_none(self, data): + def represent_none(self, data: Any) -> yaml.ScalarNode: """Represent a None value with nothing instead of 'none'.""" return self.represent_scalar("tag:yaml.org,2002:null", "") @@ -193,7 +211,7 @@ def represent_none(self, data): Dumper.add_representer(list, Dumper.represent_list) -def restore_yaml_comments(data, default_data): +def restore_yaml_comments(data: str, default_data: str) -> str: """Scan default_data for comments (we include empty lines in our definition of comments) and place them before the same keys in data. Only works with comments that are on one or more own lines, i.e. From e1a8d0ff51180989f2a8d89732fb68340b8a9070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Tue, 30 Dec 2025 06:30:58 +0000 Subject: [PATCH 04/12] Type confuse.sources --- confuse/sources.py | 75 ++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/confuse/sources.py b/confuse/sources.py index a4961e5..5439478 100644 --- a/confuse/sources.py +++ b/confuse/sources.py @@ -1,15 +1,28 @@ +from __future__ import annotations + import os +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import yaml from . import yaml_util from .util import build_dict -class ConfigSource(dict[str, object]): +class ConfigSource(dict[str, Any]): """A dictionary augmented with metadata about the source of the configuration. """ - def __init__(self, value, filename=None, default=False, base_for_paths=False): + def __init__( + self, + value: Mapping[str, object], + filename: str | None = None, + default: bool = False, + base_for_paths: bool = False, + ): """Create a configuration source from a dictionary. :param filename: The file with the data for this configuration source. @@ -32,14 +45,14 @@ def __init__(self, value, filename=None, default=False, base_for_paths=False): self.default = default self.base_for_paths = base_for_paths if filename is not None else False - def __repr__(self): + def __repr__(self) -> str: return ( f"ConfigSource({super()!r}, {self.filename!r}, {self.default!r}, " f"{self.base_for_paths!r})" ) @classmethod - def of(cls, value): + def of(cls, value: Mapping[str, object] | ConfigSource) -> ConfigSource: """Given either a dictionary or a `ConfigSource` object, return a `ConfigSource` object. This lets a function accept either type of object as an argument. @@ -57,11 +70,11 @@ class YamlSource(ConfigSource): def __init__( self, - filename=None, - default=False, - base_for_paths=False, - optional=False, - loader=yaml_util.Loader, + filename: str | None = None, + default: bool = False, + base_for_paths: bool = False, + optional: bool = False, + loader: type[yaml.SafeLoader] = yaml_util.Loader, ): """Create a YAML data source by reading data from a file. @@ -70,19 +83,29 @@ def __init__( file does not exist---instead, the source will be silently empty. """ - filename = os.path.abspath(filename) + if filename is not None: + filename = os.path.abspath(filename) super().__init__({}, filename, default, base_for_paths) self.loader = loader self.optional = optional self.load() - def load(self): + def load(self) -> None: """Load YAML data from the source's filename.""" - if self.optional and not os.path.isfile(self.filename): - value = {} + if self.optional and ( + self.filename is None or not os.path.isfile(self.filename) + ): + value: object = {} + elif self.filename is None: + raise TypeError("filename is required for YamlSource") else: value = yaml_util.load_yaml(self.filename, loader=self.loader) or {} - self.update(value) + + if isinstance(value, Mapping): + self.update(value) + else: + # We enforce that the loaded YAML is a mapping (dict) + raise TypeError(f"YAML config must be a mapping, got {type(value)}") class EnvSource(ConfigSource): @@ -90,12 +113,12 @@ class EnvSource(ConfigSource): def __init__( self, - prefix, - sep="__", - lower=True, - handle_lists=True, - parse_yaml_docs=False, - loader=yaml_util.Loader, + prefix: str, + sep: str = "__", + lower: bool = True, + handle_lists: bool = True, + parse_yaml_docs: bool = False, + loader: type[yaml.SafeLoader] = yaml_util.Loader, ): """Create a configuration source from the environment. @@ -126,10 +149,10 @@ def __init__( self.loader = loader self.load() - def load(self): + def load(self) -> None: """Load configuration data from the environment.""" # Read config variables with prefix from the environment. - config_vars = {} + config_vars: dict[str, object] = {} for var, value in os.environ.items(): if var.startswith(self.prefix): key = var[len(self.prefix) :] @@ -140,7 +163,7 @@ def load(self): # string representations of dicts and lists into the # appropriate object (ie, '{foo: bar}' to {'foo': 'bar'}). # Will raise a ConfigReadError if YAML parsing fails. - value = yaml_util.load_yaml_string( + val = yaml_util.load_yaml_string( value, f"env variable {var}", loader=self.loader ) else: @@ -148,8 +171,8 @@ def load(self): # converted using the same rules as the YAML Loader (ie, # numeric string to int/float, 'true' to True, etc.). Will # not raise a ConfigReadError. - value = yaml_util.parse_as_scalar(value, loader=self.loader) - config_vars[key] = value + val = yaml_util.parse_as_scalar(value, loader=self.loader) + config_vars[key] = val if self.sep: # Build a nested dict, keeping keys with `None` values to allow # environment variables to unset values from lower priority sources @@ -160,7 +183,7 @@ def load(self): self.update(config_vars) @classmethod - def _convert_dict_lists(cls, obj): + def _convert_dict_lists(cls, obj: object) -> object: """Recursively search for dicts where all of the keys are integers from 0 to the length of the dict, and convert them to lists. """ From 9dc651e62104fa98530aa3a6bb84da79d79e3872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 3 Jan 2026 17:41:39 +0000 Subject: [PATCH 05/12] Type confuse.exceptions --- confuse/exceptions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/confuse/exceptions.py b/confuse/exceptions.py index 1a4a87e..0bc5144 100644 --- a/confuse/exceptions.py +++ b/confuse/exceptions.py @@ -1,4 +1,4 @@ -import yaml +from yaml.scanner import ScannerError __all__ = [ "ConfigError", @@ -37,14 +37,15 @@ class ConfigTemplateError(ConfigError): class ConfigReadError(ConfigError): """A configuration source could not be read.""" - def __init__(self, name, reason=None): + def __init__(self, name: str, reason: Exception | None = None) -> None: self.name = name self.reason = reason message = f"{name} could not be read" if ( - isinstance(reason, yaml.scanner.ScannerError) + isinstance(reason, ScannerError) and reason.problem == YAML_TAB_PROBLEM + and reason.problem_mark ): # Special-case error message for tab indentation in YAML markup. message += ( From 77360a03f4e3644c9dde5acd38a47769f0ed8a41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 3 Jan 2026 17:46:12 +0000 Subject: [PATCH 06/12] Type confuse.templates --- confuse/__init__.py | 2 +- confuse/templates.py | 521 +++++++++++++++++++++++++++++-------------- test/test_valid.py | 2 +- 3 files changed, 356 insertions(+), 169 deletions(-) diff --git a/confuse/__init__.py b/confuse/__init__.py index 6c41fda..264ed60 100644 --- a/confuse/__init__.py +++ b/confuse/__init__.py @@ -4,5 +4,5 @@ from .exceptions import * # noqa: F403 from .sources import * # noqa: F403 from .templates import * # noqa: F403 -from .util import * # noqa: F403 +from .util import * # type: ignore[no-redef] # noqa: F403 from .yaml_util import * # noqa: F403 diff --git a/confuse/templates.py b/confuse/templates.py index 4c4c830..9f71c84 100644 --- a/confuse/templates.py +++ b/confuse/templates.py @@ -1,18 +1,56 @@ +from __future__ import annotations + import enum import os import pathlib import re from collections import abc +from collections.abc import Hashable, Iterable, Mapping +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any, Generic, NoReturn, overload + +from typing_extensions import TypeVar from . import exceptions, util -REQUIRED = object() +if TYPE_CHECKING: + from .core import ConfigView, Subview + +T = TypeVar("T") +K = TypeVar("K", bound=Hashable, default=str) +Kstr = TypeVar("Kstr", bound=str, default=str) +P = TypeVar("P", bound=pathlib.Path | str, default=str) +V = TypeVar("V", default=object) + + +class _Required: + """Marker class for required sentinel.""" + + pass + + +REQUIRED = _Required() """A sentinel indicating that there is no default value and an exception should be raised when the value is missing. """ -class Template: +class AttrDict(dict[Kstr, V]): + """A `dict` subclass that can be accessed via attributes (dot + notation) for convenience. + """ + + def __getattr__(self, key: str) -> V: + if key in self: + return self[key] # type: ignore[index] + else: + raise AttributeError(key) + + def __setattr__(self, key: str, value: V) -> None: + self[key] = value # type: ignore[index] + + +class Template(Generic[T]): """A value template for configuration fields. The template works like a type and instructs Confuse about how to @@ -21,7 +59,7 @@ class Template: filepath type might expand tildes and check that the file exists. """ - def __init__(self, default=REQUIRED): + def __init__(self, default: object | _Required = REQUIRED) -> None: """Create a template with a given default value. If `default` is the sentinel `REQUIRED` (as it is by default), @@ -30,13 +68,15 @@ def __init__(self, default=REQUIRED): """ self.default = default - def __call__(self, view): + def __call__(self, view: ConfigView) -> T: """Invoking a template on a view gets the view's value according to the template. """ return self.value(view, self) - def value(self, view, template=None): + def value( + self, view: ConfigView, template: Template[T] | object | None = None + ) -> T: """Get the value for a `ConfigView`. May raise a `NotFoundError` if the value is missing (and the @@ -51,7 +91,7 @@ def value(self, view, template=None): # Get default value, or raise if required. return self.get_default_value(view.name) - def get_default_value(self, key_name="default"): + def get_default_value(self, key_name: str = "default") -> T: """Get the default value to return when the value is missing. May raise a `NotFoundError` if the value is required. @@ -60,9 +100,9 @@ def get_default_value(self, key_name="default"): # The value is required. A missing value is an error. raise exceptions.NotFoundError(f"{key_name} not found") # The value is not required. - return self.default + return self.default # type: ignore[return-value] - def convert(self, value, view): + def convert(self, value: Any, view: ConfigView) -> T: """Convert the YAML-deserialized value to a value of the desired type. @@ -70,9 +110,11 @@ def convert(self, value, view): May raise a `ConfigValueError` when the configuration is wrong. """ # Default implementation does no conversion. - return value + return value # type: ignore[no-any-return] - def fail(self, message, view, type_error=False): + def fail( + self, message: str, view: ConfigView, type_error: bool = False + ) -> NoReturn: """Raise an exception indicating that a value cannot be accepted. @@ -85,17 +127,17 @@ def fail(self, message, view, type_error=False): ) raise exc_class(f"{view.name}: {message}") - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( type(self).__name__, "" if self.default is REQUIRED else repr(self.default), ) -class Integer(Template): +class Integer(Template[int]): """An integer configuration value template.""" - def convert(self, value, view): + def convert(self, value: int | float, view: ConfigView) -> int: """Check that the value is an integer. Floats are rounded.""" if isinstance(value, int): return value @@ -105,10 +147,13 @@ def convert(self, value, view): self.fail("must be a number", view, True) -class Number(Template): +Numeric = TypeVar("Numeric", int, float) + + +class Number(Template[Numeric]): """A numeric type: either an integer or a floating-point number.""" - def convert(self, value, view): + def convert(self, value: Numeric, view: ConfigView) -> Numeric: """Check that the value is an int or a float.""" if isinstance(value, (int, float)): return value @@ -116,57 +161,65 @@ def convert(self, value, view): self.fail(f"must be numeric, not {type(value).__name__}", view, True) -class MappingTemplate(Template): +class MappingTemplate(Template[AttrDict[Kstr, V]]): """A template that uses a dictionary to specify other types for the values for a set of keys and produce a validated `AttrDict`. """ - def __init__(self, mapping): + def __init__(self, mapping: Mapping[Kstr, Template[V] | type[V]]) -> None: """Create a template according to a dict (mapping). The mapping's values should themselves either be Types or convertible to Types. """ - subtemplates = {} + subtemplates: dict[Kstr, Template[V]] = {} for key, typ in mapping.items(): subtemplates[key] = as_template(typ) self.subtemplates = subtemplates - def value(self, view, template=None): + def value( + self, + view: ConfigView, + template: Template[AttrDict[Kstr, V]] | object | None = None, + ) -> AttrDict[Kstr, V]: """Get a dict with the same keys as the template and values validated according to the value types. """ - out = AttrDict() - for key, typ in self.subtemplates.items(): - out[key] = typ.value(view[key], self) - return out + return AttrDict( + {k: v.value(view[k], self) for k, v in self.subtemplates.items()} + ) - def __repr__(self): + def __repr__(self) -> str: return f"MappingTemplate({self.subtemplates!r})" -class Sequence(Template): +class Sequence(Template[list[T]]): """A template used to validate lists of similar items, based on a given subtemplate. """ - def __init__(self, subtemplate): + subtemplate: Template[T] + + def __init__(self, subtemplate: Template[T] | object): """Create a template for a list with items validated on a given subtemplate. """ + super().__init__() self.subtemplate = as_template(subtemplate) - def value(self, view, template=None): + def value( + self, view: ConfigView, template: Template[list[T]] | object | None = None + ) -> list[T]: """Get a list of items validated against the template.""" out = [] for item in view.sequence(): out.append(self.subtemplate.value(item, self)) return out - def __repr__(self): + def __repr__(self) -> str: return f"Sequence({self.subtemplate!r})" -class MappingValues(Template): +class MappingValues(Template[dict[str, T]]): """A template used to validate mappings of similar items, based on a given subtemplate applied to the values. @@ -175,13 +228,18 @@ class MappingValues(Template): Sequence template but for mappings. """ - def __init__(self, subtemplate): + subtemplate: Template[T] + + def __init__(self, subtemplate: Template[T] | object): """Create a template for a mapping with variable keys and item values validated on a given subtemplate. """ + super().__init__() self.subtemplate = as_template(subtemplate) - def value(self, view, template=None): + def value( + self, view: ConfigView, template: Template[dict[str, T]] | object | None = None + ) -> dict[str, T]: """Get a dict with the same keys as the view and the value of each item validated against the subtemplate. """ @@ -190,14 +248,19 @@ def value(self, view, template=None): out[key] = self.subtemplate.value(item, self) return out - def __repr__(self): + def __repr__(self) -> str: return f"MappingValues({self.subtemplate!r})" -class String(Template): +class String(Template[str]): """A string configuration value template.""" - def __init__(self, default=REQUIRED, pattern=None, expand_vars=False): + def __init__( + self, + default: str | _Required = REQUIRED, + pattern: str | None = None, + expand_vars: bool = False, + ): """Create a template with the added optional `pattern` argument, a regular expression string that the value should match. """ @@ -207,7 +270,7 @@ def __init__(self, default=REQUIRED, pattern=None, expand_vars=False): if pattern: self.regex = re.compile(pattern) - def __repr__(self): + def __repr__(self) -> str: args = [] if self.default is not REQUIRED: @@ -218,7 +281,7 @@ def __repr__(self): return f"String({', '.join(args)})" - def convert(self, value, view): + def convert(self, value: object, view: ConfigView) -> str: """Check that the value is a string and matches the pattern.""" if not isinstance(value, str): self.fail("must be a string", view, True) @@ -232,14 +295,20 @@ def convert(self, value, view): return value -class Choice(Template): +class Choice(Template[T], Generic[T, K]): """A template that permits values from a sequence of choices. Sequences, dictionaries and :class:`Enum` types are supported, see :meth:`__init__` for usage. """ - def __init__(self, choices, default=REQUIRED): + choices: abc.Sequence[T] | dict[K, T] | type[T] + + def __init__( + self, + choices: abc.Sequence[T] | dict[K, T] | type[T], + default: T | _Required = REQUIRED, + ) -> None: """Create a template that validates any of the values from the iterable `choices`. @@ -252,43 +321,67 @@ def __init__(self, choices, default=REQUIRED): super().__init__(default) self.choices = choices - def convert(self, value, view): + @singledispatchmethod + def convert_choices( + self, choices: abc.Sequence[T] | dict[K, T] | type[T], value: str + ) -> T: + raise NotImplementedError + + @convert_choices.register(type) + def _(self, choices: type[T], value: str) -> T: + return choices(value) # type: ignore[call-arg] + + @convert_choices.register(dict) + def _(self, choices: dict[K, T], value: K) -> T: + return choices[value] + + @convert_choices.register(abc.Sequence) + def _(self, choices: abc.Sequence[T], value: T) -> T: + return choices[choices.index(value)] + + @singledispatchmethod + def format_choices(self, choices: abc.Sequence[T] | enum.Enum) -> list[str]: + raise NotImplementedError + + @format_choices.register(type) + def _(self, choices: type[enum.Enum]) -> list[str]: + return [c.value for c in choices] + + @format_choices.register(abc.Sequence) + @format_choices.register(Mapping) + def _(self, choices: Iterable[T]) -> list[str]: + return list(map(str, choices)) + + def convert(self, value: object, view: ConfigView) -> T: """Ensure that the value is among the choices (and remap if the choices are a mapping). """ - if isinstance(self.choices, type) and issubclass(self.choices, enum.Enum): - try: - return self.choices(value) - except ValueError: - self.fail( - f"must be one of {[c.value for c in self.choices]!r}, not " - f"{value!r}", - view, - ) - - if value not in self.choices: + try: + return self.convert_choices(self.choices, value) + except (KeyError, ValueError): self.fail( - f"must be one of {list(self.choices)!r}, not {value!r}", + f"must be one of {self.format_choices(self.choices)!r}, not {value!r}", view, ) - if isinstance(self.choices, abc.Mapping): - return self.choices[value] - else: - return value - - def __repr__(self): + def __repr__(self) -> str: return f"Choice({self.choices!r})" -class OneOf(Template): +class OneOf(Template[T]): """A template that permits values complying to one of the given templates.""" - def __init__(self, allowed, default=REQUIRED): + allowed: list[Template[T]] + template: Template[Any] | None + + def __init__( + self, allowed: Iterable[Template[T] | object], default: T | _Required = REQUIRED + ): super().__init__(default) - self.allowed = list(allowed) + self.allowed = [as_template(t) for t in allowed] + self.template = None - def __repr__(self): + def __repr__(self) -> str: args = [] if self.allowed is not None: @@ -299,19 +392,28 @@ def __repr__(self): return f"OneOf({', '.join(args)})" - def value(self, view, template): - self.template = template + def value( + self, view: ConfigView, template: Template[T] | object | None = None + ) -> T: + self.template = template if isinstance(template, Template) else None return super().value(view, template) - def convert(self, value, view): + def convert(self, value: object, view: Subview) -> T: # type: ignore[override] """Ensure that the value follows at least one template.""" is_mapping = isinstance(self.template, MappingTemplate) for candidate in self.allowed: try: if is_mapping: - next_template = MappingTemplate({view.key: candidate}) - return view.parent.get(next_template)[view.key] + assert self.template is not None + from .core import Subview + + if isinstance(view, Subview): + # Use a new MappingTemplate to check the sibling value + next_template = MappingTemplate({view.key: candidate}) + return view.parent.get(next_template)[view.key] + else: + self.fail("MappingTemplate must be used with a Subview", view) else: return view.get(candidate) except exceptions.ConfigTemplateError: @@ -324,14 +426,23 @@ def convert(self, value, view): self.fail(f"must be one of {self.allowed!r}, not {value!r}", view) -class StrSeq(Template): +class BytesToStrMixin: + def normalize_bytes(self, x: str | bytes) -> str: + if isinstance(x, bytes): + return x.decode("utf-8", "ignore") + return x + + +class StrSeq(BytesToStrMixin, Template[list[str]]): """A template for values that are lists of strings. Validates both actual YAML string lists and single strings. Strings can optionally be split on whitespace. """ - def __init__(self, split=True, default=REQUIRED): + def __init__( + self, split: bool = True, default: list[str] | _Required = REQUIRED + ) -> None: """Create a new template. `split` indicates whether, when the underlying value is a single @@ -341,32 +452,32 @@ def __init__(self, split=True, default=REQUIRED): super().__init__(default) self.split = split - def _convert_value(self, x, view): - if isinstance(x, str): - return x - elif isinstance(x, bytes): - return x.decode("utf-8", "ignore") - else: + def _convert_value(self, x: object, view: ConfigView) -> str: + if not isinstance(x, (str, bytes)): self.fail("must be a list of strings", view, True) - def convert(self, value, view): + return self.normalize_bytes(x) + + def convert( + self, value: str | bytes | list[str | bytes], view: ConfigView + ) -> list[str]: if isinstance(value, bytes): value = value.decode("utf-8", "ignore") if isinstance(value, str): if self.split: - value = value.split() + values: Iterable[object] = value.split() else: - value = [value] + values = [value] + elif isinstance(value, Iterable): + values = value else: - try: - value = list(value) - except TypeError: - self.fail("must be a whitespace-separated string or a list", view, True) - return [self._convert_value(v, view) for v in value] + self.fail("must be a whitespace-separated string or a list", view, True) + return [self._convert_value(v, view) for v in values] -class Pairs(StrSeq): + +class Pairs(BytesToStrMixin, Template[list[tuple[str, object]]]): """A template for ordered key-value pairs. This can either be given with the same syntax as for `StrSeq` (i.e. without @@ -380,35 +491,41 @@ class Pairs(StrSeq): `default_value` will be returned as the second element. """ - def __init__(self, default_value=None): + def __init__(self, default_value: object | None = None) -> None: """Create a new template. `default` is the dictionary value returned for items that are not a mapping, but a single string. """ - super().__init__(split=True) + super().__init__() self.default_value = default_value - def _convert_value(self, x, view): - try: - return (super()._convert_value(x, view), self.default_value) - except exceptions.ConfigTypeError: - if isinstance(x, abc.Mapping): - if len(x) != 1: - self.fail("must be a single-element mapping", view, True) - k, v = util.iter_first(x.items()) - elif isinstance(x, abc.Sequence): - if len(x) != 2: - self.fail("must be a two-element list", view, True) - k, v = x - else: - # Is this even possible? -> Likely, if some !directive cause - # YAML to parse this to some custom type. - self.fail(f"must be a single string, mapping, or a list{x}", view, True) - return (super()._convert_value(k, view), super()._convert_value(v, view)) + def _convert_value(self, x: object, view: ConfigView) -> tuple[str, object]: + if isinstance(x, (str, bytes)): + return self.normalize_bytes(x), self.default_value + + if isinstance(x, Mapping): + if len(x) != 1: + self.fail("must be a single-element mapping", view, True) + k, v = util.iter_first(x.items()) + elif isinstance(x, abc.Sequence): + if len(x) != 2: + self.fail("must be a two-element list", view, True) + k, v = x + else: + # Is this even possible? -> Likely, if some !directive cause + # YAML to parse this to some custom type. + self.fail(f"must be a single string, mapping, or a list{x}", view, True) + + return self.normalize_bytes(k), self.normalize_bytes(v) + + def convert( + self, value: list[abc.Sequence[str] | Mapping[str, str]], view: ConfigView + ) -> list[tuple[str, object]]: + return [self._convert_value(v, view) for v in value] -class Filename(Template): +class Filename(Template[P]): """A template that validates strings as filenames. Filenames are returned as absolute, tilde-free paths. @@ -425,12 +542,12 @@ class Filename(Template): def __init__( self, - default=REQUIRED, - cwd=None, - relative_to=None, - in_app_dir=False, - in_source_dir=False, - ): + default: T | _Required = REQUIRED, + cwd: str | None = None, + relative_to: str | None = None, + in_app_dir: bool = False, + in_source_dir: bool = False, + ) -> None: """`relative_to` is the name of a sibling value that is being validated at the same time. @@ -448,7 +565,7 @@ def __init__( self.in_app_dir = in_app_dir self.in_source_dir = in_source_dir - def __repr__(self): + def __repr__(self) -> str: args = [] if self.default is not REQUIRED: @@ -468,8 +585,10 @@ def __repr__(self): return f"Filename({', '.join(args)})" - def resolve_relative_to(self, view, template): - if not isinstance(template, (abc.Mapping, MappingTemplate)): + def resolve_relative_to( + self, view: Subview, template: MappingTemplate | Mapping[str, Any] | None + ) -> str: + if not isinstance(template, (Mapping, MappingTemplate)): # disallow config.get(Filename(relative_to='foo')) raise exceptions.ConfigTemplateError( "relative_to may only be used when getting multiple values." @@ -485,12 +604,18 @@ def resolve_relative_to(self, view, template): view, ) - old_template = {} - old_template.update(template.subtemplates) + # Use a safe way to access subtemplates + if isinstance(template, MappingTemplate): + subtemplates = template.subtemplates + else: + # template is a Mapping + subtemplates = {k: as_template(v) for k, v in template.items()} + + old_template = dict(subtemplates) # save time by skipping MappingTemplate's init loop next_template = MappingTemplate({}) - next_relative = self.relative_to + next_relative: str | None = self.relative_to # gather all the needed templates and nothing else while next_relative is not None: @@ -499,7 +624,7 @@ def resolve_relative_to(self, view, template): # relative paths rel_to_template = old_template.pop(next_relative) except KeyError: - if next_relative in template.subtemplates: + if next_relative in subtemplates: # we encountered this config key previously raise exceptions.ConfigTemplateError( f"{view.name} and {self.relative_to} are recursively relative" @@ -511,61 +636,70 @@ def resolve_relative_to(self, view, template): ) next_template.subtemplates[next_relative] = rel_to_template - next_relative = rel_to_template.relative_to + next_relative_val = getattr(rel_to_template, "relative_to", None) + next_relative = ( + next_relative_val if isinstance(next_relative_val, str) else None + ) - return view.parent.get(next_template)[self.relative_to] + return view.parent.get(next_template)[self.relative_to] # type: ignore[return-value] - def value(self, view, template=None): + def value( + self, view: ConfigView, template: Template[T] | object | None = None + ) -> P: try: path, source = view.first() except exceptions.NotFoundError: return self.get_default_value(view.name) - if not isinstance(path, str): + if not isinstance(path, (str, bytes)): self.fail(f"must be a filename, not {type(path).__name__}", view, True) - path = os.path.expanduser(str(path)) - if not os.path.isabs(path): + if isinstance(path, bytes): + path_str = path.decode("utf-8", "ignore") + else: + path_str = path + + path_str = os.path.expanduser(path_str) + + if not os.path.isabs(path_str): if self.cwd is not None: # relative to the template's argument - path = os.path.join(self.cwd, path) + path_str = os.path.join(self.cwd, path_str) elif self.relative_to is not None: - path = os.path.join( + path_str = os.path.join( self.resolve_relative_to(view, template), - path, + path_str, ) elif (source.filename and self.in_source_dir) or ( source.base_for_paths and not self.in_app_dir ): # relative to the directory the source file is in. - path = os.path.join(os.path.dirname(source.filename), path) + path_str = os.path.join(os.path.dirname(source.filename), path_str) elif source.filename or self.in_app_dir: # From defaults: relative to the app's directory. - path = os.path.join(view.root().config_dir(), path) + path_str = os.path.join(view.root().config_dir(), path_str) - return os.path.abspath(path) + return os.path.abspath(path_str) -class Path(Filename): +class Path(Filename[pathlib.Path]): """A template that validates strings as `pathlib.Path` objects. Filenames are parsed equivalent to the `Filename` template and then converted to `pathlib.Path` objects. """ - def value(self, view, template=None): - value = super().value(view, template) - if value is None: - return - import pathlib - - return pathlib.Path(value) + def value( + self, view: ConfigView, template: Template[pathlib.Path] | object | None = None + ) -> pathlib.Path: + val = super().value(view, template) + return pathlib.Path(val) if val is not None else None -class Optional(Template): +class Optional(Template[T | None]): """A template that makes a subtemplate optional. If the value is present and not null, it must validate against the @@ -574,7 +708,38 @@ class Optional(Template): the template will not allow missing values while still permitting null. """ - def __init__(self, subtemplate, default=None, allow_missing=True): + subtemplate: Template[T] + + @overload + def __init__( + self, + subtemplate: type[T], + default: T | None = None, + allow_missing: bool = True, + ) -> None: ... + + @overload + def __init__( + self, + subtemplate: Template[T], + default: T | None = None, + allow_missing: bool = True, + ) -> None: ... + + @overload + def __init__( + self, + subtemplate: T, + default: T | None = None, + allow_missing: bool = True, + ) -> None: ... + + def __init__( + self, + subtemplate: Template[T] | type[T] | T, + default: T | None = None, + allow_missing: bool = True, + ): self.subtemplate = as_template(subtemplate) if default is None: # When no default is passed, try to use the subtemplate's @@ -586,13 +751,15 @@ def __init__(self, subtemplate, default=None, allow_missing=True): self.default = default self.allow_missing = allow_missing - def value(self, view, template=None): + def value( + self, view: ConfigView, template: Template[T | None] | object | None = None + ) -> T | None: try: value, _ = view.first() except exceptions.NotFoundError: if self.allow_missing: # Value is missing but not required - return self.default + return self.default # type: ignore[return-value] # Value must be present even though it can be null. Raise an error. raise exceptions.NotFoundError(f"{view.name} not found") @@ -601,51 +768,71 @@ def value(self, view, template=None): return self.default return self.subtemplate.value(view, self) - def __repr__(self): + def __repr__(self) -> str: return ( f"Optional({self.subtemplate!r}, {self.default!r}, " f"allow_missing={self.allow_missing})" ) -class TypeTemplate(Template): +class TypeTemplate(Template[T]): """A simple template that checks that a value is an instance of a desired Python type. """ - def __init__(self, typ, default=REQUIRED): + def __init__(self, typ: type[T], default: object = REQUIRED): """Create a template that checks that the value is an instance of `typ`. """ super().__init__(default) self.typ = typ - def convert(self, value, view): - if not isinstance(value, self.typ): - self.fail( - f"must be a {self.typ.__name__}, not {type(value).__name__}", - view, - True, - ) - return value - - -class AttrDict(dict[str, object]): - """A `dict` subclass that can be accessed via attributes (dot - notation) for convenience. - """ - - def __getattr__(self, key): - if key in self: - return self[key] - else: - raise AttributeError(key) + def convert(self, value: Any, view: ConfigView) -> T: + if isinstance(value, self.typ): + return value - def __setattr__(self, key, value): - self[key] = value + self.fail( + f"must be a {self.typ.__name__}, not {type(value).__name__}", + view, + True, + ) -def as_template(value): +@overload +def as_template(value: Template[T]) -> Template[T]: ... +@overload +def as_template(value: Mapping[str, object]) -> MappingTemplate: ... +@overload +def as_template(value: type[int]) -> Integer: ... +@overload +def as_template(value: int) -> Integer: ... +@overload +def as_template(value: type[str]) -> String: ... +@overload +def as_template(value: str) -> String: ... +@overload +def as_template(value: type[enum.Enum]) -> Choice[enum.Enum]: ... +@overload +def as_template(value: set[T]) -> Choice[T]: ... +@overload +def as_template(value: list[T]) -> OneOf[T]: ... +@overload +def as_template(value: type[float]) -> Number[float]: ... +@overload +def as_template(value: float) -> Number[float]: ... +@overload +def as_template(value: pathlib.PurePath) -> Path: ... +@overload +def as_template(value: None) -> Template[None]: ... +@overload +def as_template(value: type[dict[str, T]]) -> TypeTemplate[abc.Mapping[str, T]]: ... +@overload +def as_template(value: type[list[T]]) -> TypeTemplate[abc.Sequence[T]]: ... +@overload +def as_template(value: object) -> Template[Any]: ... + + +def as_template(value: Any) -> Template[Any]: """Convert a simple "shorthand" Python value to a `Template`.""" if isinstance(value, Template): # If it's already a Template, pass it through. @@ -672,7 +859,7 @@ def as_template(value): return Number() elif isinstance(value, float): return Number(value) - elif isinstance(value, pathlib.PurePath): + elif isinstance(value, pathlib.Path): return Path(value) elif value is None: return Template(None) diff --git a/test/test_valid.py b/test/test_valid.py index fad61a6..14c33f8 100644 --- a/test/test_valid.py +++ b/test/test_valid.py @@ -314,7 +314,7 @@ class BadTemplate: pass config = _root({}) - with pytest.raises(confuse.ConfigTemplateError): + with pytest.raises(ValueError, match="cannot convert to template"): config.get(confuse.OneOf([BadTemplate()])) del BadTemplate From ed0b0f108e24bf79200479d3e7f001dee8cb453b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 3 Jan 2026 17:46:28 +0000 Subject: [PATCH 07/12] Finish typing confuse.core --- confuse/core.py | 178 ++++++++++++++++++++++++++++++------------------ 1 file changed, 110 insertions(+), 68 deletions(-) diff --git a/confuse/core.py b/confuse/core.py index eef04be..ca48908 100644 --- a/confuse/core.py +++ b/confuse/core.py @@ -31,16 +31,21 @@ import errno import os from collections import OrderedDict -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, overload import yaml +from typing_extensions import Self from . import templates, util, yaml_util from .exceptions import ConfigError, ConfigTypeError, NotFoundError from .sources import ConfigSource, EnvSource, YamlSource if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + import builtins + import pathlib + from argparse import Namespace + from collections.abc import Iterable, Iterator, Mapping, Sequence + from optparse import Values from pathlib import Path CONFIG_FILENAME = "config.yaml" @@ -63,12 +68,12 @@ class ConfigView: ``view[key]``). """ - name = None + name: str """The name of the view, depicting the path taken through the configuration in Python-like syntax (e.g., ``foo['bar'][42]``). """ - def resolve(self): + def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: """The core (internal) data retrieval method. Generates (value, source) pairs for each source that contains a value for this view. May raise `ConfigTypeError` if a type error occurs while @@ -76,7 +81,7 @@ def resolve(self): """ raise NotImplementedError - def first(self): + def first(self) -> tuple[dict[str, Any], ConfigSource]: """Return a (value, source) pair for the first object found for this view. This amounts to the first element returned by `resolve`. If no values are available, a `NotFoundError` is @@ -88,7 +93,7 @@ def first(self): except ValueError: raise NotFoundError(f"{self.name} not found") - def exists(self): + def exists(self) -> bool: """Determine whether the view has a setting in any source.""" try: self.first() @@ -96,28 +101,28 @@ def exists(self): return False return True - def add(self, value): + def add(self, value: dict[str, Any]) -> None: """Set the *default* value for this configuration view. The specified value is added as the lowest-priority configuration data source. """ raise NotImplementedError - def set(self, value): + def set(self, value: dict[str, Any]) -> None: """*Override* the value for this configuration view. The specified value is added as the highest-priority configuration data source. """ raise NotImplementedError - def root(self): + def root(self) -> RootView: """The RootView object from which this view is descended.""" raise NotImplementedError - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self.name}>" - def __iter__(self): + def __iter__(self) -> Iterator[Subview | str]: """Iterate over the keys of a dictionary view or the *subviews* of a list view. """ @@ -137,20 +142,22 @@ def __iter__(self): f"{type(item).__name__}" ) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Subview: """Get a subview of this view.""" return Subview(self, key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: """Create an overlay source to assign a given key under this view. """ self.set({key: value}) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return self[key].exists() - def set_args(self, namespace, dots=False): + def set_args( + self, namespace: dict[str, Any] | Namespace | Values, dots: bool = False + ) -> None: """Overlay parsed command-line arguments, generated by a library like argparse or optparse, onto this view's value. @@ -173,17 +180,17 @@ def set_args(self, namespace, dots=False): # example, rather than using ``view.get(bool)``, it's possible to # just say ``bool(view)`` or use ``view`` in a conditional. - def __str__(self): + def __str__(self) -> str: """Get the value for this view as a bytestring.""" return str(self.get()) - def __bool__(self): + def __bool__(self) -> bool: """Gets the value for this view as a bool.""" return bool(self.get()) # Dictionary emulation methods. - def keys(self): + def keys(self) -> list[str]: """Returns a list containing all the keys available as subviews of the current views. This enumerates all the keys in *all* dictionaries matching the current view, in contrast to @@ -208,7 +215,7 @@ def keys(self): return keys - def items(self): + def items(self) -> Iterator[tuple[str, Subview]]: """Iterates over (key, subview) pairs contained in dictionaries from *all* sources at this view. If the object for this view in any source is not a dict, then a `ConfigTypeError` is raised. @@ -216,7 +223,7 @@ def items(self): for key in self.keys(): yield key, self[key] - def values(self): + def values(self) -> Iterator[Subview]: """Iterates over all the subviews contained in dictionaries from *all* sources at this view. If the object for this view in any source is not a dict, then a `ConfigTypeError` is raised. @@ -226,7 +233,7 @@ def values(self): # List/sequence emulation. - def sequence(self): + def sequence(self) -> Iterator[Subview]: """Iterates over the subviews contained in lists from the *first* source at this view. If the object for this view in the first source is not a list or tuple, then a `ConfigTypeError` is raised. @@ -244,7 +251,7 @@ def sequence(self): for index in range(len(collection)): yield self[index] - def all_contents(self): + def all_contents(self) -> Iterator[str]: """Iterates over all subviews from collections at this view from *all* sources. If the object for this view in any source is not iterable, then a `ConfigTypeError` is raised. This method is @@ -262,7 +269,7 @@ def all_contents(self): # Validation and conversion. - def flatten(self, redact=False): + def flatten(self, redact: bool = False) -> OrderedDict[str, Any]: """Create a hierarchy of OrderedDicts containing the data from this view, recursively reifying all views to get their represented values. @@ -270,7 +277,7 @@ def flatten(self, redact=False): If `redact` is set, then sensitive values are replaced with the string "REDACTED". """ - od = OrderedDict() + od: OrderedDict[str, Any] = OrderedDict() for key, view in self.items(): if redact and view.redact: od[key] = REDACTED_TOMBSTONE @@ -281,7 +288,30 @@ def flatten(self, redact=False): od[key] = view.get() return od - def get(self, template=templates.REQUIRED) -> Any: + @overload + def get(self, template: templates.Pairs) -> list[tuple[str, str]]: ... + @overload + def get(self, template: templates.MappingTemplate) -> templates.AttrDict: ... + @overload + def get(self, template: templates.Optional[R]) -> R | None: ... + @overload + def get(self, template: type[R]) -> R: ... + @overload + def get(self, template: Mapping[str, Any]) -> templates.AttrDict: ... + # Overload for list (OneOf) + @overload + def get(self, template: list[R]) -> R: ... + @overload + def get(self, template: pathlib.PurePath) -> pathlib.Path: ... + @overload + def get(self, template: None) -> None: ... + @overload + def get(self, template: templates.Template[R]) -> R: ... + # Overload for REQUIRED sentinel + @overload + def get(self, template: templates._Required = ...) -> Any: ... + + def get(self, template: object = templates.REQUIRED) -> Any: """Retrieve the value for this view according to the template. The `template` against which the values are checked can be @@ -306,7 +336,7 @@ def as_path(self) -> Path: """Get the value as a `pathlib.Path` object. Equivalent to `get(Path())`.""" return self.get(templates.Path()) - def as_choice(self, choices: Iterable[R]) -> R: + def as_choice(self, choices: Sequence[R] | dict[str, R] | type[R]) -> R: """Get the value from a list of choices. Equivalent to `get(Choice(choices))`. @@ -315,19 +345,23 @@ def as_choice(self, choices: Iterable[R]) -> R: """ return self.get(templates.Choice(choices)) + def as_int(self) -> int: + """Get the value as an integer.""" + return self.get(templates.Integer()) + def as_number(self) -> int | float: """Get the value as any number type: int or float. Equivalent to `get(Number())`. """ return self.get(templates.Number()) - def as_str_seq(self, split=True) -> Sequence[str]: + def as_str_seq(self, split: bool = True) -> list[str]: """Get the value as a sequence of strings. Equivalent to `get(StrSeq(split=split))`. """ return self.get(templates.StrSeq(split=split)) - def as_pairs(self, default_value=None) -> Sequence[tuple[str, str]]: + def as_pairs(self, default_value: str | None = None) -> list[tuple[str, str]]: """Get the value as a sequence of pairs of two strings. Equivalent to `get(Pairs(default_value=default_value))`. """ @@ -348,23 +382,23 @@ def as_str_expanded(self) -> str: # Redaction. @property - def redact(self): + def redact(self) -> bool: """Whether the view contains sensitive information and should be redacted from output. """ return () in self.get_redactions() @redact.setter - def redact(self, flag): + def redact(self, flag: bool) -> None: self.set_redaction((), flag) - def set_redaction(self, path, flag): + def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: """Add or remove a redaction for a key path, which should be an iterable of keys. """ raise NotImplementedError() - def get_redactions(self): + def get_redactions(self) -> Iterable[tuple[str, ...]]: """Get the set of currently-redacted sub-key-paths at this view.""" raise NotImplementedError() @@ -374,48 +408,48 @@ class RootView(ConfigView): sources that may be accessed by subviews. """ - def __init__(self, sources): + def __init__(self, sources: Iterable[ConfigSource]) -> None: """Create a configuration hierarchy for a list of sources. At least one source must be provided. The first source in the list has the highest priority. """ - self.sources = list(sources) + self.sources: list[ConfigSource] = list(sources) self.name = ROOT_NAME - self.redactions = set() + self.redactions: set[tuple[str, ...]] = set() - def add(self, value): - self.sources.append(ConfigSource.of(value=value)) + def add(self, value: dict[str, Any]) -> None: + self.sources.append(ConfigSource.of(value)) - def set(self, value): + def set(self, value: dict[str, Any]) -> None: self.sources.insert(0, ConfigSource.of(value)) - def resolve(self): + def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: return ((dict(s), s) for s in self.sources) - def clear(self): + def clear(self) -> None: """Remove all sources (and redactions) from this configuration. """ del self.sources[:] self.redactions.clear() - def root(self): + def root(self) -> Self: return self - def set_redaction(self, path, flag): + def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: if flag: self.redactions.add(path) elif path in self.redactions: self.redactions.remove(path) - def get_redactions(self): + def get_redactions(self) -> builtins.set[tuple[str, ...]]: return self.redactions class Subview(ConfigView): """A subview accessed via a subscript of a parent view.""" - def __init__(self, parent, key): + def __init__(self, parent: ConfigView, key: str) -> None: """Make a subview of a parent view for a given subscript key.""" self.parent = parent self.key = key @@ -436,7 +470,7 @@ def __init__(self, parent, key): else: self.name += repr(self.key) - def resolve(self): + def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: for collection, source in self.parent.resolve(): try: value = collection[self.key] @@ -454,19 +488,19 @@ def resolve(self): ) yield value, source - def set(self, value): + def set(self, value: dict[str, Any]) -> None: self.parent.set({self.key: value}) - def add(self, value): + def add(self, value: dict[str, Any]) -> None: self.parent.add({self.key: value}) - def root(self): + def root(self) -> RootView: return self.parent.root() - def set_redaction(self, path, flag): + def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: self.parent.set_redaction((self.key, *path), flag) - def get_redactions(self): + def get_redactions(self) -> Iterable[tuple[str, ...]]: return ( kp[1:] for kp in self.parent.get_redactions() if kp and kp[0] == self.key ) @@ -476,7 +510,13 @@ def get_redactions(self): class Configuration(RootView): - def __init__(self, appname, modname=None, read=True, loader=yaml_util.Loader): + def __init__( + self, + appname: str, + modname: str | None = None, + read: bool = True, + loader: type[yaml_util.Loader] = yaml_util.Loader, + ): """Create a configuration object by reading the automatically-discovered config files for the application for a given name. If `modname` is specified, it should be the import @@ -504,14 +544,14 @@ def __init__(self, appname, modname=None, read=True, loader=yaml_util.Loader): if read: self.read() - def user_config_path(self): + def user_config_path(self) -> str: """Points to the location of the user configuration. The file may not exist. """ return os.path.join(self.config_dir(), CONFIG_FILENAME) - def _add_user_source(self): + def _add_user_source(self) -> None: """Add the configuration options from the YAML file in the user's configuration directory (given by `config_dir`) if it exists. @@ -519,7 +559,7 @@ def _add_user_source(self): filename = self.user_config_path() self.add(YamlSource(filename, loader=self.loader, optional=True)) - def _add_default_source(self): + def _add_default_source(self) -> None: """Add the package's default configuration settings. This looks for a YAML file located inside the package for the module `modname` if it was given. @@ -533,7 +573,7 @@ def _add_default_source(self): ) ) - def read(self, user=True, defaults=True): + def read(self, user: bool = True, defaults: bool = True) -> None: """Find and read the files for this configuration and set them as the sources for this configuration. To disable either discovered user configuration files or the in-package defaults, @@ -544,7 +584,7 @@ def read(self, user=True, defaults=True): if defaults: self._add_default_source() - def config_dir(self): + def config_dir(self) -> str: """Get the path to the user configuration directory. The directory is guaranteed to exist as a postcondition (one may be created if none exist). @@ -582,7 +622,7 @@ def config_dir(self): return appdir - def set_file(self, filename, base_for_paths=False): + def set_file(self, filename: str, base_for_paths: bool = False) -> None: """Parses the file as YAML and inserts it into the configuration sources with highest priority. @@ -596,7 +636,7 @@ def set_file(self, filename, base_for_paths=False): YamlSource(filename, base_for_paths=base_for_paths, loader=self.loader) ) - def set_env(self, prefix=None, sep="__"): + def set_env(self, prefix: str | None = None, sep: str = "__") -> None: """Create a configuration overlay at the highest priority from environment variables. @@ -618,7 +658,7 @@ def set_env(self, prefix=None, sep="__"): prefix = f"{self.appname.upper()}_" self.set(EnvSource(prefix, sep=sep, loader=self.loader)) - def dump(self, full=True, redact=False): + def dump(self, full: bool = True, redact: bool = False) -> str: """Dump the Configuration object to a YAML file. The order of the keys is determined from the default @@ -662,7 +702,7 @@ def dump(self, full=True, redact=False): return yaml_out - def reload(self): + def reload(self) -> None: """Reload all sources from the file system. This only affects sources that come from files (i.e., @@ -680,17 +720,19 @@ class LazyConfig(Configuration): the module level. """ - def __init__(self, appname, modname=None): + def __init__(self, appname: str, modname: str | None = None) -> None: super().__init__(appname, modname, False) self._materialized = False # Have we read the files yet? - self._lazy_prefix = [] # Pre-materialization calls to set(). - self._lazy_suffix = [] # Calls to add(). + self._lazy_prefix: list[ + ConfigSource + ] = [] # Pre-materialization calls to set(). + self._lazy_suffix: list[ConfigSource] = [] # Calls to add(). - def read(self, user=True, defaults=True): + def read(self, user: bool = True, defaults: bool = True) -> None: self._materialized = True super().read(user, defaults) - def resolve(self): + def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: if not self._materialized: # Read files and unspool buffers. self.read() @@ -698,21 +740,21 @@ def resolve(self): self.sources[:0] = self._lazy_prefix return super().resolve() - def add(self, value): + def add(self, value: dict[str, Any]) -> None: super().add(value) if not self._materialized: # Buffer additions to end. self._lazy_suffix += self.sources del self.sources[:] - def set(self, value): + def set(self, value: dict[str, Any]) -> None: super().set(value) if not self._materialized: # Buffer additions to beginning. self._lazy_prefix[:0] = self.sources del self.sources[:] - def clear(self): + def clear(self) -> None: """Remove all sources from this configuration.""" super().clear() self._lazy_suffix = [] From a0efe66565787b37df934e6a7a66283245a6e665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sat, 3 Jan 2026 17:47:22 +0000 Subject: [PATCH 08/12] Add py.typed --- confuse/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 confuse/py.typed diff --git a/confuse/py.typed b/confuse/py.typed new file mode 100644 index 0000000..e69de29 From e75859771b80a9b7984a0a9fce910f3046d1c69d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 4 Jan 2026 09:51:19 +0000 Subject: [PATCH 09/12] core: fix config key type --- confuse/core.py | 54 +++++++++++++++++++++++--------------------- confuse/templates.py | 16 +++++++------ 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/confuse/core.py b/confuse/core.py index ca48908..c6141c1 100644 --- a/confuse/core.py +++ b/confuse/core.py @@ -48,6 +48,8 @@ from optparse import Values from pathlib import Path + from .templates import ConfigKey + CONFIG_FILENAME = "config.yaml" DEFAULT_FILENAME = "config_default.yaml" ROOT_NAME = "root" @@ -73,7 +75,7 @@ class ConfigView: configuration in Python-like syntax (e.g., ``foo['bar'][42]``). """ - def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: + def resolve(self) -> Iterator[tuple[dict[str, Any] | list[Any], ConfigSource]]: """The core (internal) data retrieval method. Generates (value, source) pairs for each source that contains a value for this view. May raise `ConfigTypeError` if a type error occurs while @@ -81,7 +83,7 @@ def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: """ raise NotImplementedError - def first(self) -> tuple[dict[str, Any], ConfigSource]: + def first(self) -> tuple[dict[str, Any] | list[Any], ConfigSource]: """Return a (value, source) pair for the first object found for this view. This amounts to the first element returned by `resolve`. If no values are available, a `NotFoundError` is @@ -101,14 +103,14 @@ def exists(self) -> bool: return False return True - def add(self, value: dict[str, Any]) -> None: + def add(self, value: Any) -> None: """Set the *default* value for this configuration view. The specified value is added as the lowest-priority configuration data source. """ raise NotImplementedError - def set(self, value: dict[str, Any]) -> None: + def set(self, value: Any) -> None: """*Override* the value for this configuration view. The specified value is added as the highest-priority configuration data source. @@ -142,17 +144,17 @@ def __iter__(self) -> Iterator[Subview | str]: f"{type(item).__name__}" ) - def __getitem__(self, key: str) -> Subview: + def __getitem__(self, key: ConfigKey) -> Subview: """Get a subview of this view.""" return Subview(self, key) - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key: ConfigKey, value: Any) -> None: """Create an overlay source to assign a given key under this view. """ self.set({key: value}) - def __contains__(self, key: str) -> bool: + def __contains__(self, key: ConfigKey) -> bool: return self[key].exists() def set_args( @@ -203,7 +205,7 @@ def keys(self) -> list[str]: for dic, _ in self.resolve(): try: - cur_keys = dic.keys() + cur_keys = dic.keys() # type: ignore[union-attr] except AttributeError: raise ConfigTypeError( f"{self.name} must be a dict, not {type(dic).__name__}" @@ -392,13 +394,13 @@ def redact(self) -> bool: def redact(self, flag: bool) -> None: self.set_redaction((), flag) - def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: + def set_redaction(self, path: tuple[ConfigKey, ...], flag: bool) -> None: """Add or remove a redaction for a key path, which should be an iterable of keys. """ raise NotImplementedError() - def get_redactions(self) -> Iterable[tuple[str, ...]]: + def get_redactions(self) -> Iterable[tuple[ConfigKey, ...]]: """Get the set of currently-redacted sub-key-paths at this view.""" raise NotImplementedError() @@ -415,15 +417,15 @@ def __init__(self, sources: Iterable[ConfigSource]) -> None: """ self.sources: list[ConfigSource] = list(sources) self.name = ROOT_NAME - self.redactions: set[tuple[str, ...]] = set() + self.redactions: set[tuple[ConfigKey, ...]] = set() - def add(self, value: dict[str, Any]) -> None: + def add(self, value: Any) -> None: self.sources.append(ConfigSource.of(value)) - def set(self, value: dict[str, Any]) -> None: + def set(self, value: Any) -> None: self.sources.insert(0, ConfigSource.of(value)) - def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: + def resolve(self) -> Iterator[tuple[dict[str, Any] | list[Any], ConfigSource]]: return ((dict(s), s) for s in self.sources) def clear(self) -> None: @@ -436,20 +438,20 @@ def clear(self) -> None: def root(self) -> Self: return self - def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: + def set_redaction(self, path: tuple[ConfigKey, ...], flag: bool) -> None: if flag: self.redactions.add(path) elif path in self.redactions: self.redactions.remove(path) - def get_redactions(self) -> builtins.set[tuple[str, ...]]: + def get_redactions(self) -> builtins.set[tuple[ConfigKey, ...]]: return self.redactions class Subview(ConfigView): """A subview accessed via a subscript of a parent view.""" - def __init__(self, parent: ConfigView, key: str) -> None: + def __init__(self, parent: ConfigView, key: ConfigKey) -> None: """Make a subview of a parent view for a given subscript key.""" self.parent = parent self.key = key @@ -470,10 +472,10 @@ def __init__(self, parent: ConfigView, key: str) -> None: else: self.name += repr(self.key) - def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: + def resolve(self) -> Iterator[tuple[dict[str, Any] | list[Any], ConfigSource]]: for collection, source in self.parent.resolve(): try: - value = collection[self.key] + value = collection[self.key] # type: ignore[index] except IndexError: # List index out of bounds. continue @@ -488,19 +490,19 @@ def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: ) yield value, source - def set(self, value: dict[str, Any]) -> None: + def set(self, value: Any) -> None: self.parent.set({self.key: value}) - def add(self, value: dict[str, Any]) -> None: + def add(self, value: Any) -> None: self.parent.add({self.key: value}) def root(self) -> RootView: return self.parent.root() - def set_redaction(self, path: tuple[str, ...], flag: bool) -> None: + def set_redaction(self, path: tuple[ConfigKey, ...], flag: bool) -> None: self.parent.set_redaction((self.key, *path), flag) - def get_redactions(self) -> Iterable[tuple[str, ...]]: + def get_redactions(self) -> Iterable[tuple[ConfigKey, ...]]: return ( kp[1:] for kp in self.parent.get_redactions() if kp and kp[0] == self.key ) @@ -732,7 +734,7 @@ def read(self, user: bool = True, defaults: bool = True) -> None: self._materialized = True super().read(user, defaults) - def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: + def resolve(self) -> Iterator[tuple[dict[str, Any] | list[Any], ConfigSource]]: if not self._materialized: # Read files and unspool buffers. self.read() @@ -740,14 +742,14 @@ def resolve(self) -> Iterator[tuple[dict[str, Any], ConfigSource]]: self.sources[:0] = self._lazy_prefix return super().resolve() - def add(self, value: dict[str, Any]) -> None: + def add(self, value: Any) -> None: super().add(value) if not self._materialized: # Buffer additions to end. self._lazy_suffix += self.sources del self.sources[:] - def set(self, value: dict[str, Any]) -> None: + def set(self, value: Any) -> None: super().set(value) if not self._materialized: # Buffer additions to beginning. diff --git a/confuse/templates.py b/confuse/templates.py index 9f71c84..b487449 100644 --- a/confuse/templates.py +++ b/confuse/templates.py @@ -16,11 +16,13 @@ if TYPE_CHECKING: from .core import ConfigView, Subview + T = TypeVar("T") K = TypeVar("K", bound=Hashable, default=str) -Kstr = TypeVar("Kstr", bound=str, default=str) P = TypeVar("P", bound=pathlib.Path | str, default=str) V = TypeVar("V", default=object) +ConfigKey = int | str | bytes +ConfigKeyT = TypeVar("ConfigKeyT", bound=ConfigKey, default=str) class _Required: @@ -35,7 +37,7 @@ class _Required: """ -class AttrDict(dict[Kstr, V]): +class AttrDict(dict[ConfigKeyT, V]): """A `dict` subclass that can be accessed via attributes (dot notation) for convenience. """ @@ -161,17 +163,17 @@ def convert(self, value: Numeric, view: ConfigView) -> Numeric: self.fail(f"must be numeric, not {type(value).__name__}", view, True) -class MappingTemplate(Template[AttrDict[Kstr, V]]): +class MappingTemplate(Template[AttrDict[ConfigKeyT, V]]): """A template that uses a dictionary to specify other types for the values for a set of keys and produce a validated `AttrDict`. """ - def __init__(self, mapping: Mapping[Kstr, Template[V] | type[V]]) -> None: + def __init__(self, mapping: Mapping[ConfigKeyT, Template[V] | type[V]]) -> None: """Create a template according to a dict (mapping). The mapping's values should themselves either be Types or convertible to Types. """ - subtemplates: dict[Kstr, Template[V]] = {} + subtemplates: dict[ConfigKeyT, Template[V]] = {} for key, typ in mapping.items(): subtemplates[key] = as_template(typ) self.subtemplates = subtemplates @@ -179,8 +181,8 @@ def __init__(self, mapping: Mapping[Kstr, Template[V] | type[V]]) -> None: def value( self, view: ConfigView, - template: Template[AttrDict[Kstr, V]] | object | None = None, - ) -> AttrDict[Kstr, V]: + template: Template[AttrDict[ConfigKeyT, V]] | object | None = None, + ) -> AttrDict[ConfigKeyT, V]: """Get a dict with the same keys as the template and values validated according to the value types. """ From 30781af8f98d07b05719461068ee648560248686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 4 Jan 2026 16:42:37 +0000 Subject: [PATCH 10/12] templates: use covariant type var --- confuse/core.py | 28 +++------ confuse/templates.py | 145 ++++++++++++++++++++++++++----------------- 2 files changed, 98 insertions(+), 75 deletions(-) diff --git a/confuse/core.py b/confuse/core.py index c6141c1..017c96a 100644 --- a/confuse/core.py +++ b/confuse/core.py @@ -42,7 +42,6 @@ if TYPE_CHECKING: import builtins - import pathlib from argparse import Namespace from collections.abc import Iterable, Iterator, Mapping, Sequence from optparse import Values @@ -291,28 +290,15 @@ def flatten(self, redact: bool = False) -> OrderedDict[str, Any]: return od @overload - def get(self, template: templates.Pairs) -> list[tuple[str, str]]: ... - @overload - def get(self, template: templates.MappingTemplate) -> templates.AttrDict: ... - @overload - def get(self, template: templates.Optional[R]) -> R | None: ... + def get(self, template: templates.Template[R]) -> R: ... @overload def get(self, template: type[R]) -> R: ... @overload - def get(self, template: Mapping[str, Any]) -> templates.AttrDict: ... - # Overload for list (OneOf) + def get(self, template: Mapping[str, object]) -> templates.AttrDict[str, Any]: ... @overload def get(self, template: list[R]) -> R: ... @overload - def get(self, template: pathlib.PurePath) -> pathlib.Path: ... - @overload - def get(self, template: None) -> None: ... - @overload - def get(self, template: templates.Template[R]) -> R: ... - # Overload for REQUIRED sentinel - @overload def get(self, template: templates._Required = ...) -> Any: ... - def get(self, template: object = templates.REQUIRED) -> Any: """Retrieve the value for this view according to the template. @@ -363,11 +349,17 @@ def as_str_seq(self, split: bool = True) -> list[str]: """ return self.get(templates.StrSeq(split=split)) - def as_pairs(self, default_value: str | None = None) -> list[tuple[str, str]]: + @overload + def as_pairs(self, default_value: str) -> list[tuple[str, str]]: ... + @overload + def as_pairs(self, default_value: None = None) -> list[tuple[str, None]]: ... + def as_pairs( + self, default_value: str | None = None + ) -> list[tuple[str, str]] | list[tuple[str, None]]: """Get the value as a sequence of pairs of two strings. Equivalent to `get(Pairs(default_value=default_value))`. """ - return self.get(templates.Pairs(default_value=default_value)) + return self.get(templates.Pairs(default_value=default_value)) # type: ignore[return-value] def as_str(self) -> str: """Get the value as a (Unicode) string. Equivalent to diff --git a/confuse/templates.py b/confuse/templates.py index b487449..11420ed 100644 --- a/confuse/templates.py +++ b/confuse/templates.py @@ -18,6 +18,7 @@ T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) K = TypeVar("K", bound=Hashable, default=str) P = TypeVar("P", bound=pathlib.Path | str, default=str) V = TypeVar("V", default=object) @@ -52,7 +53,7 @@ def __setattr__(self, key: str, value: V) -> None: self[key] = value # type: ignore[index] -class Template(Generic[T]): +class Template(Generic[T_co]): """A value template for configuration fields. The template works like a type and instructs Confuse about how to @@ -61,24 +62,30 @@ class Template(Generic[T]): filepath type might expand tildes and check that the file exists. """ - def __init__(self, default: object | _Required = REQUIRED) -> None: + default: T_co | _Required + + @overload + def __init__(self: Template[T], default: T) -> None: ... + @overload + def __init__(self: Template[Any], default: _Required = ...) -> None: ... + def __init__(self, default: object = REQUIRED) -> None: """Create a template with a given default value. If `default` is the sentinel `REQUIRED` (as it is by default), then an error will be raised when a value is missing. Otherwise, missing values will instead return `default`. """ - self.default = default + self.default = default # type: ignore[assignment] - def __call__(self, view: ConfigView) -> T: + def __call__(self, view: ConfigView) -> T_co: """Invoking a template on a view gets the view's value according to the template. """ return self.value(view, self) def value( - self, view: ConfigView, template: Template[T] | object | None = None - ) -> T: + self, view: ConfigView, template: Template[T_co] | object | None = None + ) -> T_co: """Get the value for a `ConfigView`. May raise a `NotFoundError` if the value is missing (and the @@ -93,7 +100,7 @@ def value( # Get default value, or raise if required. return self.get_default_value(view.name) - def get_default_value(self, key_name: str = "default") -> T: + def get_default_value(self, key_name: str = "default") -> T_co: """Get the default value to return when the value is missing. May raise a `NotFoundError` if the value is required. @@ -104,7 +111,7 @@ def get_default_value(self, key_name: str = "default") -> T: # The value is not required. return self.default # type: ignore[return-value] - def convert(self, value: Any, view: ConfigView) -> T: + def convert(self, value: Any, view: ConfigView) -> T_co: """Convert the YAML-deserialized value to a value of the desired type. @@ -201,12 +208,21 @@ class Sequence(Template[list[T]]): subtemplate: Template[T] - def __init__(self, subtemplate: Template[T] | object): + @overload + def __init__( + self: Sequence[dict[str, V]], subtemplate: Mapping[str, Template[V] | type[V]] + ) -> None: ... + @overload + def __init__(self, subtemplate: type[T]) -> None: ... + @overload + def __init__(self, subtemplate: Template[T]) -> None: ... + + def __init__(self, subtemplate: Template[T] | type[T] | Mapping[str, object]): """Create a template for a list with items validated on a given subtemplate. """ super().__init__() - self.subtemplate = as_template(subtemplate) + self.subtemplate = as_template(subtemplate) # type: ignore[assignment] def value( self, view: ConfigView, template: Template[list[T]] | object | None = None @@ -232,12 +248,12 @@ class MappingValues(Template[dict[str, T]]): subtemplate: Template[T] - def __init__(self, subtemplate: Template[T] | object): + def __init__(self, subtemplate: Template[T] | type[T] | Mapping[str, object]): """Create a template for a mapping with variable keys and item values validated on a given subtemplate. """ super().__init__() - self.subtemplate = as_template(subtemplate) + self.subtemplate = as_template(subtemplate) # type: ignore[assignment] def value( self, view: ConfigView, template: Template[dict[str, T]] | object | None = None @@ -371,14 +387,20 @@ def __repr__(self) -> str: class OneOf(Template[T]): - """A template that permits values complying to one of the given templates.""" + """A template that permits values complying to one of the given templates. - allowed: list[Template[T]] + When using templates that produce different types, explicitly specify + the type parameter: ``OneOf[bool | str]([bool, String()])`` + """ + + allowed: list[Template[Any]] template: Template[Any] | None def __init__( - self, allowed: Iterable[Template[T] | object], default: T | _Required = REQUIRED - ): + self, + allowed: Iterable[Template[Any] | type[Any] | Mapping[str, object] | T], + default: T | _Required = REQUIRED, + ) -> None: super().__init__(default) self.allowed = [as_template(t) for t in allowed] self.template = None @@ -405,6 +427,7 @@ def convert(self, value: object, view: Subview) -> T: # type: ignore[override] is_mapping = isinstance(self.template, MappingTemplate) for candidate in self.allowed: + result: T try: if is_mapping: assert self.template is not None @@ -413,11 +436,13 @@ def convert(self, value: object, view: Subview) -> T: # type: ignore[override] if isinstance(view, Subview): # Use a new MappingTemplate to check the sibling value next_template = MappingTemplate({view.key: candidate}) - return view.parent.get(next_template)[view.key] + result = view.parent.get(next_template)[view.key] + return result else: self.fail("MappingTemplate must be used with a Subview", view) else: - return view.get(candidate) + result = view.get(candidate) + return result except exceptions.ConfigTemplateError: raise except exceptions.ConfigError: @@ -479,7 +504,7 @@ def convert( return [self._convert_value(v, view) for v in values] -class Pairs(BytesToStrMixin, Template[list[tuple[str, object]]]): +class Pairs(BytesToStrMixin, Template[list[tuple[str, V]]]): """A template for ordered key-value pairs. This can either be given with the same syntax as for `StrSeq` (i.e. without @@ -493,16 +518,34 @@ class Pairs(BytesToStrMixin, Template[list[tuple[str, object]]]): `default_value` will be returned as the second element. """ - def __init__(self, default_value: object | None = None) -> None: + default_value: V + + @overload + def __init__( + self: Pairs[str], + default_value: str, + default: list[tuple[str, str]] | _Required = REQUIRED, + ) -> None: ... + @overload + def __init__( + self: Pairs[None], + default_value: None = None, + default: list[tuple[str, None]] | _Required = REQUIRED, + ) -> None: ... + def __init__( + self, + default_value: str | None = None, + default: list[tuple[str, str]] | list[tuple[str, None]] | _Required = REQUIRED, + ) -> None: """Create a new template. - `default` is the dictionary value returned for items that are not + `default_value` is the dictionary value returned for items that are not a mapping, but a single string. """ - super().__init__() - self.default_value = default_value + super().__init__(default) # type: ignore[arg-type] + self.default_value = default_value # type: ignore[assignment] - def _convert_value(self, x: object, view: ConfigView) -> tuple[str, object]: + def _convert_value(self, x: object, view: ConfigView) -> tuple[str, V]: if isinstance(x, (str, bytes)): return self.normalize_bytes(x), self.default_value @@ -519,11 +562,11 @@ def _convert_value(self, x: object, view: ConfigView) -> tuple[str, object]: # YAML to parse this to some custom type. self.fail(f"must be a single string, mapping, or a list{x}", view, True) - return self.normalize_bytes(k), self.normalize_bytes(v) + return self.normalize_bytes(k), self.normalize_bytes(v) # type: ignore[return-value] def convert( self, value: list[abc.Sequence[str] | Mapping[str, str]], view: ConfigView - ) -> list[tuple[str, object]]: + ) -> list[tuple[str, V]]: return [self._convert_value(v, view) for v in value] @@ -544,7 +587,7 @@ class Filename(Template[P]): def __init__( self, - default: T | _Required = REQUIRED, + default: P | str | None | _Required = REQUIRED, cwd: str | None = None, relative_to: str | None = None, in_app_dir: bool = False, @@ -561,7 +604,10 @@ def __init__( relative to the directory containing the source file, if there is one, taking precedence over the application's config directory. """ - super().__init__(default) + if default is None: + self.default: P | _Required = default # type: ignore[assignment] + else: + super().__init__(default) # type: ignore[arg-type] self.cwd = cwd self.relative_to = relative_to self.in_app_dir = in_app_dir @@ -711,6 +757,7 @@ class Optional(Template[T | None]): """ subtemplate: Template[T] + default: T | None @overload def __init__( @@ -731,18 +778,18 @@ def __init__( @overload def __init__( self, - subtemplate: T, - default: T | None = None, + subtemplate: Mapping[str, object], + default: Mapping[str, Any] | None = None, allow_missing: bool = True, ) -> None: ... def __init__( self, - subtemplate: Template[T] | type[T] | T, - default: T | None = None, + subtemplate: Template[T] | type[T] | Mapping[str, object], + default: T | Mapping[str, Any] | None = None, allow_missing: bool = True, - ): - self.subtemplate = as_template(subtemplate) + ) -> None: + self.subtemplate: Template[T] = as_template(subtemplate) # type: ignore[assignment] if default is None: # When no default is passed, try to use the subtemplate's # default value as the default for this template @@ -750,7 +797,7 @@ def __init__( default = self.subtemplate.get_default_value() except exceptions.NotFoundError: pass - self.default = default + self.default = default # type: ignore[assignment] self.allow_missing = allow_missing def value( @@ -761,7 +808,7 @@ def value( except exceptions.NotFoundError: if self.allow_missing: # Value is missing but not required - return self.default # type: ignore[return-value] + return self.default # Value must be present even though it can be null. Raise an error. raise exceptions.NotFoundError(f"{view.name} not found") @@ -782,7 +829,7 @@ class TypeTemplate(Template[T]): desired Python type. """ - def __init__(self, typ: type[T], default: object = REQUIRED): + def __init__(self, typ: type[T], default: T | _Required = REQUIRED) -> None: """Create a template that checks that the value is an instance of `typ`. """ @@ -803,35 +850,19 @@ def convert(self, value: Any, view: ConfigView) -> T: @overload def as_template(value: Template[T]) -> Template[T]: ... @overload -def as_template(value: Mapping[str, object]) -> MappingTemplate: ... -@overload -def as_template(value: type[int]) -> Integer: ... -@overload -def as_template(value: int) -> Integer: ... -@overload -def as_template(value: type[str]) -> String: ... -@overload -def as_template(value: str) -> String: ... +def as_template(value: type[T]) -> Template[T]: ... @overload -def as_template(value: type[enum.Enum]) -> Choice[enum.Enum]: ... +def as_template(value: Mapping[str, object]) -> MappingTemplate: ... @overload -def as_template(value: set[T]) -> Choice[T]: ... +def as_template(value: set[T]) -> Choice[T, T]: ... @overload def as_template(value: list[T]) -> OneOf[T]: ... @overload -def as_template(value: type[float]) -> Number[float]: ... -@overload -def as_template(value: float) -> Number[float]: ... -@overload def as_template(value: pathlib.PurePath) -> Path: ... @overload def as_template(value: None) -> Template[None]: ... @overload -def as_template(value: type[dict[str, T]]) -> TypeTemplate[abc.Mapping[str, T]]: ... -@overload -def as_template(value: type[list[T]]) -> TypeTemplate[abc.Sequence[T]]: ... -@overload -def as_template(value: object) -> Template[Any]: ... +def as_template(value: T) -> Template[T]: ... def as_template(value: Any) -> Template[Any]: From ca1ee26b2f6c33ef9bbff090ccd6d010e0a0c82c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 4 Jan 2026 15:49:05 +0000 Subject: [PATCH 11/12] Fix types in test directory --- setup.cfg | 3 +++ test/__init__.py | 8 +++++- test/test_dump.py | 3 ++- test/test_env.py | 27 +++++++++++---------- test/test_paths.py | 21 ++++++++-------- test/test_utils.py | 13 +++++++--- test/test_valid.py | 54 ++++++++++++++++------------------------- test/test_validation.py | 8 +++--- test/test_views.py | 13 +++++++++- test/test_yaml.py | 3 ++- 10 files changed, 84 insertions(+), 69 deletions(-) diff --git a/setup.cfg b/setup.cfg index a49d0c9..8b9b467 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,3 +35,6 @@ allow_untyped_calls = true # FIXME: Would be better to actually type the libraries (if under our control), # or write our own stubs. For now, silence errors strict = true + +[mypy-test.*] +allow_untyped_defs = true diff --git a/test/__init__.py b/test/__init__.py index 8f69bb0..ef9495a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import os import shutil import tempfile +from typing import TYPE_CHECKING, Any import confuse +if TYPE_CHECKING: + from collections.abc import Mapping + -def _root(*sources): +def _root(*sources: Mapping[str, Any]) -> confuse.RootView: return confuse.RootView([confuse.ConfigSource.of(s) for s in sources]) diff --git a/test/test_dump.py b/test/test_dump.py index 875b472..4f6e265 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -1,5 +1,6 @@ import textwrap import unittest +from collections import OrderedDict import confuse @@ -32,7 +33,7 @@ def test_dump_short_list(self): assert yaml == "foo: [bar, baz]" def test_dump_ordered_dict(self): - odict = confuse.OrderedDict() + odict = OrderedDict() odict["foo"] = "bar" odict["bar"] = "baz" odict["baz"] = "qux" diff --git a/test/test_env.py b/test/test_env.py index b3a9b7c..fe476a8 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1,5 +1,6 @@ import os import unittest +from unittest.mock import patch import pytest @@ -7,15 +8,14 @@ from . import _root -ENVIRON = os.environ - class EnvSourceTest(unittest.TestCase): def setUp(self): - os.environ = {} + self.env_patcher = patch.dict("os.environ", {}) + self.env_patcher.start() def tearDown(self): - os.environ = ENVIRON + self.env_patcher.stop() def test_prefix(self): os.environ["TEST_FOO"] = "a" @@ -235,16 +235,17 @@ def test_parse_yaml_docs_false(self): class ConfigEnvTest(unittest.TestCase): def setUp(self): + self.env_patcher = patch.dict( + "os.environ", + { + "TESTAPP_FOO": "a", + "TESTAPP_BAR__NESTED": "b", + "TESTAPP_BAZ_SEP_NESTED": "c", + "MYAPP_QUX_SEP_NESTED": "d", + }, + ) + self.env_patcher.start() self.config = confuse.Configuration("TestApp", read=False) - os.environ = { - "TESTAPP_FOO": "a", - "TESTAPP_BAR__NESTED": "b", - "TESTAPP_BAZ_SEP_NESTED": "c", - "MYAPP_QUX_SEP_NESTED": "d", - } - - def tearDown(self): - os.environ = ENVIRON def test_defaults(self): self.config.set_env() diff --git a/test/test_paths.py b/test/test_paths.py index d31bf44..085a9f1 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -10,19 +10,18 @@ import confuse import confuse.yaml_util -DEFAULT = [platform.system, os.environ, os.path] - +DEFAULT = (platform.system, os.environ, os.path) SYSTEMS = { - "Linux": [{"HOME": "/home/test", "XDG_CONFIG_HOME": "~/xdgconfig"}, posixpath], - "Darwin": [{"HOME": "/Users/test"}, posixpath], - "Windows": [ + "Linux": ({"HOME": "/home/test", "XDG_CONFIG_HOME": "~/xdgconfig"}, posixpath), + "Darwin": ({"HOME": "/Users/test"}, posixpath), + "Windows": ( { "APPDATA": "~\\winconfig", "HOME": "C:\\Users\\test", "USERPROFILE": "C:\\Users\\test", }, ntpath, - ], + ), } @@ -48,10 +47,10 @@ class FakeSystem(unittest.TestCase): def setUp(self): super().setUp() self.os_path = os.path - os.environ = {} + os.environ = {} # type: ignore[assignment] environ, os.path = SYSTEMS[self.SYS_NAME] - os.environ.update(environ) # copy + os.environ.update(environ) platform.system = lambda: self.SYS_NAME def tearDown(self): @@ -128,11 +127,11 @@ def test_fallback_dir(self): class ConfigFilenamesTest(unittest.TestCase): def setUp(self): self._old = os.path.isfile, confuse.yaml_util.load_yaml - os.path.isfile = lambda x: True + os.path.isfile = lambda x: True # type: ignore[assignment] confuse.yaml_util.load_yaml = lambda *args, **kwargs: {} def tearDown(self): - confuse.yaml_util.load_yaml, os.path.isfile = self._old + os.path.isfile, confuse.yaml_util.load_yaml = self._old def test_no_sources_when_files_missing(self): config = confuse.Configuration("myapp", read=False) @@ -173,7 +172,7 @@ def test_env_var_missing(self): assert self.config.config_dir() != self.home -@unittest.skipUnless(os.system == "Linux", "Linux-specific tests") +@unittest.skipUnless(platform.system() == "Linux", "Linux-specific tests") class PrimaryConfigDirTest(FakeHome, FakeSystem): SYS_NAME = "Linux" # conversion from posix to nt is easy diff --git a/test/test_utils.py b/test/test_utils.py index 94a3387..176fcdf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,7 @@ import unittest from argparse import Namespace from collections import OrderedDict +from typing import Any import pytest @@ -20,7 +21,7 @@ def test_namespaces(self): assert 1 == result["another"] def test_dot_sep_keys(self): - config = {"foo.bar": 1} + config: dict[str, Any] = {"foo.bar": 1} result = confuse.util.build_dict(config.copy()) assert 1 == result["foo.bar"] @@ -28,7 +29,7 @@ def test_dot_sep_keys(self): assert 1 == result["foo"]["bar"] def test_dot_sep_keys_clobber(self): - args = [("foo.bar", 1), ("foo.bar.zar", 2)] + args: list[tuple[str, Any]] = [("foo.bar", 1), ("foo.bar.zar", 2)] config = OrderedDict(args) result = confuse.util.build_dict(config.copy(), sep=".") assert {"zar": 2} == result["foo"]["bar"] @@ -42,7 +43,11 @@ def test_dot_sep_keys_clobber(self): assert 2 == result["foo"]["bar"]["zar"] def test_dot_sep_keys_no_clobber(self): - args = [("foo.bar", 1), ("foo.far", 2), ("foo.zar.dar", 4)] + args: list[tuple[str, Any]] = [ + ("foo.bar", 1), + ("foo.far", 2), + ("foo.zar.dar", 4), + ] config = OrderedDict(args) result = confuse.util.build_dict(config.copy(), sep=".") assert 1 == result["foo"]["bar"] @@ -50,7 +55,7 @@ def test_dot_sep_keys_no_clobber(self): assert 4 == result["foo"]["zar"]["dar"] def test_adjacent_underscores_sep_keys(self): - config = {"foo__bar_baz": 1} + config: dict[str, Any] = {"foo__bar_baz": 1} result = confuse.util.build_dict(config.copy()) assert 1 == result["foo__bar_baz"] diff --git a/test/test_valid.py b/test/test_valid.py index 14c33f8..0b340a6 100644 --- a/test/test_valid.py +++ b/test/test_valid.py @@ -2,6 +2,7 @@ import os import unittest from collections.abc import Mapping, Sequence +from typing import Any import pytest @@ -69,29 +70,13 @@ def test_validate_individual_value(self): assert valid == 5 def test_nested_dict_template(self): - config = _root( - { - "foo": {"bar": 9}, - } - ) - valid = config.get( - { - "foo": {"bar": confuse.Integer()}, - } - ) + config = _root({"foo": {"bar": 9}}) + valid = config.get({"foo": {"bar": confuse.Integer()}}) assert valid["foo"]["bar"] == 9 def test_nested_attribute_access(self): - config = _root( - { - "foo": {"bar": 8}, - } - ) - valid = config.get( - { - "foo": {"bar": confuse.Integer()}, - } - ) + config = _root({"foo": {"bar": 8}}) + valid = config.get({"foo": {"bar": confuse.Integer()}}) assert valid.foo.bar == 8 @@ -132,12 +117,12 @@ def test_nested_dict_as_template(self): assert typ.subtemplates["outer"].subtemplates["inner"].default == 2 def test_list_as_template(self): - typ = confuse.as_template(list()) + typ: confuse.OneOf[Any] = confuse.as_template(list()) assert isinstance(typ, confuse.OneOf) assert typ.default == confuse.REQUIRED def test_set_as_template(self): - typ = confuse.as_template(set()) + typ: confuse.Choice[Any] = confuse.as_template(set()) assert isinstance(typ, confuse.Choice) def test_enum_type_as_template(self): @@ -275,7 +260,7 @@ def test_default_value(self): def test_validate_good_choice_in_list(self): config = _root({"foo": 2}) - valid = config["foo"].get( + valid: str | int = config["foo"].get( confuse.OneOf( [ confuse.String(), @@ -287,7 +272,7 @@ def test_validate_good_choice_in_list(self): def test_validate_first_good_choice_in_list(self): config = _root({"foo": 3.14}) - valid = config["foo"].get( + valid: str | int = config["foo"].get( confuse.OneOf( [ confuse.Integer(), @@ -416,7 +401,7 @@ def test_filename_with_non_file_source(self): def test_filename_with_file_source(self): source = confuse.ConfigSource({"foo": "foo/bar"}, filename="/baz/config.yaml") config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename()) assert valid == os.path.realpath("/config/path/foo/bar") @@ -425,7 +410,7 @@ def test_filename_with_default_source(self): {"foo": "foo/bar"}, filename="/baz/config.yaml", default=True ) config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename()) assert valid == os.path.realpath("/config/path/foo/bar") @@ -434,28 +419,28 @@ def test_filename_use_config_source_dir(self): {"foo": "foo/bar"}, filename="/baz/config.yaml", base_for_paths=True ) config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename()) assert valid == os.path.realpath("/baz/foo/bar") def test_filename_in_source_dir(self): source = confuse.ConfigSource({"foo": "foo/bar"}, filename="/baz/config.yaml") config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename(in_source_dir=True)) assert valid == os.path.realpath("/baz/foo/bar") def test_filename_in_source_dir_overrides_in_app_dir(self): source = confuse.ConfigSource({"foo": "foo/bar"}, filename="/baz/config.yaml") config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename(in_source_dir=True, in_app_dir=True)) assert valid == os.path.realpath("/baz/foo/bar") def test_filename_in_app_dir_non_file_source(self): source = confuse.ConfigSource({"foo": "foo/bar"}) config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename(in_app_dir=True)) assert valid == os.path.realpath("/config/path/foo/bar") @@ -464,7 +449,7 @@ def test_filename_in_app_dir_overrides_config_source_dir(self): {"foo": "foo/bar"}, filename="/baz/config.yaml", base_for_paths=True ) config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] valid = config["foo"].get(confuse.Filename(in_app_dir=True)) assert valid == os.path.realpath("/config/path/foo/bar") @@ -503,7 +488,7 @@ def test_missing_required_value(self): class BaseTemplateTest(unittest.TestCase): def test_base_template_accepts_any_value(self): config = _root({"foo": 4.2}) - valid = config["foo"].get(confuse.Template()) + valid: float = config["foo"].get(confuse.Template()) assert valid == 4.2 def test_base_template_required(self): @@ -576,7 +561,9 @@ def test_dict_dict(self): config = _root( {"foo": {"first": {"bar": 1, "baz": 2}, "second": {"bar": 3, "baz": 4}}} ) - valid = config["foo"].get(confuse.MappingValues({"bar": int, "baz": int})) + valid: dict[str, dict[str, int]] = config["foo"].get( + confuse.MappingValues({"bar": int, "baz": int}) + ) assert valid == {"first": {"bar": 1, "baz": 2}, "second": {"bar": 3, "baz": 4}} def test_invalid_item(self): @@ -656,6 +643,7 @@ def test_optional_mapping_template_valid(self): config = _root({"foo": {"bar": 5, "baz": "bak"}}) template = {"bar": confuse.Integer(), "baz": confuse.String()} valid = config.get({"foo": confuse.Optional(template)}) + assert valid["foo"] assert valid["foo"]["bar"] == 5 assert valid["foo"]["baz"] == "bak" diff --git a/test/test_validation.py b/test/test_validation.py index 793af12..4103dbd 100644 --- a/test/test_validation.py +++ b/test/test_validation.py @@ -40,7 +40,7 @@ def test_as_filename_with_non_file_source(self): def test_as_filename_with_file_source(self): source = confuse.ConfigSource({"foo": "foo/bar"}, filename="/baz/config.yaml") config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] value = config["foo"].as_filename() assert value == os.path.realpath("/config/path/foo/bar") @@ -49,7 +49,7 @@ def test_as_filename_with_default_source(self): {"foo": "foo/bar"}, filename="/baz/config.yaml", default=True ) config = _root(source) - config.config_dir = lambda: "/config/path" + config.config_dir = lambda: "/config/path" # type: ignore[attr-defined] value = config["foo"].as_filename() assert value == os.path.realpath("/config/path/foo/bar") @@ -60,7 +60,7 @@ def test_as_filename_wrong_type(self): def test_as_path(self): config = _root({"foo": "foo/bar"}) - path = os.path.join(os.getcwd(), "foo", "bar") + path_str = os.path.join(os.getcwd(), "foo", "bar") try: import pathlib except ImportError: @@ -68,7 +68,7 @@ def test_as_path(self): value = config["foo"].as_path() else: value = config["foo"].as_path() - path = pathlib.Path(path) + path = pathlib.Path(path_str) assert value == path def test_as_choice_correct(self): diff --git a/test/test_views.py b/test/test_views.py index ddd825e..2ead690 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import unittest +from typing import TYPE_CHECKING import pytest @@ -6,6 +9,11 @@ from . import _root +if TYPE_CHECKING: + from collections.abc import Iterator + + from confuse import Subview + class SingleSourceTest(unittest.TestCase): def test_dict_access(self): @@ -35,7 +43,10 @@ def test_dict_iter(self): def test_list_iter(self): config = _root({"l": ["foo", "bar"]}) - items = [subview.get() for subview in config["l"]] + # TODO(@snejus): we need to split Subview to SequenceView and MappingView in + # order to have automatic resolution here + _items: Iterator[Subview] = iter(config["l"]) + items = [subview.get() for subview in _items] assert items == ["foo", "bar"] def test_int_iter(self): diff --git a/test/test_yaml.py b/test/test_yaml.py index ea30409..7b636ac 100644 --- a/test/test_yaml.py +++ b/test/test_yaml.py @@ -1,4 +1,5 @@ import unittest +from collections import OrderedDict import pytest import yaml @@ -15,7 +16,7 @@ def load(s): class ParseTest(unittest.TestCase): def test_dict_parsed_as_ordereddict(self): v = load("a: b\nc: d") - assert isinstance(v, confuse.OrderedDict) + assert isinstance(v, OrderedDict) assert list(v) == ["a", "c"] def test_string_beginning_with_percent(self): From 2ca46c3e3df4f8178afd3d972a48976e19364d95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0ar=C5=ABnas=20Nejus?= Date: Sun, 4 Jan 2026 16:53:55 +0000 Subject: [PATCH 12/12] Fix types in examples --- example/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/__init__.py b/example/__init__.py index 87e1465..f5a548f 100644 --- a/example/__init__.py +++ b/example/__init__.py @@ -24,7 +24,7 @@ config = confuse.LazyConfig("ConfuseExample", __name__) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="example Confuse program") parser.add_argument( "--library",