Push model
Browse files- README.md +1 -1
- a2c.py +11 -6
- events.out.tfevents.1757726955.Alperens-MBP.local.88268.0 → events.out.tfevents.1757936575.Alperens-MBP.local.22547.0 +2 -2
- hyperparameters.json +1 -1
- pyproject.toml +1 -1
- replay.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-0.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-1.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-2.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-3.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-4.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-5.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-6.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-7.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-8.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-9.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-0.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-1.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-2.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-3.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-4.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-5.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-6.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-7.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-8.mp4 +0 -0
- videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-9.mp4 +0 -0
README.md
CHANGED
|
@@ -51,6 +51,6 @@ uv run a2c.py
|
|
| 51 |
'push_model': True,
|
| 52 |
'seed': 1,
|
| 53 |
'total_timesteps': 150_000,
|
| 54 |
-
'video_capture_frequency':
|
| 55 |
```
|
| 56 |
|
|
|
|
| 51 |
'push_model': True,
|
| 52 |
'seed': 1,
|
| 53 |
'total_timesteps': 150_000,
|
| 54 |
+
'video_capture_frequency': 50}
|
| 55 |
```
|
| 56 |
|
a2c.py
CHANGED
|
@@ -37,7 +37,7 @@ class HyperParams:
|
|
| 37 |
"""The number of parallel environments to run."""
|
| 38 |
seed: int = 1
|
| 39 |
"""The random seed for reproducibility."""
|
| 40 |
-
video_capture_frequency: int =
|
| 41 |
"""The interval (in episodes) to record videos of the agent's performance."""
|
| 42 |
|
| 43 |
total_timesteps: int = 150_000
|
|
@@ -125,12 +125,15 @@ class ActorCritic(nn.Module):
|
|
| 125 |
logits = self.actor(states)
|
| 126 |
return values, logits
|
| 127 |
|
| 128 |
-
def act(
|
|
|
|
|
|
|
| 129 |
values, logits = self.forward(states)
|
| 130 |
pd = torch.distributions.Categorical(logits=logits)
|
| 131 |
actions = pd.sample()
|
| 132 |
logprobs = pd.log_prob(actions)
|
| 133 |
-
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def main() -> None:
|
|
@@ -173,6 +176,7 @@ def main() -> None:
|
|
| 173 |
values = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 174 |
rewards = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 175 |
logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
|
|
|
| 176 |
masks = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 177 |
|
| 178 |
for t in range(args.num_steps):
|
|
@@ -196,6 +200,7 @@ def main() -> None:
|
|
| 196 |
values[t] = value.squeeze()
|
| 197 |
rewards[t] = torch.from_numpy(reward)
|
| 198 |
logprobs[t] = logprob
|
|
|
|
| 199 |
masks[t] = torch.from_numpy(~terminations)
|
| 200 |
|
| 201 |
advantages = torch.zeros_like(rewards).to(device)
|
|
@@ -207,7 +212,7 @@ def main() -> None:
|
|
| 207 |
|
| 208 |
critic_loss = advantages.pow(2).mean()
|
| 209 |
actor_loss = (
|
| 210 |
-
-(logprobs * advantages.detach()).mean() - args.ent_coef *
|
| 211 |
)
|
| 212 |
|
| 213 |
loss = actor_loss + critic_loss
|
|
@@ -219,7 +224,7 @@ def main() -> None:
|
|
| 219 |
if step % args.log_interval < envs.num_envs:
|
| 220 |
writer.add_scalar("losses/actor_loss", actor_loss, step)
|
| 221 |
writer.add_scalar("losses/critic_loss", critic_loss, step)
|
| 222 |
-
writer.add_scalar("losses/entropy",
|
| 223 |
writer.add_scalar("charts/SPS", step // (time.time() - start_time), step)
|
| 224 |
writer.add_scalar("losses/total_loss", loss, step)
|
| 225 |
writer.add_scalar("losses/value_estimate", values.mean().item(), step)
|
|
@@ -228,7 +233,7 @@ def main() -> None:
|
|
| 228 |
actor_loss=actor_loss.item(),
|
| 229 |
critic_loss=critic_loss.item(),
|
| 230 |
# total_loss=loss.item(),
|
| 231 |
-
# entropy=
|
| 232 |
# value_estimate=values.mean().item(),
|
| 233 |
advantage=advantages.mean().item(),
|
| 234 |
sps=step // (time.time() - start_time),
|
|
|
|
| 37 |
"""The number of parallel environments to run."""
|
| 38 |
seed: int = 1
|
| 39 |
"""The random seed for reproducibility."""
|
| 40 |
+
video_capture_frequency: int = 50
|
| 41 |
"""The interval (in episodes) to record videos of the agent's performance."""
|
| 42 |
|
| 43 |
total_timesteps: int = 150_000
|
|
|
|
| 125 |
logits = self.actor(states)
|
| 126 |
return values, logits
|
| 127 |
|
| 128 |
+
def act(
|
| 129 |
+
self, states: torch.Tensor
|
| 130 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 131 |
values, logits = self.forward(states)
|
| 132 |
pd = torch.distributions.Categorical(logits=logits)
|
| 133 |
actions = pd.sample()
|
| 134 |
logprobs = pd.log_prob(actions)
|
| 135 |
+
entropy = pd.entropy()
|
| 136 |
+
return actions, logprobs, entropy, values
|
| 137 |
|
| 138 |
|
| 139 |
def main() -> None:
|
|
|
|
| 176 |
values = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 177 |
rewards = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 178 |
logprobs = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 179 |
+
entropies = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 180 |
masks = torch.zeros(args.num_steps, envs.num_envs, device=device)
|
| 181 |
|
| 182 |
for t in range(args.num_steps):
|
|
|
|
| 200 |
values[t] = value.squeeze()
|
| 201 |
rewards[t] = torch.from_numpy(reward)
|
| 202 |
logprobs[t] = logprob
|
| 203 |
+
entropies[t] = entropy
|
| 204 |
masks[t] = torch.from_numpy(~terminations)
|
| 205 |
|
| 206 |
advantages = torch.zeros_like(rewards).to(device)
|
|
|
|
| 212 |
|
| 213 |
critic_loss = advantages.pow(2).mean()
|
| 214 |
actor_loss = (
|
| 215 |
+
-(logprobs * advantages.detach()).mean() - args.ent_coef * entropies.mean()
|
| 216 |
)
|
| 217 |
|
| 218 |
loss = actor_loss + critic_loss
|
|
|
|
| 224 |
if step % args.log_interval < envs.num_envs:
|
| 225 |
writer.add_scalar("losses/actor_loss", actor_loss, step)
|
| 226 |
writer.add_scalar("losses/critic_loss", critic_loss, step)
|
| 227 |
+
writer.add_scalar("losses/entropy", entropies.mean(), step)
|
| 228 |
writer.add_scalar("charts/SPS", step // (time.time() - start_time), step)
|
| 229 |
writer.add_scalar("losses/total_loss", loss, step)
|
| 230 |
writer.add_scalar("losses/value_estimate", values.mean().item(), step)
|
|
|
|
| 233 |
actor_loss=actor_loss.item(),
|
| 234 |
critic_loss=critic_loss.item(),
|
| 235 |
# total_loss=loss.item(),
|
| 236 |
+
# entropy=entropies.mean().item(),
|
| 237 |
# value_estimate=values.mean().item(),
|
| 238 |
advantage=advantages.mean().item(),
|
| 239 |
sps=step // (time.time() - start_time),
|
events.out.tfevents.1757726955.Alperens-MBP.local.88268.0 → events.out.tfevents.1757936575.Alperens-MBP.local.22547.0
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:150f0c68315d117958a752772a5d9b7569a9e9e03c3de4ffdab6ac2d2bc78bf0
|
| 3 |
+
size 1928849
|
hyperparameters.json
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
"exp_name": "a2c",
|
| 4 |
"n_envs": 32,
|
| 5 |
"seed": 1,
|
| 6 |
-
"video_capture_frequency":
|
| 7 |
"total_timesteps": 150000,
|
| 8 |
"num_steps": 20,
|
| 9 |
"gamma": 0.99,
|
|
|
|
| 3 |
"exp_name": "a2c",
|
| 4 |
"n_envs": 32,
|
| 5 |
"seed": 1,
|
| 6 |
+
"video_capture_frequency": 50,
|
| 7 |
"total_timesteps": 150000,
|
| 8 |
"num_steps": 20,
|
| 9 |
"gamma": 0.99,
|
pyproject.toml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
[project]
|
| 2 |
name = "hellrl"
|
| 3 |
version = "0.1.0"
|
| 4 |
-
description = "
|
| 5 |
readme = "README.md"
|
| 6 |
authors = [
|
| 7 |
{ name = "Alperen ÜNLÜ"}
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "hellrl"
|
| 3 |
version = "0.1.0"
|
| 4 |
+
description = "RL Implementations with Cutting Edge Versions and Vectorized Training"
|
| 5 |
readme = "README.md"
|
| 6 |
authors = [
|
| 7 |
{ name = "Alperen ÜNLÜ"}
|
replay.mp4
CHANGED
|
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-0.mp4
DELETED
|
Binary file (38.6 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-1.mp4
DELETED
|
Binary file (40.9 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-2.mp4
DELETED
|
Binary file (37.2 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-3.mp4
DELETED
|
Binary file (38 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-4.mp4
DELETED
|
Binary file (38.6 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-5.mp4
DELETED
|
Binary file (38.5 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-6.mp4
DELETED
|
Binary file (38.8 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-7.mp4
DELETED
|
Binary file (37.1 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-8.mp4
DELETED
|
Binary file (35.3 kB)
|
|
|
videos/CartPole-v1_a2c_1_250913_042915_eval/CartPole-v1-episode-9.mp4
DELETED
|
Binary file (38.2 kB)
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-0.mp4
ADDED
|
Binary file (38.6 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-1.mp4
ADDED
|
Binary file (38 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-2.mp4
ADDED
|
Binary file (37.5 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-3.mp4
ADDED
|
Binary file (39.8 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-4.mp4
ADDED
|
Binary file (37.4 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-5.mp4
ADDED
|
Binary file (37.7 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-6.mp4
ADDED
|
Binary file (37.4 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-7.mp4
ADDED
|
Binary file (36.2 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-8.mp4
ADDED
|
Binary file (36.1 kB). View file
|
|
|
videos/CartPole-v1_a2c_1_250915_144255_eval/CartPole-v1-episode-9.mp4
ADDED
|
Binary file (39.8 kB). View file
|
|
|