diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c9d601e4..401515994 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ### New Features - Add `--eval-env-kwargs` to `train.py` (@Quentin18) +- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13) ### Bug fixes diff --git a/rl_zoo3/hyperparams_opt.py b/rl_zoo3/hyperparams_opt.py index 0bbd701f9..360734cd9 100644 --- a/rl_zoo3/hyperparams_opt.py +++ b/rl_zoo3/hyperparams_opt.py @@ -28,7 +28,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5]) vf_coef = trial.suggest_float("vf_coef", 0, 1) - net_arch = trial.suggest_categorical("net_arch", ["small", "medium"]) + net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"]) # Uncomment for gSDE (continuous actions) # log_std_init = trial.suggest_float("log_std_init", -4, 1) # Uncomment for gSDE (continuous action) @@ -49,6 +49,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: # Independent networks usually work best # when not working with images net_arch = { + "tiny": dict(pi=[64], vf=[64]), "small": dict(pi=[64, 64], vf=[64, 64]), "medium": dict(pi=[256, 256], vf=[256, 256]), }[net_arch] @@ -76,6 +77,28 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: } +def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]: + """ + Sampler for RecurrentPPO hyperparams. + uses sample_ppo_params(), this function samples for the policy_kwargs + :param trial: + :return: + """ + hyperparams = sample_ppo_params(trial) + + enable_critic_lstm = trial.suggest_categorical("enable_critic_lstm", [False, True]) + lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [16, 32, 64, 128, 256, 512]) + + hyperparams["policy_kwargs"].update( + { + "enable_critic_lstm": enable_critic_lstm, + "lstm_hidden_size": lstm_hidden_size, + } + ) + + return hyperparams + + def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]: """ Sampler for TRPO hyperparams. @@ -527,6 +550,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]: "sac": sample_sac_params, "tqc": sample_tqc_params, "ppo": sample_ppo_params, + "ppo_lstm": sample_ppo_lstm_params, "td3": sample_td3_params, "trpo": sample_trpo_params, } diff --git a/tests/test_hyperparams_opt.py b/tests/test_hyperparams_opt.py index 1fe82f7d2..c9d5e44c0 100644 --- a/tests/test_hyperparams_opt.py +++ b/tests/test_hyperparams_opt.py @@ -33,6 +33,8 @@ def _assert_eq(left, right): experiments["tqc-parking-v0"] = ("tqc", "parking-v0") # Test for TQC experiments["tqc-Pendulum-v1"] = ("tqc", "Pendulum-v1") +# Test for RecurrentPPO (ppo_lstm) +experiments["ppo_lstm-CartPoleNoVel-v1"] = ("ppo_lstm", "CartPoleNoVel-v1") @pytest.mark.parametrize("sampler", ["random", "tpe"])