File size: 7,219 Bytes
d643f40 33f6fc2 d643f40 33f6fc2 5c9c091 33f6fc2 d643f40 33f6fc2 d643f40 33f6fc2 d643f40 5c9c091 d643f40 5c9c091 d643f40 5c9c091 d643f40 33f6fc2 d643f40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import logging
from time import perf_counter
from typing import List, Optional, TypeVar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.algorithm import Algorithm
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.gae import compute_advantages
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
from rl_algo_impls.shared.stats import log_scalars
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
single_action_space,
single_observation_space,
)
A2CSelf = TypeVar("A2CSelf", bound="A2C")
class A2C(Algorithm):
def __init__(
self,
policy: ActorCritic,
env: VecEnv,
device: torch.device,
tb_writer: SummaryWriter,
learning_rate: float = 7e-4,
learning_rate_decay: str = "none",
n_steps: int = 5,
gamma: float = 0.99,
gae_lambda: float = 1.0,
ent_coef: float = 0.0,
ent_coef_decay: str = "none",
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rms_prop_eps: float = 1e-5,
use_rms_prop: bool = True,
sde_sample_freq: int = -1,
normalize_advantage: bool = False,
) -> None:
super().__init__(policy, env, device, tb_writer)
self.policy = policy
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
if use_rms_prop:
self.optimizer = torch.optim.RMSprop(
policy.parameters(), lr=learning_rate, eps=rms_prop_eps
)
else:
self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
self.n_steps = n_steps
self.gamma = gamma
self.gae_lambda = gae_lambda
self.vf_coef = vf_coef
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
self.max_grad_norm = max_grad_norm
self.sde_sample_freq = sde_sample_freq
self.normalize_advantage = normalize_advantage
def learn(
self: A2CSelf,
train_timesteps: int,
callbacks: Optional[List[Callback]] = None,
total_timesteps: Optional[int] = None,
start_timesteps: int = 0,
) -> A2CSelf:
if total_timesteps is None:
total_timesteps = train_timesteps
assert start_timesteps + train_timesteps <= total_timesteps
epoch_dim = (self.n_steps, self.env.num_envs)
step_dim = (self.env.num_envs,)
obs_space = single_observation_space(self.env)
act_space = single_action_space(self.env)
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
rewards = np.zeros(epoch_dim, dtype=np.float32)
episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
values = np.zeros(epoch_dim, dtype=np.float32)
logprobs = np.zeros(epoch_dim, dtype=np.float32)
next_obs = self.env.reset()
next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
timesteps_elapsed = start_timesteps
while timesteps_elapsed < start_timesteps + train_timesteps:
start_time = perf_counter()
progress = timesteps_elapsed / total_timesteps
ent_coef = self.ent_coef_schedule(progress)
learning_rate = self.lr_schedule(progress)
update_learning_rate(self.optimizer, learning_rate)
log_scalars(
self.tb_writer,
"charts",
{
"ent_coef": ent_coef,
"learning_rate": learning_rate,
},
timesteps_elapsed,
)
self.policy.eval()
self.policy.reset_noise()
for s in range(self.n_steps):
timesteps_elapsed += self.env.num_envs
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
self.policy.reset_noise()
obs[s] = next_obs
episode_starts[s] = next_episode_starts
actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
next_obs
)
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
clamped_action
)
advantages = compute_advantages(
rewards,
values,
episode_starts,
next_episode_starts,
next_obs,
self.policy,
self.gamma,
self.gae_lambda,
)
returns = advantages + values
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
self.device
)
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
if self.normalize_advantage:
b_advantages = (b_advantages - b_advantages.mean()) / (
b_advantages.std() + 1e-8
)
self.policy.train()
logp_a, entropy, v = self.policy(b_obs, b_actions)
pi_loss = -(b_advantages * logp_a).mean()
value_loss = F.mse_loss(b_returns, v)
entropy_loss = -entropy.mean()
loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
y_pred = values.reshape(-1)
y_true = returns.reshape(-1)
var_y = np.var(y_true).item()
explained_var = (
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
)
end_time = perf_counter()
rollout_steps = self.n_steps * self.env.num_envs
self.tb_writer.add_scalar(
"train/steps_per_second",
(rollout_steps) / (end_time - start_time),
timesteps_elapsed,
)
log_scalars(
self.tb_writer,
"losses",
{
"loss": loss.item(),
"pi_loss": pi_loss.item(),
"v_loss": value_loss.item(),
"entropy_loss": entropy_loss.item(),
"explained_var": explained_var,
},
timesteps_elapsed,
)
if callbacks:
if not all(
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
):
logging.info(
f"Callback terminated training at {timesteps_elapsed} timesteps"
)
break
return self
|