This is the official code for the paper
"Bridging State and History Representations: Understanding Self-Predictive RL", ICLR 2024
by Tianwei Ni, Benjamin Eysenbach, Erfan Seyedsalehi, Michel Ma, Clement Gehring, Aditya Mahajan, and Pierre-Luc Bacon.
In this work, we demonstrate a principled, minimal, and effective design, as reflected in the following pseudocode:
def total_loss(hist, act, next_obs, rew):
"""
Compute the total loss for learning one of the three abstractions.
Args: Batch of transition data (h, a, o', r).
hist h: (B, T, O+A), act a: (B, A), next_obs o': (B, O), rew r: (B, 1)
"""
# Encode current history into a latent state
h_enc = Encoder(hist) # z: (B, Z)
next_hist = torch.cat([hist, torch.cat([act, next_obs], dim=-1)], dim=1) # h': (B, T+1, O+A)
# Encode next history into a latent state using an EMA encoder
next_h_enc_tar = Encoder_Target(next_hist) # z': (B, Z)
# Model-free RL loss in the latent state space (e.g., TD3, R2D2)
rl_loss = RL_loss(h_enc, act, next_h_enc_tar, rew) # (z, a, z', r)
if [learning Q^*-irrelevance representations]: # model-free RL
return rl_loss
elif [learning self-predictive representations]: # l2 loss with EMA ZP target
zp_loss = ((Latent_Model(h_enc, act) - next_h_enc_tar)**2).sum(-1).mean()
return rl_loss + coef * zp_loss
elif [learning observation-predictive representations]: # l2 loss
op_loss = ((Observ_Model(h_enc, act) - next_obs)**2).sum(-1).mean()
return rl_loss + coef * op_loss
In deep RL, numerous representation learning methods have been proposed, ranging from state representations for MDPs to history representations for POMDPs. However, these methods often involve different learning objectives and training techniques, making it challenging for RL practitioners to select the most suitable approach for their specific problems.
This work unifies various representation learning methods by analyzing their objectives and ideal abstractions. Surprisingly, these methods are connected by a self-predictive condition, termed the ZP condition: the latent state generated by the encoder can be used to predict the next latent state. We summarize three abstractions learned by these methods and provide examples of popular instances:
-
$Q^*$ -irrelevance abstraction: purely maximizes returns. Examples: model-free RL (cleanrl), recurrent model-free RL. - Self-predictive abstraction: involves the self-predictive (ZP) and reward-prediction (RP) conditions. Examples: SPR, DBC, TD-MPC, EfficientZero.
- Observation-predictive abstraction: involves the observation-predictive (OP) and reward-prediction (RP) conditions. Examples: Dreamer, SLAC, SAC-AE.
In our paper, we establish how the ZP condition connects the three abstractions. Crucially, we investigate the training objectives for learning ZP, including widely-used
These analyses lead to the development of our minimalist algorithm for learning self-predictive abstraction. We provide the code as a baseline for future research, believing it to be:
- Principled in representation learning: targets each of the three abstractions.
- Minimal in algorithmic design: uses single auxiliary task for representation learning (just one extra loss), and model-free policy optimization (no planning).
-
Effective in practice: our implementation of self-predictive representations outperforms
$Q^*$ -irrelevance abstraction (the model-free baseline), and is more robust to distractions than observation-predictive representations.
mujoco_code/
: contains the code on standard MDPs (Section 5.1) and distracting MDPs (Section 5.2) using MuJoCo simulators.minigrid_code/
: contains the code on sparse-reward POMDPs (Section 5.3) using MiniGrid environments.linear_code/
: contains the code for illustrating our theorem on stop-gradient to prevent collapse (Section 4.2).
Here we restate our preliminary recommendations from our paper (Section 6):
- Analyze your task first. For example, in noisy or distracting tasks, consider using self-predictive representations. In sparse-reward tasks, consider using observation-predictive representations. In deterministic tasks, choose the deterministic
$\ell_2$ objectives for representation learning. - Use our minimalist algorithm as your baseline. Our algorithm allows for an independent evaluation of representation learning and policy optimization effects. Start with end-to-end learning and model-free RL for policy optimization.
- Implementation tips. For our minimalist algorithm, we recommend adopting the
$\ell_2$ objective with EMA ZP targets first. When tackling POMDPs, start with recurrent networks as the encoder.
If you have any questions, please raise an issue (preferred) or send an email to Tianwei (tianwei.ni@mila.quebec).
@inproceedings{ni2024bridging,
title={Bridging State and History Representations: Understanding Self-Predictive RL},
author={Ni, Tianwei and Eysenbach, Benjamin and Seyedsalehi, Erfan and Ma, Michel and Gehring, Clement and Mahajan, Aditya and Bacon, Pierre-Luc},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}