Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 142 additions & 61 deletions configargparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,18 +572,70 @@ def parse_known_args(

# prepare for reading config file(s)
known_config_keys = {config_key: action for action in self._actions
for config_key in self.get_possible_config_keys(action)}
for config_key in self.get_possible_config_keys(action, allow_config_file_args=True)}

# parse all config files
args = self.collate_all_config_args(
args, config_file_contents,
known_config_keys,
skip_config_file_parsing
)

# save default settings for use by print_values()
default_settings = OrderedDict()
for action in self._actions:
cares_about_default_value = (not action.is_positional_arg or
action.nargs in [OPTIONAL, ZERO_OR_MORE])
if (already_on_command_line(args, action.option_strings, self.prefix_chars) or
not cares_about_default_value or
action.default is None or
action.default == SUPPRESS or
isinstance(action, ACTION_TYPES_THAT_DONT_NEED_A_VALUE)):
continue
else:
if action.option_strings:
key = action.option_strings[-1]
else:
key = action.dest
default_settings[key] = (action, str(action.default))

if default_settings:
self._source_to_settings[_DEFAULTS_SOURCE_KEY] = default_settings

# parse all args (including commandline, config file, and env var)
namespace, unknown_args = argparse.ArgumentParser.parse_known_args(
self, args=args, namespace=namespace)
# handle any args that have is_write_out_config_file_arg set to true
# check if the user specified this arg on the commandline
output_file_paths = [getattr(namespace, a.dest, None) for a in self._actions
if getattr(a, "is_write_out_config_file_arg", False)]
output_file_paths = [a for a in output_file_paths if a is not None]
self.write_config_file(namespace, output_file_paths, exit_after=True)
return namespace, unknown_args

def collate_all_config_args(
self,
args,
config_file_contents,
known_config_keys,
skip_config_file_parsing,
parsed_config_files = []
):
"""Parse all config files. If a config option is found in a config file,
it will further be parsed in a recursive fashion.
"""
# open the config file(s)
config_streams = []
if config_file_contents is not None:
stream = StringIO(config_file_contents)
stream.name = "method arg"
config_streams = [stream]
elif not skip_config_file_parsing:
config_streams = self._open_config_files(args)
config_streams, args = self._open_config_files(args)

