File size: 3,965 Bytes
ff8c6a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import dataclasses
import shutil
import wandb
import yaml
from dataclasses import dataclass
from torch.utils.tensorboard.writer import SummaryWriter
from typing import Any, Dict, Optional, Sequence
from shared.callbacks.eval_callback import EvalCallback
from runner.env import make_env, make_eval_env
from runner.config import Config, RunArgs
from runner.running_utils import (
ALGOS,
load_hyperparams,
set_seeds,
get_device,
make_policy,
plot_eval_callback,
flatten_hyperparameters,
)
from shared.stats import EpisodesStats
@dataclass
class TrainArgs(RunArgs):
wandb_project_name: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
def train(args: TrainArgs):
print(args)
hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
print(hyperparams)
config = Config(args, hyperparams, os.getcwd())
wandb_enabled = args.wandb_project_name
if wandb_enabled:
wandb.tensorboard.patch(
root_logdir=config.tensorboard_summary_path, pytorch=True
)
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
config=hyperparams, # type: ignore
name=config.run_name,
monitor_gym=True,
save_code=True,
tags=args.wandb_tags,
)
wandb.config.update(args)
tb_writer = SummaryWriter(config.tensorboard_summary_path)
set_seeds(args.seed, args.use_deterministic_algorithms)
env = make_env(config, tb_writer=tb_writer, **config.env_hyperparams)
device = get_device(config.device, env)
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
eval_env = make_eval_env(config, **config.env_hyperparams)
record_best_videos = config.eval_params.get("record_best_videos", True)
callback = EvalCallback(
policy,
eval_env,
tb_writer,
best_model_path=config.model_dir_path(best=True),
**config.eval_params,
video_env=make_eval_env(config, override_n_envs=1, **config.env_hyperparams)
if record_best_videos
else None,
best_video_dir=config.best_videos_dir,
)
algo.learn(config.n_timesteps, callback=callback)
policy.save(config.model_dir_path(best=False))
eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
plot_eval_callback(callback, tb_writer, config.run_name)
log_dict: Dict[str, Any] = {
"eval": eval_stats._asdict(),
}
if callback.best:
log_dict["best_eval"] = callback.best._asdict()
log_dict.update(hyperparams)
log_dict.update(vars(args))
with open(config.logs_path, "a") as f:
yaml.dump({config.run_name: log_dict}, f)
best_eval_stats: EpisodesStats = callback.best # type: ignore
tb_writer.add_hparams(
flatten_hyperparameters(hyperparams, vars(args)),
{
"hparam/best_mean": best_eval_stats.score.mean,
"hparam/best_result": best_eval_stats.score.mean
- best_eval_stats.score.std,
"hparam/last_mean": eval_stats.score.mean,
"hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
},
None,
config.run_name,
)
tb_writer.close()
if wandb_enabled:
shutil.make_archive(
os.path.join(wandb.run.dir, config.model_dir_name()),
"zip",
config.model_dir_path(),
)
shutil.make_archive(
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
"zip",
config.model_dir_path(best=True),
)
wandb.finish()
|