This is the official code for the paper (Section 5.1 & 5.2: discrete control)
"When Do Transformers Shine in RL? Decoupling Memory from Credit Assignment", NeurIPS 2023 (oral)
by Tianwei Ni, Michel Ma, Benjamin Eysenbach, and Pierre-Luc Bacon.
Please switch to the branch to check the code for Section 5.3 (PyBullet continuous control).
The code has a modular design which requires three configuration files. We hope that such design could facilitate future research on different environments, RL algorithms, and sequence models.
config_env
: specify the environment, withconfig_env.env_name
specifying the exact (memory / credit assignment) length of the task- Passive T-Maze (this work)
- Active T-Maze (this work)
- Passive Visual Match (based on [Hung et al., 2018])
- Key-to-Door (based on [Raposo et al., 2021])
config_rl
: specify the RL algorithm and its hyperparameters- DQN (with epsilon greedy)
- SAC-Discrete (we find
--freeze_critic
can prevent gradient explosion, see the discussion in Appendix C.1 in the latest version of the arXiv paper).
config_seq
: specify the sequence model and its hyperparameters including training sequence lengthconfig_seq.sampled_seq_len
and number of layers--config_seq.model.seq_model_config.n_layer
- LSTM [Hochreiter and Schmidhuber, 1997]
- Transformer (GPT-2) [Radford et al., 2019]
We use python 3.7+ and list the basic requirements in requirements.txt
.
Below are example commands to reproduce the main results shown in Figure 3 and 6. For the ablation results, please adjust the corresponding hyperparameters.
To run Passive T-Maze with a memory length of 50 with LSTM-based agent:
python main.py \
--config_env configs/envs/tmaze_passive.py \
--config_env.env_name 50 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 20000 \
--config_seq configs/seq_models/lstm_default.py \
--config_seq.sampled_seq_len -1 \
To run Passive T-Maze with a memory length of 1500 with Transformer-based agent:
python main.py \
--config_env configs/envs/tmaze_passive.py \
--config_env.env_name 1500 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 6700 \
--config_seq configs/seq_models/gpt_default.py \
--config_seq.sampled_seq_len -1 \
To run Active T-Maze with a memory length of 20 with Transformer-based agent:
python main.py \
--config_env configs/envs/tmaze_active.py \
--config_env.env_name 20 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_default.py \
--config_seq.sampled_seq_len -1 \
--config_seq.model.seq_model_config.n_layer 2 \
--config_seq.model.seq_model_config.n_head 2 \
To run Passive Visual Match with a memory length of 60 with Transformer-based agent:
python main.py \
--config_env configs/envs/visual_match.py \
--config_env.env_name 60 \
--config_rl configs/rl/sacd_default.py \
--shared_encoder --freeze_critic \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_cnn.py \
--config_seq.sampled_seq_len -1 \
To run Key-to-Door with a memory length of 120 with LSTM-based agent:
python main.py \
--config_env configs/envs/keytodoor.py \
--config_env.env_name 120 \
--config_rl configs/rl/sacd_default.py \
--shared_encoder --freeze_critic \
--train_episodes 40000 \
--config_seq configs/seq_models/lstm_cnn.py \
--config_seq.sampled_seq_len -1 \
--config_seq.model.seq_model_config.n_layer 2 \
To run Key-to-Door with a memory length of 250 with Transformer-based agent:
python main.py \
--config_env configs/envs/visual_match.py \
--config_env.env_name 250 \
--config_rl configs/rl/sacd_default.py \
--shared_encoder --freeze_critic \
--train_episodes 30000 \
--config_seq configs/seq_models/gpt_cnn.py \
--config_seq.sampled_seq_len -1 \
--config_seq.model.seq_model_config.n_layer 2 \
--config_seq.model.seq_model_config.n_head 2 \
The train_episodes
of each task is specified in budget.py
.
By default, the logging data will be stored in logs/
folder with csv format. If you use --debug
flag, it will be stored in debug/
folder.
After the logging data is stored, you can plot the learning curves and aggregation plots (e.g., Figure 3 and 6) using vis.ipynb
jupyter notebook.
We also provide our logging data used in the paper shared in google drive (< 400 MB).
The code is largely based on prior works:
If you have any questions, please raise an issue (preferred) or send an email to Tianwei (tianwei.ni@mila.quebec).