diff --git a/README.md b/README.md index ed9b83c..29fdc4f 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@ This project builds on top of `argparse` by adding type-safety and allowing a mo ## Example usage ```py -from arcparse import arcparser, flag +from arcparse import arcparser, flag, positional from pathlib import Path @arcparser class Args: - path: Path + path: Path = positional() recurse: bool = flag("-r") item_limit: int = 100 output_path: Path | None @@ -19,6 +19,21 @@ args = Args.parse() print(f"Scanning {args.path}...") ... ``` +
+ Help output of this parser + + usage: program.py [-h] [-r] [--item-limit ITEM_LIMIT] [--output-path OUTPUT_PATH] path + + positional arguments: + path + + options: + -h, --help show this help message and exit + -r, --recurse + --item-limit ITEM_LIMIT + --output-path OUTPUT_PATH + +
For more examples see [Examples](examples/). @@ -29,13 +44,14 @@ $ pip install arcparse ``` ## Features -- Positional, Option and Flag arguments -- Multiple values per argument +- Positional, Option and [Flag](./examples/flag.py) arguments - Name overriding -- Type conversions -- Mutually exclusive groups -- Subparsers -- Parser inheritance (with overriding) +- [Multiple values per argument](./examples/multiple.py) +- [Type conversions](./examples/conversion.py) +- [Mutually exclusive groups](./examples/mutual_exclusion.py) +- [Subparsers](./examples/subparsers.py) +- [Parser inheritance](./examples/inheritance.py) (with [overriding](./examples/override.py)) +- [Presence validation](./examples/presence_validation.py) ## Credits This project was inspired by [swansonk14/typed-argument-parser](https://github.com/swansonk14/typed-argument-parser). diff --git a/arcparse/_partial_arguments.py b/arcparse/_partial_arguments.py index 0fad2f3..e742f65 100644 --- a/arcparse/_partial_arguments.py +++ b/arcparse/_partial_arguments.py @@ -116,10 +116,6 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: class PartialPositional[T](BasePartialValueArgument[T, Positional]): def resolve_with_typehint(self, name: str, typehint: type) -> Positional: kwargs = self.resolve_to_kwargs(name, typehint) - return Positional(**kwargs) - - def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().resolve_to_kwargs(name, typehint) type_is_optional = bool(extract_optional_type(typehint)) type_is_collection = bool(extract_collection_type(typehint)) @@ -135,7 +131,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: kwargs["nargs"] = "+" if self.at_least_one else "*" kwargs["metavar"] = self.name_override - return kwargs + return Positional(name=name, **kwargs) @dataclass @@ -151,10 +147,6 @@ def resolve_with_typehint(self, name: str, typehint: type) -> Option: ) kwargs = self.resolve_to_kwargs(name, typehint) - return Option(**kwargs) - - def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: - kwargs = super().resolve_to_kwargs(name, typehint) kwargs["short"] = self.short kwargs["short_only"] = self.short_only @@ -170,6 +162,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: if self.choices is None: # choices generate custom `{foo,bar}` metavar in argparse kwargs["metavar"] = self.name_override.replace("-", "_").upper() elif self.short_only and self.short is not None: + kwargs["name"] = name kwargs["dest"] = name else: kwargs["name"] = name @@ -181,8 +174,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]: kwargs["optional"] = True elif self.mx_group is not None: raise InvalidArgument("Arguments in mutually exclusive group have to have a default") - - return kwargs + return Option(**kwargs) @dataclass @@ -201,7 +193,7 @@ def resolve_with_typehint(self, name: str, typehint: type) -> Flag: kwargs["short"] = self.short kwargs["short_only"] = self.short_only kwargs["no_flag"] = self.no_flag - return Flag(**kwargs) + return Flag(name=name, **kwargs) class PartialTriFlag(BasePartialArgument[TriFlag]): diff --git a/arcparse/_validations.py b/arcparse/_validations.py new file mode 100644 index 0000000..eac8e9c --- /dev/null +++ b/arcparse/_validations.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection, Iterable, Sequence +from dataclasses import dataclass + +from arcparse.arguments import BaseArgument, BaseValueArgument, Flag + + +class ArgumentAccessor: + def __init__(self, arguments: dict[str, BaseArgument]): + self._arguments = arguments + + def __getattribute__(self, key: str) -> BaseArgument: + arguments = super().__getattribute__("_arguments") + if key not in arguments: + raise Exception(f'Argument "{key}" doesn\'t exist') + return arguments[key] + + +class Constraint(ABC): + @abstractmethod + def validate(self, arguments: dict[str, str]) -> bool: ... + + @staticmethod + def is_provided(argument: BaseArgument, arguments: dict[str, str]) -> bool: + if (provided_value := arguments.get(argument.name)) is None: + return False + + if isinstance(argument, Flag): + defined_default = argument.no_flag + elif isinstance(argument, BaseValueArgument): + defined_default = argument.default + else: + raise Exception(f"is_provided is not defined for {argument.__class__.__name__}") + return provided_value != defined_default + + +@dataclass +class ImplyConstraint(Constraint): + arg: BaseArgument + required: Collection[BaseArgument] + disallowed: Collection[BaseArgument] + + def validate(self, arguments: dict[str, str]) -> bool: + if not self.is_provided(self.arg, arguments): + return False + + for arg in self.required: + if not self.is_provided(arg, arguments): + raise Exception(f'Argument "{arg.display_name}" is required when "{self.arg.display_name}" is passed') + + for arg in self.disallowed: + if self.is_provided(arg, arguments): + raise Exception(f'Argument "{arg.display_name}" is incompatible with "{self.arg.display_name}"') + return True + + +@dataclass +class RequireConstraint(Constraint): + args: Collection[BaseArgument] + + def validate(self, arguments: dict[str, str]) -> bool: + def and_join(names: Sequence[str]) -> str: + if len(names) == 0: + return "" + elif len(names) == 1: + return names[0] + return f"{', '.join(names[:-1])} and {names[-1]}" + + not_provided = [arg for arg in self.args if not self.is_provided(arg, arguments)] + if not_provided: + provided_text = "none" if len(not_provided) == len(self.args) else "only some" + raise Exception( + f"Arguments {and_join([arg.display_name for arg in self.args])} are required together, but {provided_text} were provided" + ) + return True + + +def validate_with( + defined_arguments: dict[str, BaseArgument], + validations_callable: Callable[[ArgumentAccessor], Iterable[Constraint]], + provided_arguments: dict[str, str], +) -> None: + for constraint in validations_callable(ArgumentAccessor(defined_arguments)): + if not isinstance(constraint, Constraint): + raise TypeError("Items returned from __validations__() have to be of type Constrant") + + matched = constraint.validate(provided_arguments) + if matched: + break diff --git a/arcparse/arguments.py b/arcparse/arguments.py index e7c20aa..f9e4c85 100644 --- a/arcparse/arguments.py +++ b/arcparse/arguments.py @@ -36,17 +36,22 @@ def apply(self, actions_container: _ActionsContainer, name: str) -> None: ... @dataclass(kw_only=True) class BaseArgument(ABC, ContainerApplicable): + name: str help: str | None = None + @property + @abstractmethod + def display_name(self) -> str: ... + def apply(self, actions_container: _ActionsContainer, name: str) -> None: - args = self.get_argparse_args(name) - kwargs = self.get_argparse_kwargs(name) + args = self.get_argparse_args() + kwargs = self.get_argparse_kwargs() actions_container.add_argument(*args, **kwargs) @abstractmethod - def get_argparse_args(self, name: str) -> list[str]: ... + def get_argparse_args(self) -> list[str]: ... - def get_argparse_kwargs(self, name: str) -> dict[str, Any]: + def get_argparse_kwargs(self) -> dict[str, Any]: kwargs = {} if self.help is not None: kwargs["help"] = self.help @@ -59,31 +64,38 @@ class Flag(BaseArgument): short_only: bool = False no_flag: bool = False - def get_argparse_args(self, name: str) -> list[str]: + @property + def display_name(self) -> str: if self.no_flag: - args = [f"--no-{name.replace("_", "-")}"] - else: - args = [f"--{name.replace("_", "-")}"] + return f"--no-{self.name.replace("_", "-")}" + return f"--{self.name.replace("_", "-")}" + def get_argparse_args(self) -> list[str]: if self.short_only: if self.short is not None: return [self.short] - else: - return [f"-{name}"] - elif self.short is not None: - args.insert(0, self.short) + return [f"-{self.name}"] + args = [self.display_name] + if self.short is not None: + args.insert(0, self.short) return args - def get_argparse_kwargs(self, name: str) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name) + def get_argparse_kwargs(self) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs() kwargs["action"] = "store_true" if not self.no_flag else "store_false" - kwargs["dest"] = name + kwargs["dest"] = self.name return kwargs class TriFlag(ContainerApplicable): + name: str + + @property + def display_name(self) -> str: + return f"--(no-){self.name.replace("_", "-")}" + def apply(self, actions_container: _ActionsContainer, name: str) -> None: # if actions_container is not an mx group, make it one, argparse # doesn't support mx group nesting @@ -104,8 +116,8 @@ class BaseValueArgument[T](BaseArgument): optional: bool = False metavar: str | None = None - def get_argparse_kwargs(self, name: str) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name) + def get_argparse_kwargs(self) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs() if self.default is not void: kwargs["default"] = self.default @@ -121,11 +133,15 @@ def get_argparse_kwargs(self, name: str) -> dict[str, Any]: @dataclass class Positional[T](BaseValueArgument[T]): - def get_argparse_args(self, name: str) -> list[str]: - return [name] + @property + def display_name(self) -> str: + return self.name + + def get_argparse_args(self) -> list[str]: + return [self.name] - def get_argparse_kwargs(self, name: str) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name) + def get_argparse_kwargs(self) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs() if self.nargs is None and (self.optional or self.default is not void): kwargs["nargs"] = "?" @@ -135,26 +151,28 @@ def get_argparse_kwargs(self, name: str) -> dict[str, Any]: @dataclass class Option[T](BaseValueArgument[T]): - name: str | None = None dest: str | None = None short: str | None = None short_only: bool = False append: bool = False - def get_argparse_args(self, name: str) -> list[str]: - args = [f"--{(self.name or name).replace("_", "-")}"] + @property + def display_name(self) -> str: + return f"--{(self.name).replace("_", "-")}" + + def get_argparse_args(self) -> list[str]: if self.short_only: if self.short is not None: return [self.short] - else: - return [f"-{self.name or name}"] - elif self.short is not None: - args.insert(0, self.short) + return [f"-{self.name}"] + args = [self.display_name] + if self.short is not None: + args.insert(0, self.short) return args - def get_argparse_kwargs(self, name: str) -> dict[str, Any]: - kwargs = super().get_argparse_kwargs(name) + def get_argparse_kwargs(self) -> dict[str, Any]: + kwargs = super().get_argparse_kwargs() if self.dest is not None: kwargs["dest"] = self.dest diff --git a/arcparse/parser.py b/arcparse/parser.py index 9378774..eea5dfd 100644 --- a/arcparse/parser.py +++ b/arcparse/parser.py @@ -19,6 +19,7 @@ extract_subparsers_from_typehint, union_contains_none, ) +from ._validations import validate_with from .arguments import ( BaseArgument, BaseValueArgument, @@ -106,6 +107,9 @@ def _construct_object_with_parsed(self, parsed: dict[str, Any]) -> T: sub_parser = subparsers.sub_parsers[chosen_subparser] parsed[name] = sub_parser._construct_object_with_parsed(parsed) + if (validations := getattr(self.shape, "__presence_validations__", None)) and callable(validations): + validate_with(self.arguments, cast(Any, validations).__func__, parsed) + # apply argument converters for name, argument in self.all_arguments: if not isinstance(argument, BaseValueArgument) or argument.converter is None: @@ -139,9 +143,10 @@ def __getattribute__(self, name: str): if name in __dict__: return __dict__[name] - value = super().__getattribute__(name) - if not isinstance(value, BasePartialArgument): - return value + if hasattr(super(), name): + value = getattr(super(), name) + if not isinstance(value, BasePartialArgument): + return value raise AttributeError(f"'{cls.__name__}' parser didn't define argument '{name}'", name=name, obj=self) diff --git a/arcparse/validations.py b/arcparse/validations.py new file mode 100644 index 0000000..f2ff884 --- /dev/null +++ b/arcparse/validations.py @@ -0,0 +1,33 @@ +from collections.abc import Collection +from typing import Any + +from ._validations import Constraint, ImplyConstraint, RequireConstraint + + +__all__ = [ + "Constraint", + "ImplyConstraint", + "RequireConstraint", + "imply", + "require", +] + + +def imply(arg: Any, required: Collection[Any] = (), disallowed: Collection[Any] = ()) -> ImplyConstraint: + """ + Require and disallow arguments when arg is passed. + + If `arg` is present and the constraint is fulfilled, no subsequent constraints are checked. + + :param Any arg: argument to check existence for + :param Collection[Any] required: required arguments when arg is passed + :param Collection[Any] disallowed: disallowed arguments when arg is passed + """ + return ImplyConstraint(arg, required=required, disallowed=disallowed) + + +def require(*args: Any) -> RequireConstraint: + """ + Require arguments to be present. + """ + return RequireConstraint(args) diff --git a/examples/override.py b/examples/override.py index 4a8790a..da9ea96 100644 --- a/examples/override.py +++ b/examples/override.py @@ -20,7 +20,7 @@ def __post_init__(parser: Parser) -> None: del parser.arguments["bar"] # create new argument - parser.arguments["baz"] = Option("baz", default="baz") + parser.arguments["baz"] = Option("baz", name="baz", default="baz") if __name__ == "__main__": diff --git a/examples/presence_validation.py b/examples/presence_validation.py new file mode 100644 index 0000000..9b7b356 --- /dev/null +++ b/examples/presence_validation.py @@ -0,0 +1,49 @@ +from typing import Iterator + +from arcparse import arcparser, positional +from arcparse.validations import Constraint, imply, require + + +@arcparser +class Args: + """ + This parser implements a config subcommand similar to `git config`. + + Usage: + `--list` displays configuration + `--unset ` unsets the key + ` ` sets the key to the value + + It would be impossible to create this parser using subparsers, since subparsers require a name to select + the subparser used, while this example uses flags to select the action. + + This example could be partially implemented by just using a mutually exclusive group with `--list` and `--unset` + but the presence of `key` wouldn't be required with `--unset` and similarly the presence of `value` wouldn't be + required when neither `--list` nor `--unset` are present. + """ + + list: bool + unset: bool + key: str | None = positional() + value: str | None = positional() + + @classmethod + def __presence_validations__(cls) -> Iterator[Constraint]: + """ + Generate argument presence constraints. + + The return value can be any iterable. + """ + # --list is incompatible with --unset, key and value + yield imply(cls.list, disallowed=[cls.unset, cls.key, cls.value]) + + # --unset requires key, but is incompatible with value (incompatibility with --list is verified by + # the previous constraint) + yield imply(cls.unset, required=[cls.key], disallowed=[cls.value]) + + # if none of the previous constraints matched (--list and --unset are not provided), require both key and value + yield require(cls.key, cls.value) + + +if __name__ == "__main__": + print(vars(Args.parse())) diff --git a/examples/subparsers.py b/examples/subparsers.py index 85918b8..af31287 100644 --- a/examples/subparsers.py +++ b/examples/subparsers.py @@ -27,4 +27,7 @@ class Args: if __name__ == "__main__": - print(vars(Args.parse())) + parsed = Args.parse() + print(vars(parsed)) + + parsed.action.run_action() diff --git a/pyproject.toml b/pyproject.toml index 4be5db4..5c4678d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcparse" -version = "1.2.0" +version = "1.3.0" description = "Declare program arguments in a type-safe way" license = "MIT" authors = ["Jakub Rozek "] diff --git a/tests/test_examples.py b/tests/test_examples.py index 201c58c..c5c837e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -12,6 +12,12 @@ def test_example_functional(path: Path) -> None: try: subprocess.run(["poetry", "run", "python3", str(path)], check=True, stderr=subprocess.PIPE) except subprocess.CalledProcessError as e: - if b"usage:" not in e.stderr: + ignore_strings = [ + # some required arguments were not provided + "usage:", + # some required arguments were not provided together (caused by presence validation) + "are required together", + ] + if not any(msg in e.stderr.decode() for msg in ignore_strings): # process didn't exit from a parser error, reraise raise diff --git a/tests/test_override.py b/tests/test_override.py index 26c1cae..ab29161 100644 --- a/tests/test_override.py +++ b/tests/test_override.py @@ -32,7 +32,7 @@ class Args(BaseArgs): @staticmethod def __post_init__(parser: Parser) -> None: - parser.arguments["bar"] = Option("bar") + parser.arguments["bar"] = Option("bar", name="bar") with pytest.raises(SystemExit): Args.parse("--foo foo".split()) diff --git a/tests/test_presence_validations.py b/tests/test_presence_validations.py new file mode 100644 index 0000000..987dad5 --- /dev/null +++ b/tests/test_presence_validations.py @@ -0,0 +1,97 @@ +from typing import Any, Iterator + +import pytest + +from arcparse import Parser, arcparser, positional, subparsers +from arcparse.validations import Constraint, imply, require + + +@pytest.fixture(scope="session") +def args_cls() -> Parser: + @arcparser + class ConfigArgs: + list: bool + unset: bool + key: str | None = positional() + value: str | None = positional() + + @classmethod + def __presence_validations__(cls) -> Iterator[Constraint]: + yield imply(cls.list, disallowed=[cls.unset, cls.key, cls.value]) + yield imply(cls.unset, required=[cls.key], disallowed=[cls.value]) + yield require(cls.key, cls.value) + + return ConfigArgs + + +@pytest.fixture(scope="session") +def subparsers_args_cls() -> Parser: + class FooBar: + foo: bool + bar: bool + + @classmethod + def __presence_validations__(cls) -> Iterator[Constraint]: + yield imply(cls.foo, required=[cls.bar]) + + class BarFoo: + bar: bool + foo: bool + + @classmethod + def __presence_validations__(cls) -> Iterator[Constraint]: + yield imply(cls.bar, required=[cls.foo]) + + @arcparser + class Args: + arg: FooBar | BarFoo = subparsers("foobar", "barfoo") + + return Args + + +@pytest.mark.parametrize( + "string,result", + [ + ("--list", {"list": True}), + ("--unset foo", {"unset": True, "key": "foo"}), + ("foo bar", {"key": "foo", "value": "bar"}), + ("--list --unset", Exception), + ("--list foo", Exception), + ("--unset foo bar", Exception), + ("foo bar baz", SystemExit), + ], +) +def test_validation(args_cls: Parser, string: str, result: dict[str, Any] | type[BaseException]) -> None: + if isinstance(result, type) and issubclass(result, BaseException): + with pytest.raises(result): + args_cls.parse(string.split()) + else: + args = args_cls.parse(string.split()) + for k, v in result.items(): + assert getattr(args, k) == v + + +@pytest.mark.parametrize( + "string,result", + [ + ("foobar", {"foo": False, "bar": False}), + ("barfoo", {"foo": False, "bar": False}), + ("foobar --foo --bar", {"foo": True, "bar": True}), + ("barfoo --bar --foo", {"bar": True, "foo": True}), + ("foobar --bar", {"foo": False, "bar": True}), + ("barfoo --foo", {"bar": False, "foo": True}), + ("", SystemExit), + ("foobar --foo", Exception), + ("barfoo --bar", Exception), + ], +) +def test_validation_subparsers( + subparsers_args_cls: Parser, string: str, result: dict[str, Any] | type[BaseException] +) -> None: + if isinstance(result, type) and issubclass(result, BaseException): + with pytest.raises(result): + subparsers_args_cls.parse(string.split()) + else: + args = subparsers_args_cls.parse(string.split()) + for k, v in result.items(): + assert getattr(args.arg, k) == v