alperenunlu commited on
Commit
0ce1450
·
verified ·
1 Parent(s): 27f1113

Push model

Browse files
Files changed (26) hide show
  1. README.md +1 -1
  2. a2c.py +11 -6
  3. events.out.tfevents.1757726955.Alperens-MBP.local.88268.0 → events.out.tfevents.1757936575.Alperens-MBP.local.22547.0 +2 -2
  4. hyperparameters.json +1 -1
  5. pyproject.toml +1 -1
  6. replay.mp4 +0 -0
  7. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-0.mp4 +0 -0
  8. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-1.mp4 +0 -0
  9. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-2.mp4 +0 -0
  10. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-3.mp4 +0 -0
  11. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-4.mp4 +0 -0
  12. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-5.mp4 +0 -0
  13. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-6.mp4 +0 -0
  14. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-7.mp4 +0 -0
  15. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-8.mp4 +0 -0
  16. videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-9.mp4 +0 -0
  17. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-0.mp4 +0 -0
  18. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-1.mp4 +0 -0
  19. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-2.mp4 +0 -0
  20. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-3.mp4 +0 -0
  21. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-4.mp4 +0 -0
  22. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-5.mp4 +0 -0
  23. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-6.mp4 +0 -0
  24. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-7.mp4 +0 -0
  25. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-8.mp4 +0 -0
  26. videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-9.mp4 +0 -0
README.md CHANGED
@@ -51,6 +51,6 @@ uv run a2c.py
51
  'push_model': True,
52
  'seed': 1,
53
  'total_timesteps': 150_000,
54
- 'video_capture_frequency': 100}
55
  ```
56
 
 
51
  'push_model': True,
52
  'seed': 1,
53
  'total_timesteps': 150_000,
54
+ 'video_capture_frequency': 50}
55
  ```
56
 
a2c.py CHANGED
@@ -37,7 +37,7 @@ class HyperParams:
37
  """The number of parallel environments to run."""
38
  seed: int = 1
39
  """The random seed for reproducibility."""
40
- video_capture_frequency: int = 100
41
  """The interval (in episodes) to record videos of the agent's performance."""
42
 
43
  total_timesteps: int = 150_000
@@ -125,12 +125,15 @@ class ActorCritic(nn.Module):
125
  logits = self.actor(states)
126
  return values, logits
127
 
128
- def act(self, states: torch.Tensor):
 
 
129
  values, logits = self.forward(states)
130
  pd = torch.distributions.Categorical(logits=logits)
131
  actions = pd.sample()
132
  logprobs = pd.log_prob(actions)
133
- return actions, logprobs, pd.entropy(), values
 
134
 
135
 
136
  def main() -> None:
@@ -173,6 +176,7 @@ def main() -> None:
173
  values = torch.zeros(args.num_steps, envs.num_envs, device=device)
174
  rewards = torch.zeros(args.num_steps, envs.num_envs, device=device)
175
  logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device)
 
176
  masks = torch.zeros(args.num_steps, envs.num_envs, device=device)
177
 
178
  for t in range(args.num_steps):
@@ -196,6 +200,7 @@ def main() -> None:
196
  values[t] = value.squeeze()
197
  rewards[t] = torch.from_numpy(reward)
198
  logprobs[t] = logprob
 
199
  masks[t] = torch.from_numpy(~terminations)
200
 
201
  advantages = torch.zeros_like(rewards).to(device)
@@ -207,7 +212,7 @@ def main() -> None:
207
 
208
  critic_loss = advantages.pow(2).mean()
209
  actor_loss = (
210
- -(logprobs * advantages.detach()).mean() - args.ent_coef * entropy.mean()
211
  )
212
 
213
  loss = actor_loss + critic_loss
@@ -219,7 +224,7 @@ def main() -> None:
219
  if step % args.log_interval < envs.num_envs:
220
  writer.add_scalar("losses/actor_loss", actor_loss, step)
221
  writer.add_scalar("losses/critic_loss", critic_loss, step)
