Skip to content

Commit

Permalink
Allow python script for hyperparameter configuration (#318)
Browse files Browse the repository at this point in the history
* Allow loading configuration from python package in addition to yaml files.

* Add an example python configuration file.

* Add a test for training with the example python config file.

* Formatting fixes.

* Flake8 6 doesn't support inline comments

* Allow python scripts

* Remove debug print

* Add python config file usage example to README.md

* Update CHANGELOG.md

* Fix paths in python config example and add variant where we specify the python file instead of package.

* Update README

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
3 people authored Nov 30, 2022
1 parent 168c34e commit 1f9cfcf
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## Release 1.7.0a2 (WIP)

### Breaking Changes
- `--yaml-file` argument was renamed to `-conf` (`--conf-file`) as now python file are supported too

### New Features
- Specifying custom policies in yaml file is now supported (@Rick-v-E)
- Added ``monitor_kwargs`` parameter
- Handle the `env_kwargs` of `render:True` under the hood for panda-gym v1 envs in `enjoy` replay to match visualzation behavior of other envs
- Added support for python config file

### Bug fixes
- Allow `python -m rl_zoo3.cli` to be called directly
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py

# Run pytest and coverage report
pytest:
Expand Down
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,23 @@ python train.py --algo algo_name --env env_id
```
You can use `-P` (`--progress`) option to display a progress bar.

Using a custom yaml file (which contains a `env_id` entry):
Using a custom config file when it is a yaml file with a which contains a `env_id` entry:
```
python train.py --algo algo_name --env env_id --yaml-file my_yaml.yml
python train.py --algo algo_name --env env_id --conf-file my_yaml.yml
```

You can also use a python file that contains a dictionary called `hyperparams` with an entry for each `env_id`.
(see `hyperparams/python/ppo_config_example.py` for an example)
```bash
# You can pass a path to a python file
python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams/python/ppo_config_example.py
# Or pass a path to a file from a module (for instance my_package.my_file
python train.py --algo ppo --env MountainCarContinuous-v0 --conf-file hyperparams.python.ppo_config_example
```
The advantage of this approach is that you can specify arbitrary python dictionaries
and ensure that all their dependencies are imported in the config file itself.


For example (with tensorboard support):
```
python train.py --algo ppo --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/
Expand Down Expand Up @@ -139,7 +151,7 @@ Remark: plotting with the `--rliable` option is usually slow as confidence inter

## Custom Environment

The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--yaml-file` argument).
The easiest way to add support for a custom environment is to edit `rl_zoo3/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml` or a custom yaml file that you can specify using `--conf-file` argument).

## Enjoy a Trained Agent

Expand Down
29 changes: 29 additions & 0 deletions hyperparams/python/ppo_config_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""This file just serves as an example on how to configure the zoo
using python scripts instead of yaml files."""
import torch

hyperparams = {
"MountainCarContinuous-v0": dict(
env_wrapper=[{"gym.wrappers.TimeLimit": {"max_episode_steps": 100}}],
normalize=True,
n_envs=1,
n_timesteps=20000.0,
policy="MlpPolicy",
batch_size=8,
n_steps=8,
gamma=0.9999,
learning_rate=7.77e-05,
ent_coef=0.00429,
clip_range=0.1,
n_epochs=2,
gae_lambda=0.9,
max_grad_norm=5,
vf_coef=0.19,
use_sde=True,
policy_kwargs=dict(
log_std_init=-3.29,
ortho_init=False,
activation_fn=torch.nn.ReLU,
),
)
}
41 changes: 29 additions & 12 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import importlib
import os
import pickle as pkl
import time
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
n_eval_envs: int = 1,
no_optim_plots: bool = False,
device: Union[th.device, str] = "auto",
yaml_file: Optional[str] = None,
config: Optional[str] = None,
show_progress: bool = False,
):
super().__init__()
Expand All @@ -108,7 +109,7 @@ def __init__(
# Take the root folder
default_path = Path(__file__).parent.parent

self.yaml_file = yaml_file or str(default_path / f"hyperparams/{self.algo}.yml")
self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
self.env_kwargs = {} if env_kwargs is None else env_kwargs
self.n_timesteps = n_timesteps
self.normalize = False
Expand Down Expand Up @@ -281,16 +282,28 @@ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
print(f"Log path: {self.save_path}")

def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Load hyperparameters from yaml file
print(f"Loading hyperparameters from: {self.yaml_file}")
with open(self.yaml_file) as f:
hyperparams_dict = yaml.safe_load(f)
if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
print(f"Loading hyperparameters from: {self.config}")

if self.config.endswith(".yml") or self.config.endswith(".yaml"):
# Load hyperparameters from yaml file
with open(self.config) as f:
hyperparams_dict = yaml.safe_load(f)
elif self.config.endswith(".py"):
global_variables = {}
# Load hyperparameters from python file
exec(Path(self.config).read_text(), global_variables)
hyperparams_dict = global_variables["hyperparams"]
else:
# Load hyperparameters from python package
hyperparams_dict = importlib.import_module(self.config).hyperparams
# raise ValueError(f"Unsupported config file format: {self.config}")

if self.env_name.gym_id in list(hyperparams_dict.keys()):
hyperparams = hyperparams_dict[self.env_name.gym_id]
elif self._is_atari:
hyperparams = hyperparams_dict["atari"]
else:
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id} in {self.config}")

if self.custom_hyperparams is not None:
# Overwrite hyperparams if needed
Expand Down Expand Up @@ -336,6 +349,10 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An
self.normalize_kwargs = eval(self.normalize)
self.normalize = True

if isinstance(self.normalize, dict):
self.normalize_kwargs = self.normalize
self.normalize = True

# Use the same discount factor as for the algorithm
if "gamma" in hyperparams:
self.normalize_kwargs["gamma"] = hyperparams["gamma"]
Expand Down
21 changes: 19 additions & 2 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,19 @@ def train():
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-yaml", "--yaml-file", type=str, default=None, help="Custom yaml file from which the hyperparameters will be loaded"
"-conf",
"--conf-file",
type=str,
default=None,
help="Custom yaml file or python package from which the hyperparameters will be loaded."
"We expect that python packages contain a dictionary called 'hyperparams' which contains a key for each environment.",
)
parser.add_argument(
"-yaml",
"--yaml-file",
type=str,
default=None,
help="This parameter is deprecated, please use `--conf-file` instead",
)
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
parser.add_argument(
Expand Down Expand Up @@ -150,6 +162,11 @@ def train():
env_id = args.env
registered_envs = set(gym.envs.registry.env_specs.keys()) # pytype: disable=module-attr

if args.yaml_file is not None:
raise ValueError(
"The`--yaml-file` parameter is deprecated and will be removed in RL Zoo3 v1.8, please use `--conf-file` instead",
)

# If the environment is not found, suggest the closest match
if env_id not in registered_envs:
try:
Expand Down Expand Up @@ -234,7 +251,7 @@ def train():
n_eval_envs=args.n_eval_envs,
no_optim_plots=args.no_optim_plots,
device=args.device,
yaml_file=args.yaml_file,
config=args.conf_file,
show_progress=args.progress,
)

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a2
1.7.0a3
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ markers =
inputs = .

[flake8]
ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
# Ignore import not used when aliases are defined
per-file-ignores =
./rl_zoo3/__init__.py:F401
Expand Down
22 changes: 21 additions & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_custom_yaml(tmp_path):
"CartPole-v1",
"--log-folder",
tmp_path,
"-yaml",
"-conf",
"hyperparams/a2c.yml",
"-params",
"n_envs:2",
Expand All @@ -129,3 +129,23 @@ def test_custom_yaml(tmp_path):

return_code = subprocess.call(["python", "train.py"] + args)
_assert_eq(return_code, 0)


@pytest.mark.parametrize("config_file", ["hyperparams.python.ppo_config_example", "hyperparams/python/ppo_config_example.py"])
def test_python_config_file(tmp_path, config_file):
# Use the example python config file for training
args = [
"-n",
str(N_STEPS),
"--algo",
"ppo",
"--env",
"MountainCarContinuous-v0",
"--log-folder",
tmp_path,
"-conf",
config_file,
]

return_code = subprocess.call(["python", "train.py"] + args)
_assert_eq(return_code, 0)

0 comments on commit 1f9cfcf

Please sign in to comment.