sgoodfriend's picture
DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3fd02ed
import logging
from dataclasses import asdict, dataclass
from time import perf_counter
from typing import List, NamedTuple, Optional, TypeVar
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.algorithm import Algorithm
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.gae import compute_advantages
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
from rl_algo_impls.shared.schedule import (
constant_schedule,
linear_schedule,
schedule,
update_learning_rate,
)
from rl_algo_impls.shared.stats import log_scalars
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
single_action_space,
single_observation_space,
)
class TrainStepStats(NamedTuple):
loss: float
pi_loss: float
v_loss: float
entropy_loss: float
approx_kl: float
clipped_frac: float
val_clipped_frac: float
@dataclass
class TrainStats:
loss: float
pi_loss: float
v_loss: float
entropy_loss: float
approx_kl: float
clipped_frac: float
val_clipped_frac: float
explained_var: float
def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
self.loss = np.mean([s.loss for s in step_stats]).item()
self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
self.explained_var = explained_var
def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
for name, value in asdict(self).items():
tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
def __repr__(self) -> str:
return " | ".join(
[
f"Loss: {round(self.loss, 2)}",
f"Pi L: {round(self.pi_loss, 2)}",
f"V L: {round(self.v_loss, 2)}",
f"E L: {round(self.entropy_loss, 2)}",
f"Apx KL Div: {round(self.approx_kl, 2)}",
f"Clip Frac: {round(self.clipped_frac, 2)}",
f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
]
)
PPOSelf = TypeVar("PPOSelf", bound="PPO")
class PPO(Algorithm):
def __init__(
self,
policy: ActorCritic,
env: VecEnv,
device: torch.device,
tb_writer: SummaryWriter,
learning_rate: float = 3e-4,
learning_rate_decay: str = "none",
n_steps: int = 2048,
batch_size: int = 64,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: float = 0.2,
clip_range_decay: str = "none",
clip_range_vf: Optional[float] = None,
clip_range_vf_decay: str = "none",
normalize_advantage: bool = True,
ent_coef: float = 0.0,
ent_coef_decay: str = "none",
vf_coef: float = 0.5,
ppo2_vf_coef_halving: bool = False,
max_grad_norm: float = 0.5,
sde_sample_freq: int = -1,
update_advantage_between_epochs: bool = True,
update_returns_between_epochs: bool = False,
gamma_end: Optional[float] = None,
) -> None:
super().__init__(policy, env, device, tb_writer)
self.policy = policy
self.get_action_mask = getattr(env, "get_action_mask", None)
self.gamma_schedule = (
linear_schedule(gamma, gamma_end)
if gamma_end is not None
else constant_schedule(gamma)
)
self.gae_lambda = gae_lambda
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
self.max_grad_norm = max_grad_norm
self.clip_range_schedule = schedule(clip_range_decay, clip_range)
self.clip_range_vf_schedule = None
if clip_range_vf:
self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
if normalize_advantage:
assert (
env.num_envs * n_steps > 1 and batch_size > 1
), f"Each minibatch must be larger than 1 to support normalization"
self.normalize_advantage = normalize_advantage
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
self.vf_coef = vf_coef
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.sde_sample_freq = sde_sample_freq
self.update_advantage_between_epochs = update_advantage_between_epochs
self.update_returns_between_epochs = update_returns_between_epochs
def learn(
self: PPOSelf,
train_timesteps: int,
callbacks: Optional[List[Callback]] = None,
total_timesteps: Optional[int] = None,
start_timesteps: int = 0,
) -> PPOSelf:
if total_timesteps is None:
total_timesteps = train_timesteps
assert start_timesteps + train_timesteps <= total_timesteps
epoch_dim = (self.n_steps, self.env.num_envs)
step_dim = (self.env.num_envs,)
obs_space = single_observation_space(self.env)
act_space = single_action_space(self.env)
act_shape = self.policy.action_shape
next_obs = self.env.reset()
next_action_masks = self.get_action_mask() if self.get_action_mask else None
next_episode_starts = np.full(step_dim, True, dtype=np.bool_)
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
rewards = np.zeros(epoch_dim, dtype=np.float32)
episode_starts = np.zeros(epoch_dim, dtype=np.bool_)
values = np.zeros(epoch_dim, dtype=np.float32)
logprobs = np.zeros(epoch_dim, dtype=np.float32)
action_masks = (
np.zeros(
(self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
)
if next_action_masks is not None
else None
)
timesteps_elapsed = start_timesteps
while timesteps_elapsed < start_timesteps + train_timesteps:
start_time = perf_counter()
progress = timesteps_elapsed / total_timesteps
ent_coef = self.ent_coef_schedule(progress)
learning_rate = self.lr_schedule(progress)
update_learning_rate(self.optimizer, learning_rate)
pi_clip = self.clip_range_schedule(progress)
gamma = self.gamma_schedule(progress)
chart_scalars = {
"learning_rate": self.optimizer.param_groups[0]["lr"],
"ent_coef": ent_coef,
"pi_clip": pi_clip,
"gamma": gamma,
}
if self.clip_range_vf_schedule:
v_clip = self.clip_range_vf_schedule(progress)
chart_scalars["v_clip"] = v_clip
else:
v_clip = None
log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
self.policy.eval()
self.policy.reset_noise()
for s in range(self.n_steps):
timesteps_elapsed += self.env.num_envs
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
self.policy.reset_noise()
obs[s] = next_obs
episode_starts[s] = next_episode_starts
if action_masks is not None:
action_masks[s] = next_action_masks
(
actions[s],
values[s],
logprobs[s],
clamped_action,
) = self.policy.step(next_obs, action_masks=next_action_masks)
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
clamped_action
)
next_action_masks = (
self.get_action_mask() if self.get_action_mask else None
)
self.policy.train()
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore
self.device
)
b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
b_action_masks = (
torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
self.device
)
if action_masks is not None
else None
)
y_pred = values.reshape(-1)
b_values = torch.tensor(y_pred).to(self.device)
step_stats = []
# Define variables that will definitely be set through the first epoch
advantages: np.ndarray = None # type: ignore
b_advantages: torch.Tensor = None # type: ignore
y_true: np.ndarray = None # type: ignore
b_returns: torch.Tensor = None # type: ignore
for e in range(self.n_epochs):
if e == 0 or self.update_advantage_between_epochs:
advantages = compute_advantages(
rewards,
values,
episode_starts,
next_episode_starts,
next_obs,
self.policy,
gamma,
self.gae_lambda,
)
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
if e == 0 or self.update_returns_between_epochs:
returns = advantages + values
y_true = returns.reshape(-1)
b_returns = torch.tensor(y_true).to(self.device)
b_idxs = torch.randperm(len(b_obs))
# Only record last epoch's stats
step_stats.clear()
for i in range(0, len(b_obs), self.batch_size):
self.policy.reset_noise(self.batch_size)
mb_idxs = b_idxs[i : i + self.batch_size]
mb_obs = b_obs[mb_idxs]
mb_actions = b_actions[mb_idxs]
mb_values = b_values[mb_idxs]
mb_logprobs = b_logprobs[mb_idxs]
mb_action_masks = (
b_action_masks[mb_idxs] if b_action_masks is not None else None
)
mb_adv = b_advantages[mb_idxs]
if self.normalize_advantage:
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
mb_returns = b_returns[mb_idxs]
new_logprobs, entropy, new_values = self.policy(
mb_obs, mb_actions, action_masks=mb_action_masks
)
logratio = new_logprobs - mb_logprobs
ratio = torch.exp(logratio)
clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
v_loss_unclipped = (new_values - mb_returns) ** 2
if v_clip:
v_loss_clipped = (
mb_values
+ torch.clamp(new_values - mb_values, -v_clip, v_clip)
- mb_returns
) ** 2
v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
else:
v_loss = v_loss_unclipped.mean()
if self.ppo2_vf_coef_halving:
v_loss *= 0.5
entropy_loss = -entropy.mean()
loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm
)
self.optimizer.step()
with torch.no_grad():
approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
clipped_frac = (
((ratio - 1).abs() > pi_clip)
.float()
.mean()
.cpu()
.numpy()
.item()
)
val_clipped_frac = (
((new_values - mb_values).abs() > v_clip)
.float()
.mean()
.cpu()
.numpy()
.item()
if v_clip
else 0
)
step_stats.append(
TrainStepStats(
loss.item(),
pi_loss.item(),
v_loss.item(),
entropy_loss.item(),
approx_kl,
clipped_frac,
val_clipped_frac,
)
)
var_y = np.var(y_true).item()
explained_var = (
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
)
TrainStats(step_stats, explained_var).write_to_tensorboard(
self.tb_writer, timesteps_elapsed
)
end_time = perf_counter()
rollout_steps = self.n_steps * self.env.num_envs
self.tb_writer.add_scalar(
"train/steps_per_second",
rollout_steps / (end_time - start_time),
timesteps_elapsed,
)
if callbacks:
if not all(
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
):
logging.info(
f"Callback terminated training at {timesteps_elapsed} timesteps"
)
break
return self