Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #231 from delira-dev/update_from_sys_arg
Browse files Browse the repository at this point in the history
Update Config from system args
  • Loading branch information
ORippler authored Dec 3, 2019
2 parents 64ce57b + c0e5ac8 commit be0f9f0
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
128 changes: 128 additions & 0 deletions delira/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import yaml
import argparse
import sys
import collections
import inspect


def non_string_warning(func):
Expand All @@ -32,6 +35,7 @@ def warning_wrapper(config, key, *args, **kwargs):
key, type(key)), RuntimeWarning)

return func(config, key, *args, **kwargs)

return warning_wrapper


Expand Down Expand Up @@ -513,6 +517,130 @@ def create_from_str(cls, data, formatter=yaml.load, decoder_cls=Decoder,
**kwargs)
return config

def create_argparser(self):
'''
Creates an argparser for all values in the config
Following the pattern: `--training.learning_rate 1234`
Returns
-------
argparse.ArgumentParser
parser for all variables in the config
'''
parser = argparse.ArgumentParser(allow_abbrev=False)

def add_val(dict_like, prefix=''):
for key, val in dict_like.items():
name = "--{}".format(prefix + key)
if val is None:
parser.add_argument(name)
else:
if isinstance(val, int):
parser.add_argument(name, type=type(val))
elif isinstance(val, collections.Mapping):
add_val(val, prefix=key + '.')
elif isinstance(val, collections.Iterable):
if len(val) > 0 and type(val[0]) != type:
parser.add_argument(name, type=type(val[0]))
else:
parser.add_argument(name)
elif issubclass(val, type) or inspect.isclass(val):
parser.add_argument(name, type=val)
else:
parser.add_argument(name, type=type(val))

add_val(self)
return parser

@staticmethod
def _add_unknown_args(unknown_args):
'''
Can add unknown args as parsed by argparsers method
`parse_unknown_args`.
Parameters
------
unknown_args : list
list of unknown args
Returns
------
Config
a config of the parsed args
'''
# first element in the list must be a key
if not isinstance(unknown_args[0], str):
unknown_args = [str(arg) for arg in unknown_args]
if not unknown_args[0].startswith('--'):
raise ValueError

args = Config()
# take first key
key = unknown_args[0][2:]
idx, done, val = 1, False, []
while not done:
try:
item = unknown_args[idx]
except IndexError:
done = True
if item.startswith('--') or done:
# save key with its value
if len(val) == 0:
# key is used as flag
args[key] = True
elif len(val) == 1:
args[key] = val[0]
else:
args[key] = val
# new key and flush data
key = item[2:]
val = []
else:
val.append(item)
idx += 1
return args

def update_from_argparse(self, parser=None, add_unknown_items=False):
'''
Updates the config with all values from the command line.
Following the pattern: `--training.learning_rate 1234`
Raises
------
TypeError
raised if another datatype than currently in the config is parsed
Returns
-------
dict
dictionary containing only updated arguments
'''

if len(sys.argv) > 1:
if not parser:
parser = self.create_argparser()

params, unknown = parser.parse_known_args()
params = vars(params)
if unknown and not add_unknown_items:
warnings.warn(
"Called with unknown arguments: {} "
"They will not be stored if you do not set "
"`add_unknown_items` to true.".format(unknown),
RuntimeWarning)

new_params = Config()
for key, val in params.items():
if val is None:
continue
new_params[key] = val

# update dict
self.update(new_params, overwrite=True)
if add_unknown_items:
additional_params = self._add_unknown_args(unknown)
self.update(additional_params)
new_params.update(additional_params)
return new_params


class LookupConfig(Config):
"""
Expand Down
45 changes: 45 additions & 0 deletions tests/utils/test_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import unittest
import os
import sys
import copy
import argparse
from unittest.mock import patch
from delira._version import get_versions

from delira.utils.config import Config, LookupConfig, DeliraConfig
from delira.logging import Logger, TensorboardBackend, make_logger, \
register_logger
import warnings

from . import check_for_no_backend

Expand Down Expand Up @@ -215,6 +218,48 @@ def test_internal_type(self):
cf = self.config_cls.create_from_dict(self.example_dict)
self.assertTrue(isinstance(cf["deep"], self.config_cls))

@unittest.skipUnless(
check_for_no_backend(),
"Test should only be executed if no backend is specified")
def test_create_argparser(self):
cf = self.config_cls.create_from_dict(self.example_dict)
testargs = [
'--shallowNum',
'10',
'--deep.deepStr',
'check',
'--testlist',
'ele1',
'ele2',
'--setflag']
parser = cf.create_argparser()
known, unknown = parser.parse_known_args(testargs)
self.assertEqual(vars(known)['shallowNum'], 10)
self.assertEqual(vars(known)['deep.deepStr'], 'check')
self.assertEqual(unknown, ['--testlist', 'ele1', 'ele2', '--setflag'])

@unittest.skipUnless(
check_for_no_backend(),
"Test should only be executed if no backend is specified")
def test_update_from_argparse(self):
cf = self.config_cls.create_from_dict(self.example_dict)
testargs = ['--shallowNum', '10',
'--deep.deepStr', 'check',
'--testlist', 'ele1', 'ele2',
'--setflag']
# placeholder pyfile because argparser omits first argument from sys
# argv
with patch.object(sys, 'argv', ['pyfile.py'] + testargs):
cf.update_from_argparse(add_unknown_items=True)
self.assertEqual(cf['shallowNum'], int(testargs[1]))
self.assertEqual(cf['deep']['deepStr'], testargs[3])
self.assertEqual(cf['testlist'], testargs[5:7])
self.assertEqual(cf['setflag'], True)
with warnings.catch_warnings(record=True) as w:
with patch.object(sys, 'argv', ['pyfile.py', '--unknown', 'arg']):
cf.update_from_argparse(add_unknown_items=False)
self.assertEqual(len(w), 1)


class LookupConfigTest(ConfigTest):
def setUp(self):
Expand Down

0 comments on commit be0f9f0

Please sign in to comment.