File size: 4,611 Bytes
79943a9
 
 
48f24b9
 
 
 
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e63d2a
 
 
 
 
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48f24b9
 
 
 
 
 
 
 
 
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional, Type, TypeVar, Union

import numpy as np
import torch
import torch.nn as nn
from stable_baselines3.common.vec_env import unwrap_vec_normalize
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize

from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper

ACTIVATION: Dict[str, Type[nn.Module]] = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
}

VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
MODEL_FILENAME = "model.pth"
NORMALIZE_OBSERVATION_FILENAME = "norm_obs.npz"
NORMALIZE_REWARD_FILENAME = "norm_reward.npz"

PolicySelf = TypeVar("PolicySelf", bound="Policy")


class Policy(nn.Module, ABC):
    @abstractmethod
    def __init__(self, env: VecEnv, **kwargs) -> None:
        super().__init__()
        self.env = env
        self.vec_normalize = unwrap_vec_normalize(env)
        self.norm_observation = find_wrapper(env, NormalizeObservation)
        self.norm_reward = find_wrapper(env, NormalizeReward)
        self.device = None

    def to(
        self: PolicySelf,
        device: Optional[torch.device] = None,
        dtype: Optional[Union[torch.dtype, str]] = None,
        non_blocking: bool = False,
    ) -> PolicySelf:
        super().to(device, dtype, non_blocking)
        self.device = device
        return self

    @abstractmethod
    def act(
        self,
        obs: VecEnvObs,
        deterministic: bool = True,
        action_masks: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        ...

    def save(self, path: str) -> None:
        os.makedirs(path, exist_ok=True)

        if self.vec_normalize:
            self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
        if self.norm_observation:
            self.norm_observation.save(
                os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
            )
        if self.norm_reward:
            self.norm_reward.save(os.path.join(path, NORMALIZE_REWARD_FILENAME))
        torch.save(
            self.state_dict(),
            os.path.join(path, MODEL_FILENAME),
        )

    def load(self, path: str) -> None:
        # VecNormalize load occurs in env.py
        self.load_state_dict(
            torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
        )
        if self.norm_observation:
            self.norm_observation.load(
                os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
            )
        if self.norm_reward:
            self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME))

    def load_from(self: PolicySelf, policy: PolicySelf) -> PolicySelf:
        self.load_state_dict(policy.state_dict())
        if self.norm_observation:
            assert policy.norm_observation
            self.norm_observation.load_from(policy.norm_observation)
        if self.norm_reward:
            assert policy.norm_reward
            self.norm_reward.load_from(policy.norm_reward)
        return self

    def reset_noise(self) -> None:
        pass

    def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
        assert isinstance(obs, np.ndarray)
        o = torch.as_tensor(obs)
        if self.device is not None:
            o = o.to(self.device)
        return o

    def num_trainable_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def num_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def sync_normalization(self, destination_env) -> None:
        current = destination_env
        while current != current.unwrapped:
            if isinstance(current, VecNormalize):
                assert self.vec_normalize
                current.ret_rms = deepcopy(self.vec_normalize.ret_rms)
                if hasattr(self.vec_normalize, "obs_rms"):
                    current.obs_rms = deepcopy(self.vec_normalize.obs_rms)
            elif isinstance(current, NormalizeObservation):
                assert self.norm_observation
                current.rms = deepcopy(self.norm_observation.rms)
            elif isinstance(current, NormalizeReward):
                assert self.norm_reward
                current.rms = deepcopy(self.norm_reward.rms)
            current = getattr(current, "venv", getattr(current, "env", current))
            if not current:
                raise AttributeError(
                    f"{type(current)} doesn't include env or venv attribute"
                )