vwxyzjn commited on
Commit
3ca4ec6
1 Parent(s): ca45544

pushing model

Browse files
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - CartPole-v1
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ model-index:
8
+ - name: DQN
9
+ results:
10
+ - task:
11
+ type: reinforcement-learning
12
+ name: reinforcement-learning
13
+ dataset:
14
+ name: CartPole-v1
15
+ type: CartPole-v1
16
+ metrics:
17
+ - type: mean_reward
18
+ value: 76.10 +/- 28.07
19
+ name: mean_reward
20
+ verified: false
21
+ ---
22
+
23
+ # (CleanRL) **DQN** Agent Playing **CartPole-v1**
24
+
25
+ This is a trained model of a DQN agent playing CartPole-v1.
26
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
27
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py).
28
+
29
+ ## Command to reproduce the training
30
+
31
+ ```bash
32
+ curl -OL https://huggingface.co/cleanrl/CartPole-v1-dqn-seed1/raw/main/dqn.py
33
+ curl -OL https://huggingface.co/cleanrl/CartPole-v1-dqn-seed1/raw/main/pyproject.toml
34
+ curl -OL https://huggingface.co/cleanrl/CartPole-v1-dqn-seed1/raw/main/poetry.lock
35
+ poetry install --all-extras
36
+ python dqn.py --save-model --upload-model --hf-entity cleanrl --cuda False --total-timesteps 1000
37
+ ```
38
+
39
+ # Hyperparameters
40
+ ```python
41
+ {'batch_size': 128,
42
+ 'buffer_size': 10000,
43
+ 'capture_video': False,
44
+ 'cuda': False,
45
+ 'end_e': 0.05,
46
+ 'env_id': 'CartPole-v1',
47
+ 'exp_name': 'dqn',
48
+ 'exploration_fraction': 0.5,
49
+ 'gamma': 0.99,
50
+ 'hf_entity': 'cleanrl',
51
+ 'learning_rate': 0.00025,
52
+ 'learning_starts': 10000,
53
+ 'save_model': True,
54
+ 'seed': 1,
55
+ 'start_e': 1,
56
+ 'target_network_frequency': 500,
57
+ 'torch_deterministic': True,
58
+ 'total_timesteps': 1000,
59
+ 'track': False,
60
+ 'train_frequency': 10,
61
+ 'upload_model': True,
62
+ 'wandb_entity': None,
63
+ 'wandb_project_name': 'cleanRL'}
64
+ ```
65
+
dqn.cleanrl_model ADDED
Binary file (45.8 kB). View file
 
