File size: 4,610 Bytes
ffe7549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import random

import ray
import ray.rllib.algorithms.ppo as ppo
from pettingzoo.classic import connect_four_v3
from ray import air, tune
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_torch
from ray.tune import CLIReporter, register_env

from connectfour.training.callbacks import create_self_play_callback
from connectfour.training.dummy_policies import (
    AlwaysSameHeuristic,
    BeatLastHeuristic,
    LinearHeuristic,
    RandomHeuristic,
)
from connectfour.training.models import Connect4MaskModel
from connectfour.training.wrappers import Connect4Env

torch, nn = try_import_torch()


def get_cli_args():
    """
    Create CLI parser and return parsed arguments

    python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
    python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
    python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-cpus", type=int, default=0)
    parser.add_argument("--num-gpus", type=int, default=0)
    parser.add_argument("--num-workers", type=int, default=2)

    parser.add_argument(
        "--stop-iters", type=int, default=200, help="Number of iterations to train."
    )
    parser.add_argument(
        "--stop-timesteps",
        type=int,
        default=10000000,
        help="Number of timesteps to train.",
    )
    parser.add_argument(
        "--win-rate-threshold",
        type=float,
        default=0.95,
        help="Win-rate at which we setup another opponent by freezing the "
        "current main policy and playing against a uniform distribution "
        "of previously frozen 'main's from here on.",
    )
    args = parser.parse_args()
    print(f"Running with following CLI args: {args}")
    return args


def select_policy(agent_id, episode, **kwargs):
    if episode.episode_id % 2 == int(agent_id[-1:]):
        return "learned"
    else:
        return random.choice(["always_same", "beat_last", "random", "linear"])


if __name__ == "__main__":
    args = get_cli_args()

    ray.init(
        num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
    )

    # define how to make the environment
    env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")

    # register that way to make the environment under an rllib name
    register_env("connect4", lambda config: Connect4Env(env_creator(config)))

    config = (
        ppo.PPOConfig()
        .environment("connect4")
        .framework("torch")
        .training(model={"custom_model": Connect4MaskModel})
        .callbacks(
            create_self_play_callback(
                win_rate_thr=args.win_rate_threshold,
                opponent_policies=["always_same", "beat_last", "random", "linear"],
            )
        )
        .rollouts(
            num_rollout_workers=args.num_workers,
            num_envs_per_worker=5,
        )
        .multi_agent(
            policies={
                "learned": PolicySpec(),
                "always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
                "linear": PolicySpec(policy_class=LinearHeuristic),
                "beat_last": PolicySpec(policy_class=BeatLastHeuristic),
                "random": PolicySpec(policy_class=RandomHeuristic),
            },
            policy_mapping_fn=select_policy,
            policies_to_train=["learned"],
        )
    )

    stop = {
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }

    results = tune.Tuner(
        "PPO",
        param_space=config.to_dict(),
        run_config=air.RunConfig(
            stop=stop,
            verbose=2,
            progress_reporter=CLIReporter(
                metric_columns={
                    "training_iteration": "iter",
                    "time_total_s": "time_total_s",
                    "timesteps_total": "ts",
                    "episodes_this_iter": "train_episodes",
                    "policy_reward_mean/learned": "reward",
                    "win_rate": "win_rate",
                    "league_size": "league_size",
                },
                sort_by_metric=True,
            ),
            checkpoint_config=air.CheckpointConfig(
                checkpoint_at_end=True,
                checkpoint_frequency=10,
            ),
        ),
    ).fit()

    print("Best checkpoint", results.get_best_result().checkpoint)

    ray.shutdown()