222
- writer.add_scalar("losses/entropy", entropy.mean(), step)
223
  writer.add_scalar("charts/SPS", step // (time.time() - start_time), step)
224
  writer.add_scalar("losses/total_loss", loss, step)
225
  writer.add_scalar("losses/value_estimate", values.mean().item(), step)
@@ -228,7 +233,7 @@ def main() -> None:
228
  actor_loss=actor_loss.item(),
229
  critic_loss=critic_loss.item(),
230
  # total_loss=loss.item(),
231
- # entropy=entropy.mean().item(),
232
  # value_estimate=values.mean().item(),
233
  advantage=advantages.mean().item(),
234
  sps=step // (time.time() - start_time),
 
37
  """The number of parallel environments to run."""
38
  seed: int = 1
39
  """The random seed for reproducibility."""
40
+ video_capture_frequency: int = 50
41
  """The interval (in episodes) to record videos of the agent's performance."""
42
 
43
  total_timesteps: int = 150_000
 
125
  logits = self.actor(states)
126
  return values, logits
127
 
128
+ def act(
129
+ self, states: torch.Tensor
130
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
131
  values, logits = self.forward(states)
132
  pd = torch.distributions.Categorical(logits=logits)
133
  actions = pd.sample()
134
  logprobs = pd.log_prob(actions)
135
+ entropy = pd.entropy()
136
+ return actions, logprobs, entropy, values
137
 
138
 
139
  def main() -> None:
 
176
  values = torch.zeros(args.num_steps, envs.num_envs, device=device)
177
  rewards = torch.zeros(args.num_steps, envs.num_envs, device=device)
178
  logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device)
179
+ entropies = torch.zeros(args.num_steps, envs.num_envs, device=device)
180
  masks = torch.zeros(args.num_steps, envs.num_envs, device=device)
181
 
182
  for t in range(args.num_steps):
 
200
  values[t] = value.squeeze()
201
  rewards[t] = torch.from_numpy(reward)
202
  logprobs[t] = logprob
203
+ entropies[t] = entropy
204
  masks[t] = torch.from_numpy(~terminations)
205
 
206
  advantages = torch.zeros_like(rewards).to(device)
 
212
 
213
  critic_loss = advantages.pow(2).mean()
214
  actor_loss = (
215
+ -(logprobs * advantages.detach()).mean() - args.ent_coef * entropies.mean()
216
  )
217
 
218
  loss = actor_loss + critic_loss
 
224
  if step % args.log_interval < envs.num_envs:
225
  writer.add_scalar("losses/actor_loss", actor_loss, step)
226
  writer.add_scalar("losses/critic_loss", critic_loss, step)
227
+ writer.add_scalar("losses/entropy", entropies.mean(), step)
228
  writer.add_scalar("charts/SPS", step // (time.time() - start_time), step)
229
  writer.add_scalar("losses/total_loss", loss, step)
230
  writer.add_scalar("losses/value_estimate", values.mean().item(), step)
 
233
  actor_loss=actor_loss.item(),
234
  critic_loss=critic_loss.item(),
235
  # total_loss=loss.item(),
236
+ # entropy=entropies.mean().item(),
237
  # value_estimate=values.mean().item(),
238
  advantage=advantages.mean().item(),
239
  sps=step // (time.time() - start_time),
events.out.tfevents.1757726955.Alperens-MBP.local.88268.0 → events.out.tfevents.1757936575.Alperens-MBP.local.22547.0 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d45664a03cebe436ec067c370681c75ddcd17dc635617f618d93432d88e283e5
3
- size 1928850
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:150f0c68315d117958a752772a5d9b7569a9e9e03c3de4ffdab6ac2d2bc78bf0
3
+ size 1928849
hyperparameters.json CHANGED
@@ -3,7 +3,7 @@
3
  "exp_name": "a2c",
4
  "n_envs": 32,
5
  "seed": 1,
6
- "video_capture_frequency": 100,
7
  "total_timesteps": 150000,
8
  "num_steps": 20,
9
  "gamma": 0.99,
 
3
  "exp_name": "a2c",
4
  "n_envs": 32,
5
  "seed": 1,
6
+ "video_capture_frequency": 50,
7
  "total_timesteps": 150000,
8
  "num_steps": 20,
9
  "gamma": 0.99,
pyproject.toml CHANGED
@@ -1,7 +1,7 @@
1
  [project]
2
  name = "hellrl"
3
  version = "0.1.0"
4
- description = "Add your description here"
5
  readme = "README.md"
6
  authors = [
7
  { name = "Alperen ÜNLÜ"}
 
1
  [project]
2
  name = "hellrl"
3
  version = "0.1.0"
4
+ description = "RL Implementations with Cutting Edge Versions and Vectorized Training"
5
  readme = "README.md"
6
  authors = [
7
  { name = "Alperen ÜNLÜ"}
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-0.mp4 DELETED
Binary file (38.6 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-1.mp4 DELETED
Binary file (40.9 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-2.mp4 DELETED
Binary file (37.2 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-3.mp4 DELETED
Binary file (38 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-4.mp4 DELETED
Binary file (38.6 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-5.mp4 DELETED
Binary file (38.5 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-6.mp4 DELETED
Binary file (38.8 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-7.mp4 DELETED
Binary file (37.1 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-8.mp4 DELETED
Binary file (35.3 kB)
 
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-9.mp4 DELETED
Binary file (38.2 kB)
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-0.mp4 ADDED
Binary file (38.6 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-1.mp4 ADDED
Binary file (38 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-2.mp4 ADDED
Binary file (37.5 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-3.mp4 ADDED
Binary file (39.8 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-4.mp4 ADDED
Binary file (37.4 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-5.mp4 ADDED
Binary file (37.7 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-6.mp4 ADDED
Binary file (37.4 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-7.mp4 ADDED
Binary file (36.2 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-8.mp4 ADDED
Binary file (36.1 kB). View file
 
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-9.mp4 ADDED
Binary file (39.8 kB). View file