DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3fd02ed
from typing import Optional, Sequence, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from rl_algo_impls.shared.actor import Actor, PiForward, actor_head | |
from rl_algo_impls.shared.encoder import Encoder | |
from rl_algo_impls.shared.policy.actor_critic import OnPolicy, Step, clamp_actions | |
from rl_algo_impls.shared.policy.actor_critic_network import default_hidden_sizes | |
from rl_algo_impls.shared.policy.critic import CriticHead | |
from rl_algo_impls.shared.policy.policy import ACTIVATION | |
from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
VecEnv, | |
VecEnvObs, | |
single_action_space, | |
single_observation_space, | |
) | |
PI_FILE_NAME = "pi.pt" | |
V_FILE_NAME = "v.pt" | |
class VPGActor(Actor): | |
def __init__(self, feature_extractor: Encoder, head: Actor) -> None: | |
super().__init__() | |
self.feature_extractor = feature_extractor | |
self.head = head | |
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: | |
fe = self.feature_extractor(obs) | |
return self.head(fe, a) | |
def sample_weights(self, batch_size: int = 1) -> None: | |
self.head.sample_weights(batch_size=batch_size) | |
def action_shape(self) -> Tuple[int, ...]: | |
return self.head.action_shape | |
class VPGActorCritic(OnPolicy): | |
def __init__( | |
self, | |
env: VecEnv, | |
hidden_sizes: Optional[Sequence[int]] = None, | |
init_layers_orthogonal: bool = True, | |
activation_fn: str = "tanh", | |
log_std_init: float = -0.5, | |
use_sde: bool = False, | |
full_std: bool = True, | |
squash_output: bool = False, | |
**kwargs, | |
) -> None: | |
super().__init__(env, **kwargs) | |
activation = ACTIVATION[activation_fn] | |
obs_space = single_observation_space(env) | |
self.action_space = single_action_space(env) | |
self.use_sde = use_sde | |
self.squash_output = squash_output | |
hidden_sizes = ( | |
hidden_sizes | |
if hidden_sizes is not None | |
else default_hidden_sizes(obs_space) | |
) | |
pi_feature_extractor = Encoder( | |
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal | |
) | |
pi_head = actor_head( | |
self.action_space, | |
pi_feature_extractor.out_dim, | |
tuple(hidden_sizes), | |
init_layers_orthogonal, | |
activation, | |
log_std_init=log_std_init, | |
use_sde=use_sde, | |
full_std=full_std, | |
squash_output=squash_output, | |
) | |
self.pi = VPGActor(pi_feature_extractor, pi_head) | |
v_feature_extractor = Encoder( | |
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal | |
) | |
v_head = CriticHead( | |
v_feature_extractor.out_dim, | |
tuple(hidden_sizes), | |
activation=activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
) | |
self.v = nn.Sequential(v_feature_extractor, v_head) | |
def value(self, obs: VecEnvObs) -> np.ndarray: | |
o = self._as_tensor(obs) | |
with torch.no_grad(): | |
v = self.v(o) | |
return v.cpu().numpy() | |
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step: | |
assert ( | |
action_masks is None | |
), f"action_masks not currently supported in {self.__class__.__name__}" | |
o = self._as_tensor(obs) | |
with torch.no_grad(): | |
pi, _, _ = self.pi(o) | |
a = pi.sample() | |
logp_a = pi.log_prob(a) | |
v = self.v(o) | |
a_np = a.cpu().numpy() | |
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output) | |
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np) | |
def act( | |
self, | |
obs: np.ndarray, | |
deterministic: bool = True, | |
action_masks: Optional[np.ndarray] = None, | |
) -> np.ndarray: | |
assert ( | |
action_masks is None | |
), f"action_masks not currently supported in {self.__class__.__name__}" | |
if not deterministic: | |
return self.step(obs).clamped_a | |
else: | |
o = self._as_tensor(obs) | |
with torch.no_grad(): | |
pi, _, _ = self.pi(o) | |
a = pi.mode | |
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output) | |
def load(self, path: str) -> None: | |
super().load(path) | |
self.reset_noise() | |
def reset_noise(self, batch_size: Optional[int] = None) -> None: | |
self.pi.sample_weights( | |
batch_size=batch_size if batch_size else self.env.num_envs | |
) | |
def action_shape(self) -> Tuple[int, ...]: | |
return self.pi.action_shape | |