sgoodfriend's picture
VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
ec6152b
from typing import Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn as nn
from rl_algo_impls.shared.actor import Actor, PiForward, actor_head
from rl_algo_impls.shared.encoder import Encoder
from rl_algo_impls.shared.policy.actor_critic import OnPolicy, Step, clamp_actions
from rl_algo_impls.shared.policy.actor_critic_network import default_hidden_sizes
from rl_algo_impls.shared.policy.critic import CriticHead
from rl_algo_impls.shared.policy.policy import ACTIVATION
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
VecEnvObs,
single_action_space,
single_observation_space,
)
PI_FILE_NAME = "pi.pt"
V_FILE_NAME = "v.pt"
class VPGActor(Actor):
def __init__(self, feature_extractor: Encoder, head: Actor) -> None:
super().__init__()
self.feature_extractor = feature_extractor
self.head = head
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
fe = self.feature_extractor(obs)
return self.head(fe, a)
def sample_weights(self, batch_size: int = 1) -> None:
self.head.sample_weights(batch_size=batch_size)
@property
def action_shape(self) -> Tuple[int, ...]:
return self.head.action_shape
class VPGActorCritic(OnPolicy):
def __init__(
self,
env: VecEnv,
hidden_sizes: Optional[Sequence[int]] = None,
init_layers_orthogonal: bool = True,
activation_fn: str = "tanh",
log_std_init: float = -0.5,
use_sde: bool = False,
full_std: bool = True,
squash_output: bool = False,
**kwargs,
) -> None:
super().__init__(env, **kwargs)
activation = ACTIVATION[activation_fn]
obs_space = single_observation_space(env)
self.action_space = single_action_space(env)
self.use_sde = use_sde
self.squash_output = squash_output
hidden_sizes = (
hidden_sizes
if hidden_sizes is not None
else default_hidden_sizes(obs_space)
)
pi_feature_extractor = Encoder(
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
)
pi_head = actor_head(
self.action_space,
pi_feature_extractor.out_dim,
tuple(hidden_sizes),
init_layers_orthogonal,
activation,
log_std_init=log_std_init,
use_sde=use_sde,
full_std=full_std,
squash_output=squash_output,
)
self.pi = VPGActor(pi_feature_extractor, pi_head)
v_feature_extractor = Encoder(
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
)
v_head = CriticHead(
v_feature_extractor.out_dim,
tuple(hidden_sizes),
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
self.v = nn.Sequential(v_feature_extractor, v_head)
def value(self, obs: VecEnvObs) -> np.ndarray:
o = self._as_tensor(obs)
with torch.no_grad():
v = self.v(o)
return v.cpu().numpy()
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
assert (
action_masks is None
), f"action_masks not currently supported in {self.__class__.__name__}"
o = self._as_tensor(obs)
with torch.no_grad():
pi, _, _ = self.pi(o)
a = pi.sample()
logp_a = pi.log_prob(a)
v = self.v(o)
a_np = a.cpu().numpy()
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
def act(
self,
obs: np.ndarray,
deterministic: bool = True,
action_masks: Optional[np.ndarray] = None,
) -> np.ndarray:
assert (
action_masks is None
), f"action_masks not currently supported in {self.__class__.__name__}"
if not deterministic:
return self.step(obs).clamped_a
else:
o = self._as_tensor(obs)
with torch.no_grad():
pi, _, _ = self.pi(o)
a = pi.mode
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
def load(self, path: str) -> None:
super().load(path)
self.reset_noise()
def reset_noise(self, batch_size: Optional[int] = None) -> None:
self.pi.sample_weights(
batch_size=batch_size if batch_size else self.env.num_envs
)
@property
def action_shape(self) -> Tuple[int, ...]:
return self.pi.action_shape