sgoodfriend's picture
A2C playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
3d6ce6f
raw
history blame
4.09 kB
import numpy as np
import os
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from copy import deepcopy
from stable_baselines3.common.vec_env import unwrap_vec_normalize
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from typing import Dict, Optional, Type, TypeVar, Union
from wrappers.normalize import NormalizeObservation, NormalizeReward
from wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper
ACTIVATION: Dict[str, Type[nn.Module]] = {
"tanh": nn.Tanh,
"relu": nn.ReLU,
}
VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
MODEL_FILENAME = "model.pth"
NORMALIZE_OBSERVATION_FILENAME = "norm_obs.npz"
NORMALIZE_REWARD_FILENAME = "norm_reward.npz"
PolicySelf = TypeVar("PolicySelf", bound="Policy")
class Policy(nn.Module, ABC):
@abstractmethod
def __init__(self, env: VecEnv, **kwargs) -> None:
super().__init__()
self.env = env
self.vec_normalize = unwrap_vec_normalize(env)
self.norm_observation = find_wrapper(env, NormalizeObservation)
self.norm_reward = find_wrapper(env, NormalizeReward)
self.device = None
def to(
self: PolicySelf,
device: Optional[torch.device] = None,
dtype: Optional[Union[torch.dtype, str]] = None,
non_blocking: bool = False,
) -> PolicySelf:
super().to(device, dtype, non_blocking)
self.device = device
return self
@abstractmethod
def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
...
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
if self.vec_normalize:
self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
if self.norm_observation:
self.norm_observation.save(
os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
)
if self.norm_reward:
self.norm_reward.save(os.path.join(path, NORMALIZE_REWARD_FILENAME))
torch.save(
self.state_dict(),
os.path.join(path, MODEL_FILENAME),
)
def load(self, path: str) -> None:
# VecNormalize load occurs in env.py
self.load_state_dict(
torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
)
if self.norm_observation:
self.norm_observation.load(
os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
)
if self.norm_reward:
self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME))
def reset_noise(self) -> None:
pass
def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
assert isinstance(obs, np.ndarray)
o = torch.as_tensor(obs)
if self.device is not None:
o = o.to(self.device)
return o
def num_trainable_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
def sync_normalization(self, destination_env) -> None:
current = destination_env
while current != current.unwrapped:
if isinstance(current, VecNormalize):
assert self.vec_normalize
current.ret_rms = deepcopy(self.vec_normalize.ret_rms)
if hasattr(self.vec_normalize, "obs_rms"):
current.obs_rms = deepcopy(self.vec_normalize.obs_rms)
elif isinstance(current, NormalizeObservation):
assert self.norm_observation
current.rms = deepcopy(self.norm_observation.rms)
elif isinstance(current, NormalizeReward):
assert self.norm_reward
current.rms = deepcopy(self.norm_reward.rms)
current = getattr(current, "venv", getattr(current, "env", current))
if not current:
raise AttributeError(
f"{type(current)} doesn't include env or venv attribute"
)