File size: 3,965 Bytes
ff8c6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import dataclasses
import shutil
import wandb
import yaml

from dataclasses import dataclass
from torch.utils.tensorboard.writer import SummaryWriter
from typing import Any, Dict, Optional, Sequence

from shared.callbacks.eval_callback import EvalCallback
from runner.env import make_env, make_eval_env
from runner.config import Config, RunArgs
from runner.running_utils import (
    ALGOS,
    load_hyperparams,
    set_seeds,
    get_device,
    make_policy,
    plot_eval_callback,
    flatten_hyperparameters,
)
from shared.stats import EpisodesStats


@dataclass
class TrainArgs(RunArgs):
    wandb_project_name: Optional[str] = None
    wandb_entity: Optional[str] = None
    wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)


def train(args: TrainArgs):
    print(args)
    hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
    print(hyperparams)
    config = Config(args, hyperparams, os.getcwd())

    wandb_enabled = args.wandb_project_name
    if wandb_enabled:
        wandb.tensorboard.patch(
            root_logdir=config.tensorboard_summary_path, pytorch=True
        )
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            config=hyperparams,  # type: ignore
            name=config.run_name,
            monitor_gym=True,
            save_code=True,
            tags=args.wandb_tags,
        )
        wandb.config.update(args)

    tb_writer = SummaryWriter(config.tensorboard_summary_path)

    set_seeds(args.seed, args.use_deterministic_algorithms)

    env = make_env(config, tb_writer=tb_writer, **config.env_hyperparams)
    device = get_device(config.device, env)
    policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
    algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)

    eval_env = make_eval_env(config, **config.env_hyperparams)
    record_best_videos = config.eval_params.get("record_best_videos", True)
    callback = EvalCallback(
        policy,
        eval_env,
        tb_writer,
        best_model_path=config.model_dir_path(best=True),
        **config.eval_params,
        video_env=make_eval_env(config, override_n_envs=1, **config.env_hyperparams)
        if record_best_videos
        else None,
        best_video_dir=config.best_videos_dir,
    )
    algo.learn(config.n_timesteps, callback=callback)

    policy.save(config.model_dir_path(best=False))

    eval_stats = callback.evaluate(n_episodes=10, print_returns=True)

    plot_eval_callback(callback, tb_writer, config.run_name)

    log_dict: Dict[str, Any] = {
        "eval": eval_stats._asdict(),
    }
    if callback.best:
        log_dict["best_eval"] = callback.best._asdict()
    log_dict.update(hyperparams)
    log_dict.update(vars(args))
    with open(config.logs_path, "a") as f:
        yaml.dump({config.run_name: log_dict}, f)

    best_eval_stats: EpisodesStats = callback.best  # type: ignore
    tb_writer.add_hparams(
        flatten_hyperparameters(hyperparams, vars(args)),
        {
            "hparam/best_mean": best_eval_stats.score.mean,
            "hparam/best_result": best_eval_stats.score.mean
            - best_eval_stats.score.std,
            "hparam/last_mean": eval_stats.score.mean,
            "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
        },
        None,
        config.run_name,
    )

    tb_writer.close()

    if wandb_enabled:
        shutil.make_archive(
            os.path.join(wandb.run.dir, config.model_dir_name()),
            "zip",
            config.model_dir_path(),
        )
        shutil.make_archive(
            os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
            "zip",
            config.model_dir_path(best=True),
        )
        wandb.finish()