sgoodfriend's picture
DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3fd02ed
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import dataclasses
import shutil
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Sequence
import yaml
from torch.utils.tensorboard.writer import SummaryWriter
import wandb
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
from rl_algo_impls.runner.running_utils import (
ALGOS,
get_device,
hparam_dict,
load_hyperparams,
make_policy,
plot_eval_callback,
set_seeds,
)
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
MicrortsRewardDecayCallback,
)
from rl_algo_impls.shared.stats import EpisodesStats
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
@dataclass
class TrainArgs(RunArgs):
wandb_project_name: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
wandb_group: Optional[str] = None
def train(args: TrainArgs):
print(args)
hyperparams = load_hyperparams(args.algo, args.env)
print(hyperparams)
config = Config(args, hyperparams, os.getcwd())
wandb_enabled = bool(args.wandb_project_name)
if wandb_enabled:
wandb.tensorboard.patch(
root_logdir=config.tensorboard_summary_path, pytorch=True
)
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
config=asdict(hyperparams),
name=config.run_name(),
monitor_gym=True,
save_code=True,
tags=args.wandb_tags,
group=args.wandb_group,
)
wandb.config.update(args)
tb_writer = SummaryWriter(config.tensorboard_summary_path)
set_seeds(args.seed, args.use_deterministic_algorithms)
env = make_env(
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
)
device = get_device(config, env)
policy_factory = lambda: make_policy(
args.algo, env, device, **config.policy_hyperparams
)
policy = policy_factory()
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
num_parameters = policy.num_parameters()
num_trainable_parameters = policy.num_trainable_parameters()
if wandb_enabled:
wandb.run.summary["num_parameters"] = num_parameters # type: ignore
wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore
else:
print(
f"num_parameters = {num_parameters} ; "
f"num_trainable_parameters = {num_trainable_parameters}"
)
eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
record_best_videos = config.eval_hyperparams.get("record_best_videos", True)
eval_callback = EvalCallback(
policy,
eval_env,
tb_writer,
best_model_path=config.model_dir_path(best=True),
**config.eval_callback_params(),
video_env=make_eval_env(
config,
EnvHyperparams(**config.env_hyperparams),
override_hparams={"n_envs": 1},
)
if record_best_videos
else None,
best_video_dir=config.best_videos_dir,
additional_keys_to_log=config.additional_keys_to_log,
wandb_enabled=wandb_enabled,
)
callbacks: List[Callback] = [eval_callback]
if config.hyperparams.microrts_reward_decay_callback:
callbacks.append(MicrortsRewardDecayCallback(config, env))
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
if selfPlayWrapper:
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
algo.learn(config.n_timesteps, callbacks=callbacks)
policy.save(config.model_dir_path(best=False))
eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True)
plot_eval_callback(eval_callback, tb_writer, config.run_name())
log_dict: Dict[str, Any] = {
"eval": eval_stats._asdict(),
}
if eval_callback.best:
log_dict["best_eval"] = eval_callback.best._asdict()
log_dict.update(asdict(hyperparams))
log_dict.update(vars(args))
with open(config.logs_path, "a") as f:
yaml.dump({config.run_name(): log_dict}, f)
best_eval_stats: EpisodesStats = eval_callback.best # type: ignore
tb_writer.add_hparams(
hparam_dict(hyperparams, vars(args)),
{
"hparam/best_mean": best_eval_stats.score.mean,
"hparam/best_result": best_eval_stats.score.mean
- best_eval_stats.score.std,
"hparam/last_mean": eval_stats.score.mean,
"hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
},
None,
config.run_name(),
)
tb_writer.close()
if wandb_enabled:
shutil.make_archive(
os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore
"zip",
config.model_dir_path(),
)
wandb.finish()