Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added hyperparameter tuning for RecurrentPPO #415

Merged
merged 9 commits into from
Oct 28, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

### New Features
- Add `--eval-env-kwargs` to `train.py` (@Quentin18)
- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13)

### Bug fixes

Expand Down
26 changes: 25 additions & 1 deletion rl_zoo3/hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5])
vf_coef = trial.suggest_float("vf_coef", 0, 1)
net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])
net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])
# Uncomment for gSDE (continuous actions)
# log_std_init = trial.suggest_float("log_std_init", -4, 1)
# Uncomment for gSDE (continuous action)
Expand All @@ -49,6 +49,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
# Independent networks usually work best
# when not working with images
net_arch = {
"tiny": dict(pi=[64], vf=[64]),
"small": dict(pi=[64, 64], vf=[64, 64]),
"medium": dict(pi=[256, 256], vf=[256, 256]),
}[net_arch]
Expand Down Expand Up @@ -76,6 +77,28 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
}


def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for RecurrentPPO hyperparams.
uses sample_ppo_params(), this function samples for the policy_kwargs
:param trial:
:return:
"""
hyperparams = sample_ppo_params(trial)

enable_critic_lstm = trial.suggest_categorical("enable_critic_lstm", [False, True])
lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [16, 32, 64, 128, 256, 512])

hyperparams["policy_kwargs"].update(
{
"enable_critic_lstm": enable_critic_lstm,
"lstm_hidden_size": lstm_hidden_size,
}
)

return hyperparams


def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for TRPO hyperparams.
Expand Down Expand Up @@ -527,6 +550,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
"sac": sample_sac_params,
"tqc": sample_tqc_params,
"ppo": sample_ppo_params,
"ppo_lstm": sample_ppo_lstm_params,
"td3": sample_td3_params,
"trpo": sample_trpo_params,
}
2 changes: 2 additions & 0 deletions tests/test_hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def _assert_eq(left, right):
experiments["tqc-parking-v0"] = ("tqc", "parking-v0")
# Test for TQC
experiments["tqc-Pendulum-v1"] = ("tqc", "Pendulum-v1")
# Test for RecurrentPPO (ppo_lstm)
experiments["ppo_lstm-CartPoleNoVel-v1"] = ("ppo_lstm", "CartPoleNoVel-v1")


@pytest.mark.parametrize("sampler", ["random", "tpe"])
Expand Down
Loading