sgoodfriend's picture
A2C playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
aa3f47c
raw
history blame
No virus
4.09 kB
import numpy as np
import os
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from copy import deepcopy
from stable_baselines3.common.vec_env import unwrap_vec_normalize
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from typing import Dict, Optional, Type, TypeVar, Union
from wrappers.normalize import NormalizeObservation, NormalizeReward
from wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper
ACTIVATION: Dict[str, Type[nn.Module]] = {
"tanh": nn.Tanh,
"relu": nn.ReLU,
}
VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
MODEL_FILENAME = "model.pth"
NORMALIZE_OBSERVATION_FILENAME = "norm_obs.npz"
NORMALIZE_REWARD_FILENAME = "norm_reward.npz"
PolicySelf = TypeVar("PolicySelf", bound="Policy")
class Policy(nn.Module, ABC):
@abstractmethod
def __init__(self, env: VecEnv, **kwargs) -> None:
super().__init__()
self.env = env
self.vec_normalize = unwrap_vec_normalize(env)
self.norm_observation = find_wrapper(env, NormalizeObservation)
self.norm_reward = find_wrapper(env, NormalizeReward)
self.device = None
def to(
self: PolicySelf,
device: Optional[torch.device] = None,
dtype: Optional[Union[torch.dtype, str]] = None,
non_blocking: bool = False,
) -> PolicySelf:
super().to(device, dtype, non_blocking)
self.device = device
return self
@abstractmethod
def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
...
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
if self.vec_normalize:
self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
if self.norm_observation:
self.norm_observation.save(
os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
)
if self.norm_reward:
self.norm_reward.save(os.path.join(path, NORMALIZE_REWARD_FILENAME))
torch.save(
self.state_dict(),
os.path.join(path, MODEL_FILENAME),
)
def load(self, path: str) -> None:
# VecNormalize load occurs in env.py
self.load_state_dict(
torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
)
if self.norm_observation:
self.norm_observation.load(
os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
)
if self.norm_reward:
self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME))
def reset_noise(self) -> None:
pass
def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
assert isinstance(obs, np.ndarray)
o = torch.as_tensor(obs)
if self.device is not None:
o = o.to(self.device)
return o
def num_trainable_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
def sync_normalization(self, destination_env) -> None:
current = destination_env
while current != current.unwrapped:
if isinstance(current, VecNormalize):
assert self.vec_normalize
current.ret_rms = deepcopy(self.vec_normalize.ret_rms)
if hasattr(self.vec_normalize, "obs_rms"):
current.obs_rms = deepcopy(self.vec_normalize.obs_rms)
elif isinstance(current, NormalizeObservation):
assert self.norm_observation
current.rms = deepcopy(self.norm_observation.rms)
elif isinstance(current, NormalizeReward):
assert self.norm_reward
current.rms = deepcopy(self.norm_reward.rms)
current = getattr(current, "venv", getattr(current, "env", current))
if not current:
raise AttributeError(
f"{type(current)} doesn't include env or venv attribute"
)