File size: 5,752 Bytes
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import dataclasses
import inspect
import itertools
import os

from datetime import datetime
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, TypeVar, Union


RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")


@dataclass
class RunArgs:
    algo: str
    env: str
    seed: Optional[int] = None
    use_deterministic_algorithms: bool = True

    @classmethod
    def expand_from_dict(
        cls: Type[RunArgsSelf], d: Dict[str, Any]
    ) -> List[RunArgsSelf]:
        maybe_listify = lambda v: [v] if isinstance(v, str) or isinstance(v, int) else v
        algos = maybe_listify(d["algo"])
        envs = maybe_listify(d["env"])
        seeds = maybe_listify(d["seed"])
        args = []
        for algo, env, seed in itertools.product(algos, envs, seeds):
            _d = d.copy()
            _d.update({"algo": algo, "env": env, "seed": seed})
            args.append(cls(**_d))
        return args


@dataclass
class EnvHyperparams:
    env_type: str = "gymvec"
    n_envs: int = 1
    frame_stack: int = 1
    make_kwargs: Optional[Dict[str, Any]] = None
    no_reward_timeout_steps: Optional[int] = None
    no_reward_fire_steps: Optional[int] = None
    vec_env_class: str = "sync"
    normalize: bool = False
    normalize_kwargs: Optional[Dict[str, Any]] = None
    rolling_length: int = 100
    train_record_video: bool = False
    video_step_interval: Union[int, float] = 1_000_000
    initial_steps_to_truncate: Optional[int] = None
    clip_atari_rewards: bool = True


HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")


@dataclass
class Hyperparams:
    device: str = "auto"
    n_timesteps: Union[int, float] = 100_000
    env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
    policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
    algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
    eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
    env_id: Optional[str] = None

    @classmethod
    def from_dict_with_extra_fields(
        cls: Type[HyperparamsSelf], d: Dict[str, Any]
    ) -> HyperparamsSelf:
        return cls(
            **{k: v for k, v in d.items() if k in inspect.signature(cls).parameters}
        )


@dataclass
class Config:
    args: RunArgs
    hyperparams: Hyperparams
    root_dir: str
    run_id: str = datetime.now().isoformat()

    def seed(self, training: bool = True) -> Optional[int]:
        seed = self.args.seed
        if training or seed is None:
            return seed
        return seed + self.env_hyperparams.get("n_envs", 1)

    @property
    def device(self) -> str:
        return self.hyperparams.device

    @property
    def n_timesteps(self) -> int:
        return int(self.hyperparams.n_timesteps)

    @property
    def env_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.env_hyperparams

    @property
    def policy_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.policy_hyperparams

    @property
    def algo_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.algo_hyperparams

    @property
    def eval_params(self) -> Dict[str, Any]:
        return self.hyperparams.eval_params

    @property
    def algo(self) -> str:
        return self.args.algo

    @property
    def env_id(self) -> str:
        return self.hyperparams.env_id or self.args.env

    def model_name(self, include_seed: bool = True) -> str:
        # Use arg env name instead of environment name
        parts = [self.algo, self.args.env]
        if include_seed and self.args.seed is not None:
            parts.append(f"S{self.args.seed}")

        # Assume that the custom arg name already has the necessary information
        if not self.hyperparams.env_id:
            make_kwargs = self.env_hyperparams.get("make_kwargs", {})
            if make_kwargs:
                for k, v in make_kwargs.items():
                    if type(v) == bool and v:
                        parts.append(k)
                    elif type(v) == int and v:
                        parts.append(f"{k}{v}")
                    else:
                        parts.append(str(v))

        return "-".join(parts)

    def run_name(self, include_seed: bool = True) -> str:
        parts = [self.model_name(include_seed=include_seed), self.run_id]
        return "-".join(parts)

    @property
    def saved_models_dir(self) -> str:
        return os.path.join(self.root_dir, "saved_models")

    @property
    def downloaded_models_dir(self) -> str:
        return os.path.join(self.root_dir, "downloaded_models")

    def model_dir_name(
        self,
        best: bool = False,
        extension: str = "",
    ) -> str:
        return self.model_name() + ("-best" if best else "") + extension

    def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
        return os.path.join(
            self.saved_models_dir if not downloaded else self.downloaded_models_dir,
            self.model_dir_name(best=best),
        )

    @property
    def runs_dir(self) -> str:
        return os.path.join(self.root_dir, "runs")

    @property
    def tensorboard_summary_path(self) -> str:
        return os.path.join(self.runs_dir, self.run_name())

    @property
    def logs_path(self) -> str:
        return os.path.join(self.runs_dir, f"log.yml")

    @property
    def videos_dir(self) -> str:
        return os.path.join(self.root_dir, "videos")

    @property
    def video_prefix(self) -> str:
        return os.path.join(self.videos_dir, self.model_name())

    @property
    def best_videos_dir(self) -> str:
        return os.path.join(self.videos_dir, f"{self.model_name()}-best")