Spaces:
Runtime error
Runtime error
import argparse | |
import random | |
import ray | |
import ray.rllib.algorithms.ppo as ppo | |
from pettingzoo.classic import connect_four_v3 | |
from ray import air, tune | |
from ray.rllib.policy.policy import PolicySpec | |
from ray.rllib.utils.framework import try_import_torch | |
from ray.tune import CLIReporter, register_env | |
from connectfour.training.callbacks import create_self_play_callback | |
from connectfour.training.dummy_policies import ( | |
AlwaysSameHeuristic, | |
BeatLastHeuristic, | |
LinearHeuristic, | |
RandomHeuristic, | |
) | |
from connectfour.training.models import Connect4MaskModel | |
from connectfour.training.wrappers import Connect4Env | |
torch, nn = try_import_torch() | |
def get_cli_args(): | |
""" | |
Create CLI parser and return parsed arguments | |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50 | |
python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50 | |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200 | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--num-cpus", type=int, default=0) | |
parser.add_argument("--num-gpus", type=int, default=0) | |
parser.add_argument("--num-workers", type=int, default=2) | |
parser.add_argument( | |
"--stop-iters", type=int, default=200, help="Number of iterations to train." | |
) | |
parser.add_argument( | |
"--stop-timesteps", | |
type=int, | |
default=10000000, | |
help="Number of timesteps to train.", | |
) | |
parser.add_argument( | |
"--win-rate-threshold", | |
type=float, | |
default=0.95, | |
help="Win-rate at which we setup another opponent by freezing the " | |
"current main policy and playing against a uniform distribution " | |
"of previously frozen 'main's from here on.", | |
) | |
args = parser.parse_args() | |
print(f"Running with following CLI args: {args}") | |
return args | |
def select_policy(agent_id, episode, **kwargs): | |
if episode.episode_id % 2 == int(agent_id[-1:]): | |
return "learned" | |
else: | |
return random.choice(["always_same", "beat_last", "random", "linear"]) | |
if __name__ == "__main__": | |
args = get_cli_args() | |
ray.init( | |
num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False | |
) | |
# define how to make the environment | |
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array") | |
# register that way to make the environment under an rllib name | |
register_env("connect4", lambda config: Connect4Env(env_creator(config))) | |
config = ( | |
ppo.PPOConfig() | |
.environment("connect4") | |
.framework("torch") | |
.training(model={"custom_model": Connect4MaskModel}) | |
.callbacks( | |
create_self_play_callback( | |
win_rate_thr=args.win_rate_threshold, | |
opponent_policies=["always_same", "beat_last", "random", "linear"], | |
) | |
) | |
.rollouts( | |
num_rollout_workers=args.num_workers, | |
num_envs_per_worker=5, | |
) | |
.multi_agent( | |
policies={ | |
"learned": PolicySpec(), | |
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic), | |
"linear": PolicySpec(policy_class=LinearHeuristic), | |
"beat_last": PolicySpec(policy_class=BeatLastHeuristic), | |
"random": PolicySpec(policy_class=RandomHeuristic), | |
}, | |
policy_mapping_fn=select_policy, | |
policies_to_train=["learned"], | |
) | |
) | |
stop = { | |
"timesteps_total": args.stop_timesteps, | |
"training_iteration": args.stop_iters, | |
} | |
results = tune.Tuner( | |
"PPO", | |
param_space=config.to_dict(), | |
run_config=air.RunConfig( | |
stop=stop, | |
verbose=2, | |
progress_reporter=CLIReporter( | |
metric_columns={ | |
"training_iteration": "iter", | |
"time_total_s": "time_total_s", | |
"timesteps_total": "ts", | |
"episodes_this_iter": "train_episodes", | |
"policy_reward_mean/learned": "reward", | |
"win_rate": "win_rate", | |
"league_size": "league_size", | |
}, | |
sort_by_metric=True, | |
), | |
checkpoint_config=air.CheckpointConfig( | |
checkpoint_at_end=True, | |
checkpoint_frequency=10, | |
), | |
), | |
).fit() | |
print("Best checkpoint", results.get_best_result().checkpoint) | |
ray.shutdown() | |