DalasNoin commited on
Commit
4466677
1 Parent(s): d7c6a2a

usage example did not work because of typo and torch not imported

Browse files
Files changed (1) hide show
  1. README.md +45 -45
README.md CHANGED
@@ -1,45 +1,45 @@
1
- ---
2
- tags:
3
- - deep-reinforcement-learning
4
- - reinforcement-learning
5
-
6
- ---
7
-
8
- Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer).
9
- Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory.
10
- We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds.
11
-
12
-
13
- ### Usage
14
-
15
- ```
16
- git clone https://huggingface.co/edbeeching/decision_transformer_atari
17
- conda env create -f conda_env.yml
18
- ```
19
-
20
- Then, you can use the model like this:
21
-
22
- ```python
23
-
24
- from decision_transform_atari import GPTConfig, GPT
25
-
26
- vocab_size = 4
27
- block_size = 90
28
- model_type = "reward_conditioned"
29
- timesteps = 2654
30
-
31
- mconf = GPTConfig(
32
- vocab_size,
33
- block_size,
34
- n_layer=6,
35
- n_head=8,
36
- n_embd=128,
37
- model_type=model_type,
38
- max_timestep=timesteps,
39
- )
40
- model = GPT(mconf)
41
-
42
- checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest
43
- checkpoint = torch.load(checkpoint_path)
44
- model.load_state_dict(checkpoint)
45
- ```
 
1
+ ---
2
+ tags:
3
+ - deep-reinforcement-learning
4
+ - reinforcement-learning
5
+
6
+ ---
7
+
8
+ Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer).
9
+ Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory.
10
+ We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds.
11
+
12
+
13
+ ### Usage
14
+
15
+ ```
16
+ git clone https://huggingface.co/edbeeching/decision_transformer_atari
17
+ conda env create -f conda_env.yml
18
+ ```
19
+
20
+ Then, you can use the model like this:
21
+
22
+ ```python
23
+ import torch
24
+ from decision_transformer_atari import GPTConfig, GPT
25
+
26
+ vocab_size = 4
27
+ block_size = 90
28
+ model_type = "reward_conditioned"
29
+ timesteps = 2654
30
+
31
+ mconf = GPTConfig(
32
+ vocab_size,
33
+ block_size,
34
+ n_layer=6,
35
+ n_head=8,
36
+ n_embd=128,
37
+ model_type=model_type,
38
+ max_timestep=timesteps,
39
+ )
40
+ model = GPT(mconf)
41
+
42
+ checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest
43
+ checkpoint = torch.load(checkpoint_path)
44
+ model.load_state_dict(checkpoint)
45
+ ```