import argparse import random import ray import ray.rllib.algorithms.ppo as ppo from pettingzoo.classic import connect_four_v3 from ray import air, tune from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.framework import try_import_torch from ray.tune import CLIReporter, register_env from connectfour.training.callbacks import create_self_play_callback from connectfour.training.dummy_policies import ( AlwaysSameHeuristic, BeatLastHeuristic, LinearHeuristic, RandomHeuristic, ) from connectfour.training.models import Connect4MaskModel from connectfour.training.wrappers import Connect4Env torch, nn = try_import_torch() def get_cli_args(): """ Create CLI parser and return parsed arguments python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50 python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50 python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200 """ parser = argparse.ArgumentParser() parser.add_argument("--num-cpus", type=int, default=0) parser.add_argument("--num-gpus", type=int, default=0) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument( "--stop-iters", type=int, default=200, help="Number of iterations to train." ) parser.add_argument( "--stop-timesteps", type=int, default=10000000, help="Number of timesteps to train.", ) parser.add_argument( "--win-rate-threshold", type=float, default=0.95, help="Win-rate at which we setup another opponent by freezing the " "current main policy and playing against a uniform distribution " "of previously frozen 'main's from here on.", ) args = parser.parse_args() print(f"Running with following CLI args: {args}") return args def select_policy(agent_id, episode, **kwargs): if episode.episode_id % 2 == int(agent_id[-1:]): return "learned" else: return random.choice(["always_same", "beat_last", "random", "linear"]) if __name__ == "__main__": args = get_cli_args() ray.init( num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False ) # define how to make the environment env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array") # register that way to make the environment under an rllib name register_env("connect4", lambda config: Connect4Env(env_creator(config))) config = ( ppo.PPOConfig() .environment("connect4") .framework("torch") .training(model={"custom_model": Connect4MaskModel}) .callbacks( create_self_play_callback( win_rate_thr=args.win_rate_threshold, opponent_policies=["always_same", "beat_last", "random", "linear"], ) ) .rollouts( num_rollout_workers=args.num_workers, num_envs_per_worker=5, ) .multi_agent( policies={ "learned": PolicySpec(), "always_same": PolicySpec(policy_class=AlwaysSameHeuristic), "linear": PolicySpec(policy_class=LinearHeuristic), "beat_last": PolicySpec(policy_class=BeatLastHeuristic), "random": PolicySpec(policy_class=RandomHeuristic), }, policy_mapping_fn=select_policy, policies_to_train=["learned"], ) ) stop = { "timesteps_total": args.stop_timesteps, "training_iteration": args.stop_iters, } results = tune.Tuner( "PPO", param_space=config.to_dict(), run_config=air.RunConfig( stop=stop, verbose=2, progress_reporter=CLIReporter( metric_columns={ "training_iteration": "iter", "time_total_s": "time_total_s", "timesteps_total": "ts", "episodes_this_iter": "train_episodes", "policy_reward_mean/learned": "reward", "win_rate": "win_rate", "league_size": "league_size", }, sort_by_metric=True, ), checkpoint_config=air.CheckpointConfig( checkpoint_at_end=True, checkpoint_frequency=10, ), ), ).fit() print("Best checkpoint", results.get_best_result().checkpoint) ray.shutdown()