tags: | |
- deep-reinforcement-learning | |
- reinforcement-learning | |
Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer). | |
Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory. | |
We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds. | |
### Usage | |
``` | |
git clone https://huggingface.co/edbeeching/decision_transformer_atari | |
conda env create -f conda_env.yml | |
``` | |
Then, you can use the model like this: | |
```python | |
import torch | |
from decision_transformer_atari import GPTConfig, GPT | |
vocab_size = 4 | |
block_size = 90 | |
model_type = "reward_conditioned" | |
timesteps = 2654 | |
mconf = GPTConfig( | |
vocab_size, | |
block_size, | |
n_layer=6, | |
n_head=8, | |
n_embd=128, | |
model_type=model_type, | |
max_timestep=timesteps, | |
) | |
model = GPT(mconf) | |
checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint) | |
``` | |