A2C playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
05b94c0
import copy | |
import logging | |
import random | |
from collections import deque | |
from typing import List, NamedTuple, Optional, TypeVar | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import Adam | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from rl_algo_impls.dqn.policy import DQNPolicy | |
from rl_algo_impls.shared.algorithm import Algorithm | |
from rl_algo_impls.shared.callbacks import Callback | |
from rl_algo_impls.shared.schedule import linear_schedule | |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs | |
class Transition(NamedTuple): | |
obs: np.ndarray | |
action: np.ndarray | |
reward: float | |
done: bool | |
next_obs: np.ndarray | |
class Batch(NamedTuple): | |
obs: np.ndarray | |
actions: np.ndarray | |
rewards: np.ndarray | |
dones: np.ndarray | |
next_obs: np.ndarray | |
class ReplayBuffer: | |
def __init__(self, num_envs: int, maxlen: int) -> None: | |
self.num_envs = num_envs | |
self.buffer = deque(maxlen=maxlen) | |
def add( | |
self, | |
obs: VecEnvObs, | |
action: np.ndarray, | |
reward: np.ndarray, | |
done: np.ndarray, | |
next_obs: VecEnvObs, | |
) -> None: | |
assert isinstance(obs, np.ndarray) | |
assert isinstance(next_obs, np.ndarray) | |
for i in range(self.num_envs): | |
self.buffer.append( | |
Transition(obs[i], action[i], reward[i], done[i], next_obs[i]) | |
) | |
def sample(self, batch_size: int) -> Batch: | |
ts = random.sample(self.buffer, batch_size) | |
return Batch( | |
obs=np.array([t.obs for t in ts]), | |
actions=np.array([t.action for t in ts]), | |
rewards=np.array([t.reward for t in ts]), | |
dones=np.array([t.done for t in ts]), | |
next_obs=np.array([t.next_obs for t in ts]), | |
) | |
def __len__(self) -> int: | |
return len(self.buffer) | |
DQNSelf = TypeVar("DQNSelf", bound="DQN") | |
class DQN(Algorithm): | |
def __init__( | |
self, | |
policy: DQNPolicy, | |
env: VecEnv, | |
device: torch.device, | |
tb_writer: SummaryWriter, | |
learning_rate: float = 1e-4, | |
buffer_size: int = 1_000_000, | |
learning_starts: int = 50_000, | |
batch_size: int = 32, | |
tau: float = 1.0, | |
gamma: float = 0.99, | |
train_freq: int = 4, | |
gradient_steps: int = 1, | |
target_update_interval: int = 10_000, | |
exploration_fraction: float = 0.1, | |
exploration_initial_eps: float = 1.0, | |
exploration_final_eps: float = 0.05, | |
max_grad_norm: float = 10.0, | |
) -> None: | |
super().__init__(policy, env, device, tb_writer) | |
self.policy = policy | |
self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate) | |
self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device) | |
self.target_q_net.train(False) | |
self.tau = tau | |
self.target_update_interval = target_update_interval | |
self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size) | |
self.batch_size = batch_size | |
self.learning_starts = learning_starts | |
self.train_freq = train_freq | |
self.gradient_steps = gradient_steps | |
self.gamma = gamma | |
self.exploration_eps_schedule = linear_schedule( | |
exploration_initial_eps, | |
exploration_final_eps, | |
end_fraction=exploration_fraction, | |
) | |
self.max_grad_norm = max_grad_norm | |
def learn( | |
self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None | |
) -> DQNSelf: | |
self.policy.train(True) | |
obs = self.env.reset() | |
obs = self._collect_rollout(self.learning_starts, obs, 1) | |
learning_steps = total_timesteps - self.learning_starts | |
timesteps_elapsed = 0 | |
steps_since_target_update = 0 | |
while timesteps_elapsed < learning_steps: | |
progress = timesteps_elapsed / learning_steps | |
eps = self.exploration_eps_schedule(progress) | |
obs = self._collect_rollout(self.train_freq, obs, eps) | |
rollout_steps = self.train_freq | |
timesteps_elapsed += rollout_steps | |
for _ in range( | |
self.gradient_steps if self.gradient_steps > 0 else self.train_freq | |
): | |
self.train() | |
steps_since_target_update += rollout_steps | |
if steps_since_target_update >= self.target_update_interval: | |
self._update_target() | |
steps_since_target_update = 0 | |
if callbacks: | |
if not all( | |
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks | |
): | |
logging.info( | |
f"Callback terminated training at {timesteps_elapsed} timesteps" | |
) | |
break | |
return self | |
def train(self) -> None: | |
if len(self.replay_buffer) < self.batch_size: | |
return | |
o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size) | |
o = torch.as_tensor(o, device=self.device) | |
a = torch.as_tensor(a, device=self.device).unsqueeze(1) | |
r = torch.as_tensor(r, dtype=torch.float32, device=self.device) | |
d = torch.as_tensor(d, dtype=torch.long, device=self.device) | |
next_o = torch.as_tensor(next_o, device=self.device) | |
with torch.no_grad(): | |
target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values | |
current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1) | |
loss = F.smooth_l1_loss(current, target) | |
self.optimizer.zero_grad() | |
loss.backward() | |
if self.max_grad_norm: | |
nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm) | |
self.optimizer.step() | |
def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs: | |
for _ in range(0, timesteps, self.env.num_envs): | |
action = self.policy.act(obs, eps, deterministic=False) | |
next_obs, reward, done, _ = self.env.step(action) | |
self.replay_buffer.add(obs, action, reward, done, next_obs) | |
obs = next_obs | |
return obs | |
def _update_target(self) -> None: | |
for target_param, param in zip( | |
self.target_q_net.parameters(), self.policy.q_net.parameters() | |
): | |
target_param.data.copy_( | |
self.tau * param.data + (1 - self.tau) * target_param.data | |
) | |