dqn.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
2
+ import argparse
3
+ import os
4
+ import random
5
+ import time
6
+ from distutils.util import strtobool
7
+
8
+ import gym
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ from stable_baselines3.common.buffers import ReplayBuffer
15
+ from torch.utils.tensorboard import SummaryWriter
16
+
17
+
18
+ def parse_args():
19
+ # fmt: off
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
22
+ help="the name of this experiment")
23
+ parser.add_argument("--seed", type=int, default=1,
24
+ help="seed of the experiment")
25
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
26
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
27
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
28
+ help="if toggled, cuda will be enabled by default")
29
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
30
+ help="if toggled, this experiment will be tracked with Weights and Biases")
31
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
32
+ help="the wandb's project name")
33
+ parser.add_argument("--wandb-entity", type=str, default=None,
34
+ help="the entity (team) of wandb's project")
35
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
36
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
37
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
38
+ help="whether to save model into the `runs/{run_name}` folder")
39
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
40
+ help="whether to upload the saved model to huggingface")
41
+ parser.add_argument("--hf-entity", type=str, default="",
42
+ help="the user or org name of the model repository from the Hugging Face Hub")
43
+
44
+ # Algorithm specific arguments
45
+ parser.add_argument("--env-id", type=str, default="CartPole-v1",
46
+ help="the id of the environment")
47
+ parser.add_argument("--total-timesteps", type=int, default=500000,
48
+ help="total timesteps of the experiments")
49
+ parser.add_argument("--learning-rate", type=float, default=2.5e-4,
50
+ help="the learning rate of the optimizer")
51
+ parser.add_argument("--buffer-size", type=int, default=10000,
52
+ help="the replay memory buffer size")
53
+ parser.add_argument("--gamma", type=float, default=0.99,
54
+ help="the discount factor gamma")
55
+ parser.add_argument("--target-network-frequency", type=int, default=500,
56
+ help="the timesteps it takes to update the target network")
57
+ parser.add_argument("--batch-size", type=int, default=128,
58
+ help="the batch size of sample from the reply memory")
59
+ parser.add_argument("--start-e", type=float, default=1,
60
+ help="the starting epsilon for exploration")
61
+ parser.add_argument("--end-e", type=float, default=0.05,
62
+ help="the ending epsilon for exploration")
63
+ parser.add_argument("--exploration-fraction", type=float, default=0.5,
64
+ help="the fraction of `total-timesteps` it takes from start-e to go end-e")
65
+ parser.add_argument("--learning-starts", type=int, default=10000,
66
+ help="timestep to start learning")
67
+ parser.add_argument("--train-frequency", type=int, default=10,
68
+ help="the frequency of training")
69
+ args = parser.parse_args()
70
+ # fmt: on
71
+ return args
72
+
73
+
74
+ def make_env(env_id, seed, idx, capture_video, run_name):
75
+ def thunk():
76
+ env = gym.make(env_id)
77
+ env = gym.wrappers.RecordEpisodeStatistics(env)
78
+ if capture_video:
79
+ if idx == 0:
80
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
81
+ env.seed(seed)
82
+ env.action_space.seed(seed)
83
+ env.observation_space.seed(seed)
84
+ return env
85
+
86
+ return thunk
87
+
88
+
89
+ # ALGO LOGIC: initialize agent here:
90
+ class QNetwork(nn.Module):
91
+ def __init__(self, env):
92
+ super().__init__()
93
+ self.network = nn.Sequential(
94
+ nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
95
+ nn.ReLU(),
96
+ nn.Linear(120, 84),
97
+ nn.ReLU(),
98
+ nn.Linear(84, env.single_action_space.n),
99
+ )
100
+
101
+ def forward(self, x):
102
+ return self.network(x)
103
+
104
+
105
+ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
106
+ slope = (end_e - start_e) / duration
107
+ return max(slope * t + start_e, end_e)
108
+
109
+
110
+ if __name__ == "__main__":
111
+ args = parse_args()
112
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
113
+ if args.track:
114
+ import wandb
115
+
116
+ wandb.init(
117
+ project=args.wandb_project_name,
118
+ entity=args.wandb_entity,
119
+ sync_tensorboard=True,
120
+ config=vars(args),
121
+ name=run_name,
122
+ monitor_gym=True,
123
+ save_code=True,
124
+ )
125
+ writer = SummaryWriter(f"runs/{run_name}")
126
+ writer.add_text(
127
+ "hyperparameters",
128
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
129
+ )
130
+
131
+ # TRY NOT TO MODIFY: seeding
132
+ random.seed(args.seed)
133
+ np.random.seed(args.seed)
134
+ torch.manual_seed(args.seed)
135
+ torch.backends.cudnn.deterministic = args.torch_deterministic
136
+
137
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
138
+
139
+ # env setup
140
+ envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
141
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
142
+
143
+ q_network = QNetwork(envs).to(device)
144
+ optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
145
+ target_network = QNetwork(envs).to(device)
146
+ target_network.load_state_dict(q_network.state_dict())
147
+
148
+ rb = ReplayBuffer(
149
+ args.buffer_size,
150
+ envs.single_observation_space,
151
+ envs.single_action_space,
152
+ device,
153
+ handle_timeout_termination=True,
154
+ )
155
+ start_time = time.time()
156
+
157
+ # TRY NOT TO MODIFY: start the game
158
+ obs = envs.reset()
159
+ for global_step in range(args.total_timesteps):
160
+ # ALGO LOGIC: put action logic here
161
+ epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
162
+ if random.random() < epsilon:
163
+ actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
164
+ else:
165
+ q_values = q_network(torch.Tensor(obs).to(device))
166
+ actions = torch.argmax(q_values, dim=1).cpu().numpy()
167
+
168
+ # TRY NOT TO MODIFY: execute the game and log data.
169
+ next_obs, rewards, dones, infos = envs.step(actions)
170
+
171
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
172
+ for info in infos:
173
+ if "episode" in info.keys():
174
+ print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
175
+ writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
176
+ writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
177
+ writer.add_scalar("charts/epsilon", epsilon, global_step)
178
+ break
179
+
180
+ # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
181
+ real_next_obs = next_obs.copy()
182
+ for idx, d in enumerate(dones):
183
+ if d:
184
+ real_next_obs[idx] = infos[idx]["terminal_observation"]
185
+ rb.add(obs, real_next_obs, actions, rewards, dones, infos)
186
+
187
+ # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
188
+ obs = next_obs
189
+
190
+ # ALGO LOGIC: training.
191
+ if global_step > args.learning_starts and global_step % args.train_frequency == 0:
192
+ data = rb.sample(args.batch_size)
193
+ with torch.no_grad():
194
+ target_max, _ = target_network(data.next_observations).max(dim=1)
195
+ td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
196
+ old_val = q_network(data.observations).gather(1, data.actions).squeeze()
197
+ loss = F.mse_loss(td_target, old_val)
198
+
199
+ if global_step % 100 == 0:
200
+ writer.add_scalar("losses/td_loss", loss, global_step)
201
+ writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
202
+ print("SPS:", int(global_step / (time.time() - start_time)))
203
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
204
+
205
+ # optimize the model
206
+ optimizer.zero_grad()
207
+ loss.backward()
208
+ optimizer.step()
209
+
210
+ # update the target network
211
+ if global_step % args.target_network_frequency == 0:
212
+ target_network.load_state_dict(q_network.state_dict())
213
+
214
+ if args.save_model:
215
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
216
+ torch.save(q_network.state_dict(), model_path)
217
+ print(f"model saved to {model_path}")
218
+ from cleanrl_utils.evals.dqn_eval import evaluate
219
+
220
+ episodic_returns = evaluate(
221
+ model_path,
222
+ make_env,
223
+ args.env_id,
224
+ eval_episodes=10,
225
+ run_name=f"{run_name}-eval",
226
+ Model=QNetwork,
227
+ device=device,
228
+ epsilon=0.05,
229
+ )
230
+ for idx, episodic_return in enumerate(episodic_returns):
231
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
232
+
233
+ if args.upload_model:
234
+ from cleanrl_utils.huggingface import push_to_hub
235
+
236
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
237
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
238
+ push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")
239
+
240
+ envs.close()
241
+ writer.close()
events.out.tfevents.1668715831.pop-os.2878860.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56e216780ef8b110afd3e68efe3546dc2b33ce840f1a4da347daed064dd9b280
3
+ size 4310
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanrl"
3
+ version = "1.0.0"
4
+ description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
5
+ authors = ["Costa Huang <costa.huang@outlook.com>"]
6
+ include = ["cleanrl_utils"]
7
+ keywords = ["reinforcement", "machine", "learning", "research"]
8
+ license="MIT"
9
+ readme = "README.md"
10
+
11
+ [tool.poetry.dependencies]
12
+ python = ">=3.7.1,<3.10"
13
+ tensorboard = "^2.10.0"
14
+ wandb = "^0.13.3"
15
+ gym = {version = "0.23.1", extras = ["classic_control"]}
16
+ torch = "^1.12.1"
17
+ stable-baselines3 = "1.2.0"
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ pre-commit = "^2.20.0"
21
+
22
+ [tool.poetry.group.atari]
23
+ optional = true
24
+ [tool.poetry.group.atari.dependencies]
25
+ ale-py = "0.7.4"
26
+ AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
27
+ opencv-python = "^4.6.0.66"
28
+
29
+ [tool.poetry.group.pybullet]
30
+ optional = true
31
+ [tool.poetry.group.pybullet.dependencies]
32
+ pybullet = "3.1.8"
33
+
34
+ [tool.poetry.group.procgen]
35
+ optional = true
36
+ [tool.poetry.group.procgen.dependencies]
37
+ procgen = "^0.10.7"
38
+
39
+ [tool.poetry.group.pytest]
40
+ optional = true
41
+ [tool.poetry.group.pytest.dependencies]
42
+ pytest = "^7.1.3"
43
+
44
+ [tool.poetry.group.mujoco]
45
+ optional = true
46
+ [tool.poetry.group.mujoco.dependencies]
47
+ free-mujoco-py = "^2.1.6"
48
+
49
+ [tool.poetry.group.docs]
50
+ optional = true
51
+ [tool.poetry.group.docs.dependencies]
52
+ mkdocs-material = "^8.4.3"
53
+ markdown-include = "^0.7.0"
54
+
55
+ [tool.poetry.group.jax]
56
+ optional = true
57
+ [tool.poetry.group.jax.dependencies]
58
+ jax = "^0.3.17"
59
+ jaxlib = "^0.3.15"
60
+ flax = "^0.6.0"
61
+
62
+ [tool.poetry.group.optuna]
63
+ optional = true
64
+ [tool.poetry.group.optuna.dependencies]
65
+ optuna = "^3.0.1"
66
+ optuna-dashboard = "^0.7.2"
67
+ rich = "<12.0"
68
+
69
+ [tool.poetry.group.envpool]
70
+ optional = true
71
+ [tool.poetry.group.envpool.dependencies]
72
+ envpool = "^0.6.4"
73
+
74
+ [tool.poetry.group.pettingzoo]
75
+ optional = true
76
+ [tool.poetry.group.pettingzoo.dependencies]
77
+ PettingZoo = "1.18.1"
78
+ SuperSuit = "3.4.0"
79
+ multi-agent-ale-py = "0.1.11"
80
+
81
+
82
+ [tool.poetry.group.cloud]
83
+ optional = true
84
+ [tool.poetry.group.cloud.dependencies]
85
+ boto3 = "^1.24.70"
86
+ awscli = "^1.25.71"
87
+
88
+ [tool.poetry.group.isaacgym]
89
+ optional = true
90
+ [tool.poetry.group.isaacgym.dependencies]
91
+ isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry"}
92
+ isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
93
+
94
+ [build-system]
95
+ requires = ["poetry-core"]
96
+ build-backend = "poetry.core.masonry.api"
replay.mp4 ADDED
Binary file (9.14 kB). View file
 
videos/CartPole-v1__dqn__1__1668715831-eval/rl-video-episode-0.mp4 ADDED
Binary file (17.7 kB). View file
 
videos/CartPole-v1__dqn__1__1668715831-eval/rl-video-episode-1.mp4 ADDED
Binary file (12.3 kB). View file
 
videos/CartPole-v1__dqn__1__1668715831-eval/rl-video-episode-8.mp4 ADDED
Binary file (9.14 kB). View file