From 1f9cfcfa3b4a12f374aef12d0cbc13d44b3a7674 Mon Sep 17 00:00:00 2001 From: "M. Ernestus" Date: Wed, 30 Nov 2022 15:31:40 +0100 Subject: [PATCH] Allow python script for hyperparameter configuration (#318) * Allow loading configuration from python package in addition to yaml files. * Add an example python configuration file. * Add a test for training with the example python config file. * Formatting fixes. * Flake8 6 doesn't support inline comments * Allow python scripts * Remove debug print * Add python config file usage example to README.md * Update CHANGELOG.md * Fix paths in python config example and add variant where we specify the python file instead of package. * Update README Co-authored-by: Antonin Raffin Co-authored-by: Antonin Raffin --- CHANGELOG.md | 2 ++ Makefile | 2 +- README.md | 18 +++++++++-- hyperparams/python/ppo_config_example.py | 29 +++++++++++++++++ rl_zoo3/exp_manager.py | 41 +++++++++++++++++------- rl_zoo3/train.py | 21 ++++++++++-- rl_zoo3/version.txt | 2 +- setup.cfg | 3 +- tests/test_train.py | 22 ++++++++++++- 9 files changed, 119 insertions(+), 21 deletions(-) create mode 100644 hyperparams/python/ppo_config_example.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 41363dea6..af109d87f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,13 @@ ## Release 1.7.0a2 (WIP) ### Breaking Changes +- `--yaml-file` argument was renamed to `-conf` (`--conf-file`) as now python file are supported too ### New Features - Specifying custom policies in yaml file is now supported (@Rick-v-E) - Added ``monitor_kwargs`` parameter - Handle the `env_kwargs` of `render:True` under the hood for panda-gym v1 envs in `enjoy` replay to match visualzation behavior of other envs +- Added support for python config file ### Bug fixes - Allow `python -m rl_zoo3.cli` to be called directly diff --git a/Makefile b/Makefile index 7139a61e4..586014871 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ +LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py # Run pytest and coverage report pytest: diff --git a/README.md b/README.md index bed85d6ed..e3b8739c4 100644 --- a/README.md +++ b/README.md @@ -61,11 +61,23 @@ python train.py --algo algo_name --env env_id ``` You can use `-P` (`--progress`) option to display a progress bar. -Using a custom yaml file (which contains a `env_id` entry): +Using a custom config file when it is a yaml file with a which contains a `env_id` entry: ``` -python train.py --algo algo_name --env env_id --yaml-file my_yaml.yml +python train.py --algo algo_name --env env_id --conf-file my_yaml.yml ``` +You can also use a python file that contains a dictionary called `hyperparams` with an entry for each `env_id`. +(see `hyperparams/python/ppo_config_example.py` for an example) +```bash +# You can pass a path to a python file +python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams/python/ppo_config_example.py +# Or pass a path to a file from a module (for instance my_package.my_file +python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams.python.ppo_config_example +``` +The advantage of this approach is that you can specify arbitrary python dictionaries +and ensure that all their dependencies are imported in the config file itself. + + For example (with tensorboard support): ``` python train.py --algo ppo --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/ @@ -139,7 +151,7 @@ Remark: plotting with the `--rliable` option is usually slow as confidence inter ## Custom Environment -The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--yaml-file` argument). +The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--conf-file` argument). ## Enjoy a Trained Agent diff --git a/hyperparams/python/ppo_config_example.py b/hyperparams/python/ppo_config_example.py new file mode 100644 index 000000000..225aba6ee --- /dev/null +++ b/hyperparams/python/ppo_config_example.py @@ -0,0 +1,29 @@ +"""This file just serves as an example on how to configure the zoo +using python scripts instead of yaml files.""" +import torch + +hyperparams = { + "MountainCarContinuous-v0": dict( + env_wrapper=[{"gym.wrappers.TimeLimit": {"max_episode_steps": 100}}], + normalize=True, + n_envs=1, + n_timesteps=20000.0, + policy="MlpPolicy", + batch_size=8, + n_steps=8, + gamma=0.9999, + learning_rate=7.77e-05, + ent_coef=0.00429, + clip_range=0.1, + n_epochs=2, + gae_lambda=0.9, + max_grad_norm=5, + vf_coef=0.19, + use_sde=True, + policy_kwargs=dict( + log_std_init=-3.29, + ortho_init=False, + activation_fn=torch.nn.ReLU, + ), + ) +} diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index bfa6bd1c0..bc5e4bb0c 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -1,4 +1,5 @@ import argparse +import importlib import os import pickle as pkl import time @@ -93,7 +94,7 @@ def __init__( n_eval_envs: int = 1, no_optim_plots: bool = False, device: Union[th.device, str] = "auto", - yaml_file: Optional[str] = None, + config: Optional[str] = None, show_progress: bool = False, ): super().__init__() @@ -108,7 +109,7 @@ def __init__( # Take the root folder default_path = Path(__file__).parent.parent - self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml") + self.config = config or str(default_path / f"hyperparams/{self.algo}.yml") self.env_kwargs = {} if env_kwargs is None else env_kwargs self.n_timesteps = n_timesteps self.normalize = False @@ -281,16 +282,28 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None: print(f"Log path: {self.save_path}") def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: - # Load hyperparameters from yaml file - print(f"Loading hyperparameters from: {self.yaml_file}") - with open(self.yaml_file) as f: - hyperparams_dict = yaml.safe_load(f) - if self.env_name.gym_id in list(hyperparams_dict.keys()): - hyperparams = hyperparams_dict[self.env_name.gym_id] - elif self._is_atari: - hyperparams = hyperparams_dict["atari"] - else: - raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}") + print(f"Loading hyperparameters from: {self.config}") + + if self.config.endswith(".yml") or self.config.endswith(".yaml"): + # Load hyperparameters from yaml file + with open(self.config) as f: + hyperparams_dict = yaml.safe_load(f) + elif self.config.endswith(".py"): + global_variables = {} + # Load hyperparameters from python file + exec(Path(self.config).read_text(), global_variables) + hyperparams_dict = global_variables["hyperparams"] + else: + # Load hyperparameters from python package + hyperparams_dict = importlib.import_module(self.config).hyperparams + # raise ValueError(f"Unsupported config file format: {self.config}") + + if self.env_name.gym_id in list(hyperparams_dict.keys()): + hyperparams = hyperparams_dict[self.env_name.gym_id] + elif self._is_atari: + hyperparams = hyperparams_dict["atari"] + else: + raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}") if self.custom_hyperparams is not None: # Overwrite hyperparams if needed @@ -336,6 +349,10 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An self.normalize_kwargs = eval(self.normalize) self.normalize = True + if isinstance(self.normalize, dict): + self.normalize_kwargs = self.normalize + self.normalize = True + # Use the same discount factor as for the algorithm if "gamma" in hyperparams: self.normalize_kwargs["gamma"] = hyperparams["gamma"] diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index bffd83663..b60fc3cf9 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -122,7 +122,19 @@ def train(): help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)", ) parser.add_argument( - "-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded" + "-conf", + "--conf-file", + type=str, + default=None, + help="Custom yaml file or python package from which the hyperparameters will be loaded." + "We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.", + ) + parser.add_argument( + "-yaml", + "--yaml-file", + type=str, + default=None, + help="This parameter is deprecated, please use `--conf-file` instead", ) parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID") parser.add_argument( @@ -150,6 +162,11 @@ def train(): env_id = args.env registered_envs = set(gym.envs.registry.env_specs.keys()) # pytype: disable=module-attr + if args.yaml_file is not None: + raise ValueError( + "The`--yaml-file` parameter is deprecated and will be removed in RL Zoo3 v1.8, please use `--conf-file` instead", + ) + # If the environment is not found, suggest the closest match if env_id not in registered_envs: try: @@ -234,7 +251,7 @@ def train(): n_eval_envs=args.n_eval_envs, no_optim_plots=args.no_optim_plots, device=args.device, - yaml_file=args.yaml_file, + config=args.conf_file, show_progress=args.progress, ) diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index b89528520..08b6b37ca 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -1.7.0a2 +1.7.0a3 diff --git a/setup.cfg b/setup.cfg index 0f61321a1..5bb294e9c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,8 @@ markers = inputs = . [flake8] -ignore = W503,W504,E203,E231 # line breaks before and after binary operators +# line breaks before and after binary operators +ignore = W503,W504,E203,E231 # Ignore import not used when aliases are defined per-file-ignores = ./rl_zoo3/__init__.py:F401 diff --git a/tests/test_train.py b/tests/test_train.py index fc4bcd545..32cbe4164 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -116,7 +116,7 @@ def test_custom_yaml(tmp_path): "CartPole-v1", "--log-folder", tmp_path, - "-yaml", + "-conf", "hyperparams/a2c.yml", "-params", "n_envs:2", @@ -129,3 +129,23 @@ def test_custom_yaml(tmp_path): return_code = subprocess.call(["python", "train.py"] + args) _assert_eq(return_code, 0) + + +@pytest.mark.parametrize("config_file", ["hyperparams.python.ppo_config_example", "hyperparams/python/ppo_config_example.py"]) +def test_python_config_file(tmp_path, config_file): + # Use the example python config file for training + args = [ + "-n", + str(N_STEPS), + "--algo", + "ppo", + "--env", + "MountainCarContinuous-v0", + "--log-folder", + tmp_path, + "-conf", + config_file, + ] + + return_code = subprocess.call(["python", "train.py"] + args) + _assert_eq(return_code, 0)