import numpy as np import torch import torch.nn as nn from collections import defaultdict from dataclasses import dataclass, asdict from torch.optim import Adam from torch.utils.tensorboard.writer import SummaryWriter from typing import Optional, Sequence, TypeVar from shared.algorithm import Algorithm from shared.callbacks.callback import Callback from shared.gae import compute_rtg_and_advantage, compute_advantage from shared.trajectory import Trajectory, TrajectoryAccumulator from vpg.policy import VPGActorCritic from wrappers.vectorable_wrapper import VecEnv @dataclass class TrainEpochStats: pi_loss: float entropy_loss: float v_loss: float envs_with_done: int = 0 episodes_done: int = 0 def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None: for name, value in asdict(self).items(): tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step) class VPGTrajectoryAccumulator(TrajectoryAccumulator): def __init__(self, num_envs: int) -> None: super().__init__(num_envs, trajectory_class=Trajectory) self.completed_per_env: defaultdict[int, int] = defaultdict(int) def on_done(self, env_idx: int, trajectory: Trajectory) -> None: self.completed_per_env[env_idx] += 1 VanillaPolicyGradientSelf = TypeVar( "VanillaPolicyGradientSelf", bound="VanillaPolicyGradient" ) class VanillaPolicyGradient(Algorithm): def __init__( self, policy: VPGActorCritic, env: VecEnv, device: torch.device, tb_writer: SummaryWriter, gamma: float = 0.99, pi_lr: float = 3e-4, val_lr: float = 1e-3, train_v_iters: int = 80, gae_lambda: float = 0.97, max_grad_norm: float = 10.0, n_steps: int = 4_000, sde_sample_freq: int = -1, update_rtg_between_v_iters: bool = False, ent_coef: float = 0.0, ) -> None: super().__init__(policy, env, device, tb_writer) self.policy = policy self.gamma = gamma self.gae_lambda = gae_lambda self.pi_optim = Adam(self.policy.pi.parameters(), lr=pi_lr) self.val_optim = Adam(self.policy.v.parameters(), lr=val_lr) self.max_grad_norm = max_grad_norm self.n_steps = n_steps self.train_v_iters = train_v_iters self.sde_sample_freq = sde_sample_freq self.update_rtg_between_v_iters = update_rtg_between_v_iters self.ent_coef = ent_coef def learn( self: VanillaPolicyGradientSelf, total_timesteps: int, callback: Optional[Callback] = None, ) -> VanillaPolicyGradientSelf: timesteps_elapsed = 0 epoch_cnt = 0 while timesteps_elapsed < total_timesteps: epoch_cnt += 1 accumulator = self._collect_trajectories() epoch_stats = self.train(accumulator.all_trajectories) epoch_stats.envs_with_done = len(accumulator.completed_per_env) epoch_stats.episodes_done = sum(accumulator.completed_per_env.values()) epoch_steps = accumulator.n_timesteps() timesteps_elapsed += epoch_steps epoch_stats.write_to_tensorboard( self.tb_writer, global_step=timesteps_elapsed ) print( " | ".join( [ f"Epoch: {epoch_cnt}", f"Pi Loss: {round(epoch_stats.pi_loss, 2)}", f"Epoch Loss: {round(epoch_stats.entropy_loss, 2)}", f"V Loss: {round(epoch_stats.v_loss, 2)}", f"Total Steps: {timesteps_elapsed}", ] ) ) if callback: callback.on_step(timesteps_elapsed=epoch_steps) return self def train(self, trajectories: Sequence[Trajectory]) -> TrainEpochStats: self.policy.train() obs = torch.as_tensor( np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device ) act = torch.as_tensor( np.concatenate([np.array(t.act) for t in trajectories]), device=self.device ) rtg, adv = compute_rtg_and_advantage( trajectories, self.policy, self.gamma, self.gae_lambda, self.device ) _, logp, entropy = self.policy.pi(obs, act) pi_loss = -(logp * adv).mean() entropy_loss = entropy.mean() actor_loss = pi_loss - self.ent_coef * entropy_loss self.pi_optim.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.policy.pi.parameters(), self.max_grad_norm) self.pi_optim.step() v_loss = 0 for _ in range(self.train_v_iters): if self.update_rtg_between_v_iters: rtg = compute_advantage( trajectories, self.policy, self.gamma, self.gae_lambda, self.device ) v = self.policy.v(obs) v_loss = ((v - rtg) ** 2).mean() self.val_optim.zero_grad() v_loss.backward() nn.utils.clip_grad_norm_(self.policy.v.parameters(), self.max_grad_norm) self.val_optim.step() return TrainEpochStats( pi_loss.item(), entropy_loss.item(), v_loss.item(), # type: ignore ) def _collect_trajectories(self) -> VPGTrajectoryAccumulator: self.policy.eval() obs = self.env.reset() accumulator = VPGTrajectoryAccumulator(self.env.num_envs) self.policy.reset_noise() for i in range(self.n_steps): if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0: self.policy.reset_noise() action, value, _, clamped_action = self.policy.step(obs) next_obs, reward, done, _ = self.env.step(clamped_action) accumulator.step(obs, action, next_obs, reward, done, value) obs = next_obs return accumulator