# parse each config file
config_keys = self._get_config_file_args_keys()
new_config_files = []
for stream in reversed(config_streams):
try:
config_items = self._config_file_parser.parse(stream)
Expand All @@ -599,8 +651,22 @@ def parse_known_args(
for key, value in config_items.items():
if key in known_config_keys:
action = known_config_keys[key]
discard_this_key = already_on_command_line(
args, action.option_strings, self.prefix_chars)
if key not in config_keys:
# if the key is not that of a config file,
# check if already on command line
discard_this_key = already_on_command_line(
args, action.option_strings, self.prefix_chars)
else:
# if the key corresponds to a config argument, we remove
# the already visited config files to avoid circular dependencies
if isinstance(value, list):
for c in value:
if c not in parsed_config_files and c not in new_config_files:
new_config_files.append(c)
else:
if value not in parsed_config_files and value not in new_config_files:
new_config_files.append(value)
discard_this_key= True
else:
action = None
discard_this_key = self._ignore_unknown_config_file_keys or \
Expand All @@ -625,37 +691,26 @@ def parse_known_args(
else:
args = config_args + args

# save default settings for use by print_values()
default_settings = OrderedDict()
for action in self._actions:
cares_about_default_value = (not action.is_positional_arg or
action.nargs in [OPTIONAL, ZERO_OR_MORE])
if (already_on_command_line(args, action.option_strings, self.prefix_chars) or
not cares_about_default_value or
action.default is None or
action.default == SUPPRESS or
isinstance(action, ACTION_TYPES_THAT_DONT_NEED_A_VALUE)):
continue
else:
if action.option_strings:
key = action.option_strings[-1]
else:
key = action.dest
default_settings[key] = (action, str(action.default))

if default_settings:
self._source_to_settings[_DEFAULTS_SOURCE_KEY] = default_settings
if len(new_config_files) > 0:
# Add the newly found config files to
# the list of files to not parse again
parsed_config_files.extend(new_config_files)

# Set the config file arguments to the
# newly discovered ones
args.append(config_keys[0])
args.extend(new_config_files)

# recursively call the collating function
args = self.collate_all_config_args(
args,
config_file_contents,
known_config_keys,
skip_config_file_parsing,
parsed_config_files
)

# parse all args (including commandline, config file, and env var)
namespace, unknown_args = argparse.ArgumentParser.parse_known_args(
self, args=args, namespace=namespace)
# handle any args that have is_write_out_config_file_arg set to true
# check if the user specified this arg on the commandline
output_file_paths = [getattr(namespace, a.dest, None) for a in self._actions
if getattr(a, "is_write_out_config_file_arg", False)]
output_file_paths = [a for a in output_file_paths if a is not None]
self.write_config_file(namespace, output_file_paths, exit_after=True)
return namespace, unknown_args
return args

def get_source_to_settings_dict(self):
"""
Expand Down Expand Up @@ -843,7 +898,7 @@ def convert_item_to_command_line_arg(self, action, key, value):

return args

def get_possible_config_keys(self, action):
def get_possible_config_keys(self, action, allow_config_file_args=False):
"""This method decides which actions can be set in a config file and
what their keys will be. It returns a list of 0 or more config keys that
can be used to set the given action's value in a config file.
Expand All @@ -860,9 +915,29 @@ def get_possible_config_keys(self, action):
for arg in action.option_strings:
if any(arg.startswith(2*c) for c in self.prefix_chars):
keys += [arg[2:], arg] # eg. for '--bla' return ['bla', '--bla']
elif getattr(action, 'is_config_file_arg', False) and allow_config_file_args:
if any(arg.startswith(c) for c in self.prefix_chars):
keys += [arg[1:], arg] # eg. for '-bla' return ['bla', '-bla']
return keys

def _get_config_file_args_keys(self):
keys = []

for a in self._actions:
if getattr(a, 'is_config_file_arg', False):
for arg in a.option_strings:
keys.append(arg)

for key in reversed(keys):
cleaned_key = key
while cleaned_key[0] == '-':
cleaned_key = cleaned_key[1:]
if cleaned_key != key:
keys.append(cleaned_key)

return keys


def _open_config_files(self, command_line_args):
"""Tries to parse config file path(s) from within command_line_args.
Returns a list of opened config files, including files specified on the
Expand All @@ -887,7 +962,7 @@ def _open_config_files(self, command_line_args):
a for a in self._actions if getattr(a, "is_config_file_arg", False)]

if not user_config_file_arg_actions:
return config_files
return config_files, command_line_args

for action in user_config_file_arg_actions:
# try to parse out the config file path by using a clean new
Expand All @@ -909,34 +984,40 @@ def error_method(self, message):
parsed_arg = arg_parser.parse_known_args(args=command_line_args)
if not parsed_arg:
continue
namespace, _ = parsed_arg
user_config_file = getattr(namespace, action.dest, None)
namespace, command_line_args = parsed_arg
user_config_files = getattr(namespace, action.dest, None)

if not user_config_file:
if user_config_files is None:
continue

# open user-provided config file
user_config_file = os.path.expanduser(user_config_file)
try:
stream = self._config_file_open_func(user_config_file)
except Exception as e:
if len(e.args) == 2: # OSError
errno, msg = e.args
else:
msg = str(e)
# close previously opened config files
for config_file in config_files:
try:
config_file.close()
except Exception:
pass
self.error("Unable to open config file: %s. Error: %s" % (
user_config_file, msg
))

config_files += [stream]

return config_files
if not isinstance(user_config_files, list):
user_config_files = [user_config_files]

for user_config_file in user_config_files:
# open user-provided config file
user_config_file = os.path.expanduser(user_config_file)
try:
stream = self._config_file_open_func(user_config_file)
except Exception as e:
if len(e.args) == 2: # OSError
errno, msg = e.args
else:
msg = str(e)
# close previously opened config files
for config_file in config_files:
try:
config_file.close()
except Exception:
pass
self.error("Unable to open config file: %s. Error: %s" % (
user_config_file, msg
))

config_files += [stream]

# we return both the config files and the remaining args
# to make sure the config files don't get parsed multiple times
return config_files, command_line_args

def format_values(self):
"""Returns a string with all args and settings and where they came from
Expand Down
2 changes: 2 additions & 0 deletions tests/test_conf_0.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
config=[test_conf_1.ini, test_conf_2.ini]
arg_0="conf_0"
2 changes: 2 additions & 0 deletions tests/test_conf_1.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
config=[test_conf_3.ini]
arg_1="conf_1"
2 changes: 2 additions & 0 deletions tests/test_conf_2.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
arg_0="conf_2"
config=test_conf_0.ini
1 change: 1 addition & 0 deletions tests/test_conf_3.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arg_2="conf_3"
3 changes: 3 additions & 0 deletions tests/test_conf_4.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
arg_0="conf_0"
arg_1="conf_1"
arg_2="conf_3"
2 changes: 2 additions & 0 deletions tests/test_conf_5.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
arg_0="conf_0"
arg_1="conf_1"
40 changes: 40 additions & 0 deletions tests/test_multiconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from pdb import set_trace
import sys
import unittest
from configargparse import ArgumentParser


def get_abs_path(file):
return os.path.abspath(os.path.join(os.path.dirname(__file__), file))

class TestMulticonf(unittest.TestCase):
def setUp(self):
self.parser = ArgumentParser()
self.parser.add_argument('-config', type=get_abs_path, nargs='*', is_config_file=True)
self.parser.add_argument('-config_2', type=get_abs_path, nargs='*', is_config_file=True)
self.parser.add_argument('--arg_0', type=str)
self.parser.add_argument('--arg_1', type=str)
self.parser.add_argument('--arg_2', type=str)

self.ref = self.parse_command_line('-config test_conf_4.ini')

def parse_command_line(self, cmd):
sys.argv = [sys.argv[0]] + cmd.split(' ')
args = self.parser.parse_args()
return args

def test_multiconf(self):
args = self.parse_command_line('-config test_conf_3.ini test_conf_5.ini')
self.assertEqual(self.ref, args)

def test_multiconf_2(self):
args = self.parse_command_line('-config test_conf_3.ini -config_2 test_conf_5.ini')
self.assertEqual(self.ref, args)

def test_conf_inheritance(self):
args = self.parse_command_line('-config test_conf_0.ini')
self.assertEqual(self.ref, args)

if __name__ == '__main__':
unittest.main()