Edward Beeching commited on
Commit
d7c6a2a
1 Parent(s): 7978894

Updated README

Browse files
Files changed (1) hide show
  1. README.md +22 -3
README.md CHANGED
@@ -13,6 +13,7 @@ We share models trained for one seed (123), whereas the paper contained weights
13
  ### Usage
14
 
15
  ```
 
16
  conda env create -f conda_env.yml
17
  ```
18
 
@@ -20,7 +21,25 @@ Then, you can use the model like this:
20
 
21
  ```python
22
 
23
-
24
-
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ```
 
13
  ### Usage
14
 
15
  ```
16
+ git clone https://huggingface.co/edbeeching/decision_transformer_atari
17
  conda env create -f conda_env.yml
18
  ```
19
 
 
21
 
22
  ```python
23
 
24
+ from decision_transform_atari import GPTConfig, GPT
25
+
26
+ vocab_size = 4
27
+ block_size = 90
28
+ model_type = "reward_conditioned"
29
+ timesteps = 2654
30
+
31
+ mconf = GPTConfig(
32
+ vocab_size,
33
+ block_size,
34
+ n_layer=6,
35
+ n_head=8,
36
+ n_embd=128,
37
+ model_type=model_type,
38
+ max_timestep=timesteps,
39
+ )
40
+ model = GPT(mconf)
41
+
42
+ checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest
43
+ checkpoint = torch.load(checkpoint_path)
44
+ model.load_state_dict(checkpoint)
45
  ```