File size: 1,740 Bytes
68e589c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import os
import torch
import torch.nn as nn

from abc import ABC, abstractmethod
from stable_baselines3.common.vec_env import unwrap_vec_normalize
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
from typing import Dict, Optional, Type, TypeVar, Union

ACTIVATION: Dict[str, Type[nn.Module]] = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
}

VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
MODEL_FILENAME = "model.pth"

PolicySelf = TypeVar("PolicySelf", bound="Policy")


class Policy(nn.Module, ABC):
    @abstractmethod
    def __init__(self, env: VecEnv, **kwargs) -> None:
        super().__init__()
        self.env = env
        self.vec_normalize = unwrap_vec_normalize(env)
        self.device = None

    def to(
        self: PolicySelf,
        device: Optional[torch.device] = None,
        dtype: Optional[Union[torch.dtype, str]] = None,
        non_blocking: bool = False,
    ) -> PolicySelf:
        super().to(device, dtype, non_blocking)
        self.device = device
        return self

    @abstractmethod
    def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
        ...

    def save(self, path: str) -> None:
        os.makedirs(path, exist_ok=True)

        if self.vec_normalize:
            self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
        torch.save(
            self.state_dict(),
            os.path.join(path, MODEL_FILENAME),
        )

    @abstractmethod
    def load(self, path: str) -> None:
        # VecNormalize load occurs in env.py
        self.load_state_dict(
            torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
        )

    def reset_noise(self) -> None:
        pass