diff --git a/README.md b/README.md
index ed9b83c..29fdc4f 100644
--- a/README.md
+++ b/README.md
@@ -5,12 +5,12 @@ This project builds on top of `argparse` by adding type-safety and allowing a mo
## Example usage
```py
-from arcparse import arcparser, flag
+from arcparse import arcparser, flag, positional
from pathlib import Path
@arcparser
class Args:
- path: Path
+ path: Path = positional()
recurse: bool = flag("-r")
item_limit: int = 100
output_path: Path | None
@@ -19,6 +19,21 @@ args = Args.parse()
print(f"Scanning {args.path}...")
...
```
+
+ Help output of this parser
+
+ usage: program.py [-h] [-r] [--item-limit ITEM_LIMIT] [--output-path OUTPUT_PATH] path
+
+ positional arguments:
+ path
+
+ options:
+ -h, --help show this help message and exit
+ -r, --recurse
+ --item-limit ITEM_LIMIT
+ --output-path OUTPUT_PATH
+
+
For more examples see [Examples](examples/).
@@ -29,13 +44,14 @@ $ pip install arcparse
```
## Features
-- Positional, Option and Flag arguments
-- Multiple values per argument
+- Positional, Option and [Flag](./examples/flag.py) arguments
- Name overriding
-- Type conversions
-- Mutually exclusive groups
-- Subparsers
-- Parser inheritance (with overriding)
+- [Multiple values per argument](./examples/multiple.py)
+- [Type conversions](./examples/conversion.py)
+- [Mutually exclusive groups](./examples/mutual_exclusion.py)
+- [Subparsers](./examples/subparsers.py)
+- [Parser inheritance](./examples/inheritance.py) (with [overriding](./examples/override.py))
+- [Presence validation](./examples/presence_validation.py)
## Credits
This project was inspired by [swansonk14/typed-argument-parser](https://github.com/swansonk14/typed-argument-parser).
diff --git a/arcparse/_partial_arguments.py b/arcparse/_partial_arguments.py
index 0fad2f3..e742f65 100644
--- a/arcparse/_partial_arguments.py
+++ b/arcparse/_partial_arguments.py
@@ -116,10 +116,6 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
class PartialPositional[T](BasePartialValueArgument[T, Positional]):
def resolve_with_typehint(self, name: str, typehint: type) -> Positional:
kwargs = self.resolve_to_kwargs(name, typehint)
- return Positional(**kwargs)
-
- def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
- kwargs = super().resolve_to_kwargs(name, typehint)
type_is_optional = bool(extract_optional_type(typehint))
type_is_collection = bool(extract_collection_type(typehint))
@@ -135,7 +131,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
kwargs["nargs"] = "+" if self.at_least_one else "*"
kwargs["metavar"] = self.name_override
- return kwargs
+ return Positional(name=name, **kwargs)
@dataclass
@@ -151,10 +147,6 @@ def resolve_with_typehint(self, name: str, typehint: type) -> Option:
)
kwargs = self.resolve_to_kwargs(name, typehint)
- return Option(**kwargs)
-
- def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
- kwargs = super().resolve_to_kwargs(name, typehint)
kwargs["short"] = self.short
kwargs["short_only"] = self.short_only
@@ -170,6 +162,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
if self.choices is None: # choices generate custom `{foo,bar}` metavar in argparse
kwargs["metavar"] = self.name_override.replace("-", "_").upper()
elif self.short_only and self.short is not None:
+ kwargs["name"] = name
kwargs["dest"] = name
else:
kwargs["name"] = name
@@ -181,8 +174,7 @@ def resolve_to_kwargs(self, name: str, typehint: type) -> dict[str, Any]:
kwargs["optional"] = True
elif self.mx_group is not None:
raise InvalidArgument("Arguments in mutually exclusive group have to have a default")
-
- return kwargs
+ return Option(**kwargs)
@dataclass
@@ -201,7 +193,7 @@ def resolve_with_typehint(self, name: str, typehint: type) -> Flag:
kwargs["short"] = self.short
kwargs["short_only"] = self.short_only
kwargs["no_flag"] = self.no_flag
- return Flag(**kwargs)
+ return Flag(name=name, **kwargs)
class PartialTriFlag(BasePartialArgument[TriFlag]):
diff --git a/arcparse/_validations.py b/arcparse/_validations.py
new file mode 100644
index 0000000..eac8e9c
--- /dev/null
+++ b/arcparse/_validations.py
@@ -0,0 +1,89 @@
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Collection, Iterable, Sequence
+from dataclasses import dataclass
+
+from arcparse.arguments import BaseArgument, BaseValueArgument, Flag
+
+
+class ArgumentAccessor:
+ def __init__(self, arguments: dict[str, BaseArgument]):
+ self._arguments = arguments
+
+ def __getattribute__(self, key: str) -> BaseArgument:
+ arguments = super().__getattribute__("_arguments")
+ if key not in arguments:
+ raise Exception(f'Argument "{key}" doesn\'t exist')
+ return arguments[key]
+
+
+class Constraint(ABC):
+ @abstractmethod
+ def validate(self, arguments: dict[str, str]) -> bool: ...
+
+ @staticmethod
+ def is_provided(argument: BaseArgument, arguments: dict[str, str]) -> bool:
+ if (provided_value := arguments.get(argument.name)) is None:
+ return False
+
+ if isinstance(argument, Flag):
+ defined_default = argument.no_flag
+ elif isinstance(argument, BaseValueArgument):
+ defined_default = argument.default
+ else:
+ raise Exception(f"is_provided is not defined for {argument.__class__.__name__}")
+ return provided_value != defined_default
+
+
+@dataclass
+class ImplyConstraint(Constraint):
+ arg: BaseArgument
+ required: Collection[BaseArgument]
+ disallowed: Collection[BaseArgument]
+
+ def validate(self, arguments: dict[str, str]) -> bool:
+ if not self.is_provided(self.arg, arguments):
+ return False
+
+ for arg in self.required:
+ if not self.is_provided(arg, arguments):
+ raise Exception(f'Argument "{arg.display_name}" is required when "{self.arg.display_name}" is passed')
+
+ for arg in self.disallowed:
+ if self.is_provided(arg, arguments):
+ raise Exception(f'Argument "{arg.display_name}" is incompatible with "{self.arg.display_name}"')
+ return True
+
+
+@dataclass
+class RequireConstraint(Constraint):
+ args: Collection[BaseArgument]
+
+ def validate(self, arguments: dict[str, str]) -> bool:
+ def and_join(names: Sequence[str]) -> str:
+ if len(names) == 0:
+ return ""
+ elif len(names) == 1:
+ return names[0]
+ return f"{', '.join(names[:-1])} and {names[-1]}"
+
+ not_provided = [arg for arg in self.args if not self.is_provided(arg, arguments)]
+ if not_provided:
+ provided_text = "none" if len(not_provided) == len(self.args) else "only some"
+ raise Exception(
+ f"Arguments {and_join([arg.display_name for arg in self.args])} are required together, but {provided_text} were provided"
+ )
+ return True
+
+
+def validate_with(
+ defined_arguments: dict[str, BaseArgument],
+ validations_callable: Callable[[ArgumentAccessor], Iterable[Constraint]],
+ provided_arguments: dict[str, str],
+) -> None:
+ for constraint in validations_callable(ArgumentAccessor(defined_arguments)):
+ if not isinstance(constraint, Constraint):
+ raise TypeError("Items returned from __validations__() have to be of type Constrant")
+
+ matched = constraint.validate(provided_arguments)
+ if matched:
+ break
diff --git a/arcparse/arguments.py b/arcparse/arguments.py
index e7c20aa..f9e4c85 100644
--- a/arcparse/arguments.py
+++ b/arcparse/arguments.py
@@ -36,17 +36,22 @@ def apply(self, actions_container: _ActionsContainer, name: str) -> None: ...
@dataclass(kw_only=True)
class BaseArgument(ABC, ContainerApplicable):
+ name: str
help: str | None = None
+ @property
+ @abstractmethod
+ def display_name(self) -> str: ...
+
def apply(self, actions_container: _ActionsContainer, name: str) -> None:
- args = self.get_argparse_args(name)
- kwargs = self.get_argparse_kwargs(name)
+ args = self.get_argparse_args()
+ kwargs = self.get_argparse_kwargs()
actions_container.add_argument(*args, **kwargs)
@abstractmethod
- def get_argparse_args(self, name: str) -> list[str]: ...
+ def get_argparse_args(self) -> list[str]: ...
- def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
+ def get_argparse_kwargs(self) -> dict[str, Any]:
kwargs = {}
if self.help is not None:
kwargs["help"] = self.help
@@ -59,31 +64,38 @@ class Flag(BaseArgument):
short_only: bool = False
no_flag: bool = False
- def get_argparse_args(self, name: str) -> list[str]:
+ @property
+ def display_name(self) -> str:
if self.no_flag:
- args = [f"--no-{name.replace("_", "-")}"]
- else:
- args = [f"--{name.replace("_", "-")}"]
+ return f"--no-{self.name.replace("_", "-")}"
+ return f"--{self.name.replace("_", "-")}"
+ def get_argparse_args(self) -> list[str]:
if self.short_only:
if self.short is not None:
return [self.short]
- else:
- return [f"-{name}"]
- elif self.short is not None:
- args.insert(0, self.short)
+ return [f"-{self.name}"]
+ args = [self.display_name]
+ if self.short is not None:
+ args.insert(0, self.short)
return args
- def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
- kwargs = super().get_argparse_kwargs(name)
+ def get_argparse_kwargs(self) -> dict[str, Any]:
+ kwargs = super().get_argparse_kwargs()
kwargs["action"] = "store_true" if not self.no_flag else "store_false"
- kwargs["dest"] = name
+ kwargs["dest"] = self.name
return kwargs
class TriFlag(ContainerApplicable):
+ name: str
+
+ @property
+ def display_name(self) -> str:
+ return f"--(no-){self.name.replace("_", "-")}"
+
def apply(self, actions_container: _ActionsContainer, name: str) -> None:
# if actions_container is not an mx group, make it one, argparse
# doesn't support mx group nesting
@@ -104,8 +116,8 @@ class BaseValueArgument[T](BaseArgument):
optional: bool = False
metavar: str | None = None
- def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
- kwargs = super().get_argparse_kwargs(name)
+ def get_argparse_kwargs(self) -> dict[str, Any]:
+ kwargs = super().get_argparse_kwargs()
if self.default is not void:
kwargs["default"] = self.default
@@ -121,11 +133,15 @@ def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
@dataclass
class Positional[T](BaseValueArgument[T]):
- def get_argparse_args(self, name: str) -> list[str]:
- return [name]
+ @property
+ def display_name(self) -> str:
+ return self.name
+
+ def get_argparse_args(self) -> list[str]:
+ return [self.name]
- def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
- kwargs = super().get_argparse_kwargs(name)
+ def get_argparse_kwargs(self) -> dict[str, Any]:
+ kwargs = super().get_argparse_kwargs()
if self.nargs is None and (self.optional or self.default is not void):
kwargs["nargs"] = "?"
@@ -135,26 +151,28 @@ def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
@dataclass
class Option[T](BaseValueArgument[T]):
- name: str | None = None
dest: str | None = None
short: str | None = None
short_only: bool = False
append: bool = False
- def get_argparse_args(self, name: str) -> list[str]:
- args = [f"--{(self.name or name).replace("_", "-")}"]
+ @property
+ def display_name(self) -> str:
+ return f"--{(self.name).replace("_", "-")}"
+
+ def get_argparse_args(self) -> list[str]:
if self.short_only:
if self.short is not None:
return [self.short]
- else:
- return [f"-{self.name or name}"]
- elif self.short is not None:
- args.insert(0, self.short)
+ return [f"-{self.name}"]
+ args = [self.display_name]
+ if self.short is not None:
+ args.insert(0, self.short)
return args
- def get_argparse_kwargs(self, name: str) -> dict[str, Any]:
- kwargs = super().get_argparse_kwargs(name)
+ def get_argparse_kwargs(self) -> dict[str, Any]:
+ kwargs = super().get_argparse_kwargs()
if self.dest is not None:
kwargs["dest"] = self.dest
diff --git a/arcparse/parser.py b/arcparse/parser.py
index 9378774..eea5dfd 100644
--- a/arcparse/parser.py
+++ b/arcparse/parser.py
@@ -19,6 +19,7 @@
extract_subparsers_from_typehint,
union_contains_none,
)
+from ._validations import validate_with
from .arguments import (
BaseArgument,
BaseValueArgument,
@@ -106,6 +107,9 @@ def _construct_object_with_parsed(self, parsed: dict[str, Any]) -> T:
sub_parser = subparsers.sub_parsers[chosen_subparser]
parsed[name] = sub_parser._construct_object_with_parsed(parsed)
+ if (validations := getattr(self.shape, "__presence_validations__", None)) and callable(validations):
+ validate_with(self.arguments, cast(Any, validations).__func__, parsed)
+
# apply argument converters
for name, argument in self.all_arguments:
if not isinstance(argument, BaseValueArgument) or argument.converter is None:
@@ -139,9 +143,10 @@ def __getattribute__(self, name: str):
if name in __dict__:
return __dict__[name]
- value = super().__getattribute__(name)
- if not isinstance(value, BasePartialArgument):
- return value
+ if hasattr(super(), name):
+ value = getattr(super(), name)
+ if not isinstance(value, BasePartialArgument):
+ return value
raise AttributeError(f"'{cls.__name__}' parser didn't define argument '{name}'", name=name, obj=self)
diff --git a/arcparse/validations.py b/arcparse/validations.py
new file mode 100644
index 0000000..f2ff884
--- /dev/null
+++ b/arcparse/validations.py
@@ -0,0 +1,33 @@
+from collections.abc import Collection
+from typing import Any
+
+from ._validations import Constraint, ImplyConstraint, RequireConstraint
+
+
+__all__ = [
+ "Constraint",
+ "ImplyConstraint",
+ "RequireConstraint",
+ "imply",
+ "require",
+]
+
+
+def imply(arg: Any, required: Collection[Any] = (), disallowed: Collection[Any] = ()) -> ImplyConstraint:
+ """
+ Require and disallow arguments when arg is passed.
+
+ If `arg` is present and the constraint is fulfilled, no subsequent constraints are checked.
+
+ :param Any arg: argument to check existence for
+ :param Collection[Any] required: required arguments when arg is passed
+ :param Collection[Any] disallowed: disallowed arguments when arg is passed
+ """
+ return ImplyConstraint(arg, required=required, disallowed=disallowed)
+
+
+def require(*args: Any) -> RequireConstraint:
+ """
+ Require arguments to be present.
+ """
+ return RequireConstraint(args)
diff --git a/examples/override.py b/examples/override.py
index 4a8790a..da9ea96 100644
--- a/examples/override.py
+++ b/examples/override.py
@@ -20,7 +20,7 @@ def __post_init__(parser: Parser) -> None:
del parser.arguments["bar"]
# create new argument
- parser.arguments["baz"] = Option("baz", default="baz")
+ parser.arguments["baz"] = Option("baz", name="baz", default="baz")
if __name__ == "__main__":
diff --git a/examples/presence_validation.py b/examples/presence_validation.py
new file mode 100644
index 0000000..9b7b356
--- /dev/null
+++ b/examples/presence_validation.py
@@ -0,0 +1,49 @@
+from typing import Iterator
+
+from arcparse import arcparser, positional
+from arcparse.validations import Constraint, imply, require
+
+
+@arcparser
+class Args:
+ """
+ This parser implements a config subcommand similar to `git config`.
+
+ Usage:
+ `--list` displays configuration
+ `--unset ` unsets the key
+ ` ` sets the key to the value
+
+ It would be impossible to create this parser using subparsers, since subparsers require a name to select
+ the subparser used, while this example uses flags to select the action.
+
+ This example could be partially implemented by just using a mutually exclusive group with `--list` and `--unset`
+ but the presence of `key` wouldn't be required with `--unset` and similarly the presence of `value` wouldn't be
+ required when neither `--list` nor `--unset` are present.
+ """
+
+ list: bool
+ unset: bool
+ key: str | None = positional()
+ value: str | None = positional()
+
+ @classmethod
+ def __presence_validations__(cls) -> Iterator[Constraint]:
+ """
+ Generate argument presence constraints.
+
+ The return value can be any iterable.
+ """
+ # --list is incompatible with --unset, key and value
+ yield imply(cls.list, disallowed=[cls.unset, cls.key, cls.value])
+
+ # --unset requires key, but is incompatible with value (incompatibility with --list is verified by
+ # the previous constraint)
+ yield imply(cls.unset, required=[cls.key], disallowed=[cls.value])
+
+ # if none of the previous constraints matched (--list and --unset are not provided), require both key and value
+ yield require(cls.key, cls.value)
+
+
+if __name__ == "__main__":
+ print(vars(Args.parse()))
diff --git a/examples/subparsers.py b/examples/subparsers.py
index 85918b8..af31287 100644
--- a/examples/subparsers.py
+++ b/examples/subparsers.py
@@ -27,4 +27,7 @@ class Args:
if __name__ == "__main__":
- print(vars(Args.parse()))
+ parsed = Args.parse()
+ print(vars(parsed))
+
+ parsed.action.run_action()
diff --git a/pyproject.toml b/pyproject.toml
index 4be5db4..5c4678d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "arcparse"
-version = "1.2.0"
+version = "1.3.0"
description = "Declare program arguments in a type-safe way"
license = "MIT"
authors = ["Jakub Rozek "]
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 201c58c..c5c837e 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -12,6 +12,12 @@ def test_example_functional(path: Path) -> None:
try:
subprocess.run(["poetry", "run", "python3", str(path)], check=True, stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
- if b"usage:" not in e.stderr:
+ ignore_strings = [
+ # some required arguments were not provided
+ "usage:",
+ # some required arguments were not provided together (caused by presence validation)
+ "are required together",
+ ]
+ if not any(msg in e.stderr.decode() for msg in ignore_strings):
# process didn't exit from a parser error, reraise
raise
diff --git a/tests/test_override.py b/tests/test_override.py
index 26c1cae..ab29161 100644
--- a/tests/test_override.py
+++ b/tests/test_override.py
@@ -32,7 +32,7 @@ class Args(BaseArgs):
@staticmethod
def __post_init__(parser: Parser) -> None:
- parser.arguments["bar"] = Option("bar")
+ parser.arguments["bar"] = Option("bar", name="bar")
with pytest.raises(SystemExit):
Args.parse("--foo foo".split())
diff --git a/tests/test_presence_validations.py b/tests/test_presence_validations.py
new file mode 100644
index 0000000..987dad5
--- /dev/null
+++ b/tests/test_presence_validations.py
@@ -0,0 +1,97 @@
+from typing import Any, Iterator
+
+import pytest
+
+from arcparse import Parser, arcparser, positional, subparsers
+from arcparse.validations import Constraint, imply, require
+
+
+@pytest.fixture(scope="session")
+def args_cls() -> Parser:
+ @arcparser
+ class ConfigArgs:
+ list: bool
+ unset: bool
+ key: str | None = positional()
+ value: str | None = positional()
+
+ @classmethod
+ def __presence_validations__(cls) -> Iterator[Constraint]:
+ yield imply(cls.list, disallowed=[cls.unset, cls.key, cls.value])
+ yield imply(cls.unset, required=[cls.key], disallowed=[cls.value])
+ yield require(cls.key, cls.value)
+
+ return ConfigArgs
+
+
+@pytest.fixture(scope="session")
+def subparsers_args_cls() -> Parser:
+ class FooBar:
+ foo: bool
+ bar: bool
+
+ @classmethod
+ def __presence_validations__(cls) -> Iterator[Constraint]:
+ yield imply(cls.foo, required=[cls.bar])
+
+ class BarFoo:
+ bar: bool
+ foo: bool
+
+ @classmethod
+ def __presence_validations__(cls) -> Iterator[Constraint]:
+ yield imply(cls.bar, required=[cls.foo])
+
+ @arcparser
+ class Args:
+ arg: FooBar | BarFoo = subparsers("foobar", "barfoo")
+
+ return Args
+
+
+@pytest.mark.parametrize(
+ "string,result",
+ [
+ ("--list", {"list": True}),
+ ("--unset foo", {"unset": True, "key": "foo"}),
+ ("foo bar", {"key": "foo", "value": "bar"}),
+ ("--list --unset", Exception),
+ ("--list foo", Exception),
+ ("--unset foo bar", Exception),
+ ("foo bar baz", SystemExit),
+ ],
+)
+def test_validation(args_cls: Parser, string: str, result: dict[str, Any] | type[BaseException]) -> None:
+ if isinstance(result, type) and issubclass(result, BaseException):
+ with pytest.raises(result):
+ args_cls.parse(string.split())
+ else:
+ args = args_cls.parse(string.split())
+ for k, v in result.items():
+ assert getattr(args, k) == v
+
+
+@pytest.mark.parametrize(
+ "string,result",
+ [
+ ("foobar", {"foo": False, "bar": False}),
+ ("barfoo", {"foo": False, "bar": False}),
+ ("foobar --foo --bar", {"foo": True, "bar": True}),
+ ("barfoo --bar --foo", {"bar": True, "foo": True}),
+ ("foobar --bar", {"foo": False, "bar": True}),
+ ("barfoo --foo", {"bar": False, "foo": True}),
+ ("", SystemExit),
+ ("foobar --foo", Exception),
+ ("barfoo --bar", Exception),
+ ],
+)
+def test_validation_subparsers(
+ subparsers_args_cls: Parser, string: str, result: dict[str, Any] | type[BaseException]
+) -> None:
+ if isinstance(result, type) and issubclass(result, BaseException):
+ with pytest.raises(result):
+ subparsers_args_cls.parse(string.split())
+ else:
+ args = subparsers_args_cls.parse(string.split())
+ for k, v in result.items():
+ assert getattr(args.arg, k) == v