ClementBM's picture
first commit
ffe7549
import argparse
import random
import ray
import ray.rllib.algorithms.ppo as ppo
from pettingzoo.classic import connect_four_v3
from ray import air, tune
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_torch
from ray.tune import CLIReporter, register_env
from connectfour.training.callbacks import create_self_play_callback
from connectfour.training.dummy_policies import (
AlwaysSameHeuristic,
BeatLastHeuristic,
LinearHeuristic,
RandomHeuristic,
)
from connectfour.training.models import Connect4MaskModel
from connectfour.training.wrappers import Connect4Env
torch, nn = try_import_torch()
def get_cli_args():
"""
Create CLI parser and return parsed arguments
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
"""
parser = argparse.ArgumentParser()
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--num-gpus", type=int, default=0)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument(
"--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps",
type=int,
default=10000000,
help="Number of timesteps to train.",
)
parser.add_argument(
"--win-rate-threshold",
type=float,
default=0.95,
help="Win-rate at which we setup another opponent by freezing the "
"current main policy and playing against a uniform distribution "
"of previously frozen 'main's from here on.",
)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args
def select_policy(agent_id, episode, **kwargs):
if episode.episode_id % 2 == int(agent_id[-1:]):
return "learned"
else:
return random.choice(["always_same", "beat_last", "random", "linear"])
if __name__ == "__main__":
args = get_cli_args()
ray.init(
num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
)
# define how to make the environment
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
# register that way to make the environment under an rllib name
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
config = (
ppo.PPOConfig()
.environment("connect4")
.framework("torch")
.training(model={"custom_model": Connect4MaskModel})
.callbacks(
create_self_play_callback(
win_rate_thr=args.win_rate_threshold,
opponent_policies=["always_same", "beat_last", "random", "linear"],
)
)
.rollouts(
num_rollout_workers=args.num_workers,
num_envs_per_worker=5,
)
.multi_agent(
policies={
"learned": PolicySpec(),
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
"linear": PolicySpec(policy_class=LinearHeuristic),
"beat_last": PolicySpec(policy_class=BeatLastHeuristic),
"random": PolicySpec(policy_class=RandomHeuristic),
},
policy_mapping_fn=select_policy,
policies_to_train=["learned"],
)
)
stop = {
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop=stop,
verbose=2,
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "iter",
"time_total_s": "time_total_s",
"timesteps_total": "ts",
"episodes_this_iter": "train_episodes",
"policy_reward_mean/learned": "reward",
"win_rate": "win_rate",
"league_size": "league_size",
},
sort_by_metric=True,
),
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True,
checkpoint_frequency=10,
),
),
).fit()
print("Best checkpoint", results.get_best_result().checkpoint)
ray.shutdown()