masterdezign commited on
Commit
6c123e0
1 Parent(s): 962a215

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -5
README.md CHANGED
@@ -25,12 +25,48 @@ This is a trained model of a **DQN** agent playing **SpaceInvadersNoFrameskip-v4
25
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
26
 
27
  ## Usage (with Stable-baselines3)
28
- TODO: Add your code
29
-
30
 
31
  ```python
32
- from stable_baselines3 import ...
33
- from huggingface_sb3 import load_from_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- ...
 
36
  ```
 
25
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
26
 
27
  ## Usage (with Stable-baselines3)
 
 
28
 
29
  ```python
30
+ from stable_baselines3.common.env_util import make_atari_env
31
+ from stable_baselines3.common.vec_env import VecFrameStack
32
+ from stable_baselines3 import DQN
33
+ from stable_baselines3.common.evaluation import evaluate_policy
34
+ from huggingface_sb3 import load_from_hub, package_to_hub
35
+ from stable_baselines3.common.utils import set_random_seed
36
+
37
+ env_id = "SpaceInvadersNoFrameskip-v4"
38
+
39
+ env = make_atari_env(env_id,
40
+ n_envs=12,
41
+ # Improving reproducibility
42
+ seed=1)
43
+ env = VecFrameStack(env, n_stack=4) # Stack last four images
44
+
45
+ # Improving reproducibility
46
+ set_random_seed(42)
47
+
48
+ # Using these parameters as default: https://huggingface.co/micheljperez/dqn-SpaceInvadersNoFrameskip-v4
49
+ model = DQN(policy = "CnnPolicy",
50
+ env = env,
51
+ batch_size = 32,
52
+ buffer_size = 100_000,
53
+ exploration_final_eps = 0.01,
54
+ exploration_fraction = 0.025,
55
+ gradient_steps = 1,
56
+ learning_rate = 1e-4,
57
+ learning_starts = 100_000,
58
+ optimize_memory_usage = True,
59
+ replay_buffer_kwargs = {"handle_timeout_termination": False},
60
+ target_update_interval = 1000,
61
+ train_freq = 4,
62
+ # normalize = False,
63
+ tensorboard_log = "./tensorboard",
64
+ verbose=1
65
+ )
66
+
67
+ f = load_from_hub('masterdezign/dqn-SpaceInvadersNoFrameskip-v4', 'dqn-SpaceInvadersNoFrameskip-v4.zip')
68
+ model = model.load(f)
69
 
70
+ mean_reward, std_reward = evaluate_policy(model, env)
71
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
72
  ```