import os from datetime import datetime from dataclasses import dataclass from typing import Any, Dict, NamedTuple, Optional, TypedDict, Union @dataclass class RunArgs: algo: str env: str seed: Optional[int] = None use_deterministic_algorithms: bool = True class EnvHyperparams(NamedTuple): is_procgen: bool = False n_envs: int = 1 frame_stack: int = 1 make_kwargs: Optional[Dict[str, Any]] = None no_reward_timeout_steps: Optional[int] = None no_reward_fire_steps: Optional[int] = None vec_env_class: str = "dummy" normalize: bool = False normalize_kwargs: Optional[Dict[str, Any]] = None rolling_length: int = 100 train_record_video: bool = False video_step_interval: Union[int, float] = 1_000_000 initial_steps_to_truncate: Optional[int] = None class Hyperparams(TypedDict, total=False): device: str n_timesteps: Union[int, float] env_hyperparams: Dict[str, Any] policy_hyperparams: Dict[str, Any] algo_hyperparams: Dict[str, Any] eval_params: Dict[str, Any] @dataclass class Config: args: RunArgs hyperparams: Hyperparams root_dir: str run_id: str = datetime.now().isoformat() def seed(self, training: bool = True) -> Optional[int]: seed = self.args.seed if training or seed is None: return seed return seed + self.env_hyperparams.get("n_envs", 1) @property def device(self) -> str: return self.hyperparams.get("device", "auto") @property def n_timesteps(self) -> int: return int(self.hyperparams.get("n_timesteps", 100_000)) @property def env_hyperparams(self) -> Dict[str, Any]: return self.hyperparams.get("env_hyperparams", {}) @property def policy_hyperparams(self) -> Dict[str, Any]: return self.hyperparams.get("policy_hyperparams", {}) @property def algo_hyperparams(self) -> Dict[str, Any]: return self.hyperparams.get("algo_hyperparams", {}) @property def eval_params(self) -> Dict[str, Any]: return self.hyperparams.get("eval_params", {}) @property def algo(self) -> str: return self.args.algo @property def env_id(self) -> str: return self.hyperparams.get("env_id") or self.args.env def model_name(self, include_seed: bool = True) -> str: # Use arg env name instead of environment name parts = [self.algo, self.args.env] if include_seed and self.args.seed is not None: parts.append(f"S{self.args.seed}") # Assume that the custom arg name already has the necessary information if not self.hyperparams.get("env_id"): make_kwargs = self.env_hyperparams.get("make_kwargs", {}) if make_kwargs: for k, v in make_kwargs.items(): if type(v) == bool and v: parts.append(k) elif type(v) == int and v: parts.append(f"{k}{v}") else: parts.append(str(v)) return "-".join(parts) @property def run_name(self) -> str: parts = [self.model_name(), self.run_id] return "-".join(parts) @property def saved_models_dir(self) -> str: return os.path.join(self.root_dir, "saved_models") @property def downloaded_models_dir(self) -> str: return os.path.join(self.root_dir, "downloaded_models") def model_dir_name( self, best: bool = False, extension: str = "", ) -> str: return self.model_name() + ("-best" if best else "") + extension def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str: return os.path.join( self.saved_models_dir if not downloaded else self.downloaded_models_dir, self.model_dir_name(best=best), ) @property def runs_dir(self) -> str: return os.path.join(self.root_dir, "runs") @property def tensorboard_summary_path(self) -> str: return os.path.join(self.runs_dir, self.run_name) @property def logs_path(self) -> str: return os.path.join(self.runs_dir, f"log.yml") @property def videos_dir(self) -> str: return os.path.join(self.root_dir, "videos") @property def video_prefix(self) -> str: return os.path.join(self.videos_dir, self.model_name()) @property def best_videos_dir(self) -> str: return os.path.join(self.videos_dir, f"{self.model_name()}-best")