Spaces:
Runtime error
Runtime error
File size: 4,610 Bytes
ffe7549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import argparse
import random
import ray
import ray.rllib.algorithms.ppo as ppo
from pettingzoo.classic import connect_four_v3
from ray import air, tune
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_torch
from ray.tune import CLIReporter, register_env
from connectfour.training.callbacks import create_self_play_callback
from connectfour.training.dummy_policies import (
AlwaysSameHeuristic,
BeatLastHeuristic,
LinearHeuristic,
RandomHeuristic,
)
from connectfour.training.models import Connect4MaskModel
from connectfour.training.wrappers import Connect4Env
torch, nn = try_import_torch()
def get_cli_args():
"""
Create CLI parser and return parsed arguments
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
"""
parser = argparse.ArgumentParser()
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--num-gpus", type=int, default=0)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument(
"--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps",
type=int,
default=10000000,
help="Number of timesteps to train.",
)
parser.add_argument(
"--win-rate-threshold",
type=float,
default=0.95,
help="Win-rate at which we setup another opponent by freezing the "
"current main policy and playing against a uniform distribution "
"of previously frozen 'main's from here on.",
)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args
def select_policy(agent_id, episode, **kwargs):
if episode.episode_id % 2 == int(agent_id[-1:]):
return "learned"
else:
return random.choice(["always_same", "beat_last", "random", "linear"])
if __name__ == "__main__":
args = get_cli_args()
ray.init(
num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
)
# define how to make the environment
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
# register that way to make the environment under an rllib name
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
config = (
ppo.PPOConfig()
.environment("connect4")
.framework("torch")
.training(model={"custom_model": Connect4MaskModel})
.callbacks(
create_self_play_callback(
win_rate_thr=args.win_rate_threshold,
opponent_policies=["always_same", "beat_last", "random", "linear"],
)
)
.rollouts(
num_rollout_workers=args.num_workers,
num_envs_per_worker=5,
)
.multi_agent(
policies={
"learned": PolicySpec(),
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
"linear": PolicySpec(policy_class=LinearHeuristic),
"beat_last": PolicySpec(policy_class=BeatLastHeuristic),
"random": PolicySpec(policy_class=RandomHeuristic),
},
policy_mapping_fn=select_policy,
policies_to_train=["learned"],
)
)
stop = {
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop=stop,
verbose=2,
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "iter",
"time_total_s": "time_total_s",
"timesteps_total": "ts",
"episodes_this_iter": "train_episodes",
"policy_reward_mean/learned": "reward",
"win_rate": "win_rate",
"league_size": "league_size",
},
sort_by_metric=True,
),
checkpoint_config=air.CheckpointConfig(
checkpoint_at_end=True,
checkpoint_frequency=10,
),
),
).fit()
print("Best checkpoint", results.get_best_result().checkpoint)
ray.shutdown()
|