File size: 3,889 Bytes
ffe7549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import numpy as np


def create_self_play_callback(win_rate_thr, opponent_policies):
    class SelfPlayCallback(DefaultCallbacks):
        win_rate_threshold = win_rate_thr

        def __init__(self):
            super().__init__()
            self.current_opponent = 0

        def on_train_result(self, *, algorithm, result, **kwargs):
            # Get the win rate for the train batch.
            # Note that normally, one should set up a proper evaluation config,
            # such that evaluation always happens on the already updated policy,
            # instead of on the already used train_batch.
            main_rew = result["hist_stats"].pop("policy_learned_reward")
            opponent_rew = result["hist_stats"].pop("episode_reward")

            if len(main_rew) != len(opponent_rew):
                raise Exception(
                    "len(main_rew) != len(opponent_rew)",
                    len(main_rew),
                    len(opponent_rew),
                    result["hist_stats"].keys(),
                    "episode len",
                    len(opponent_rew),
                )

            won = 0
            for r_main, r_opponent in zip(main_rew, opponent_rew):
                if r_main > r_opponent:
                    won += 1
            win_rate = won / len(main_rew)

            result["win_rate"] = win_rate
            print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="")

            # If win rate is good -> Snapshot current policy and play against
            # it next, keeping the snapshot fixed and only improving the "learned"
            # policy.
            if win_rate > self.win_rate_threshold:
                self.current_opponent += 1
                new_pol_id = f"learned_v{self.current_opponent}"
                print(
                    f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})."
                )

                # Re-define the mapping function, such that "learned" is forced
                # to play against any of the previously played policies
                # (excluding "random").
                def policy_mapping_fn(agent_id, episode, worker, **kwargs):
                    # agent_id = [0|1] -> policy depends on episode ID
                    # This way, we make sure that both policies sometimes play
                    # (start player) and sometimes agent1 (player to move 2nd).
                    return (
                        "learned"
                        if episode.episode_id % 2 == int(agent_id[-1:])
                        else np.random.choice(
                            opponent_policies
                            + [
                                f"learned_v{i}"
                                for i in range(1, self.current_opponent + 1)
                            ]
                        )
                    )

                new_policy = algorithm.add_policy(
                    policy_id=new_pol_id,
                    policy_cls=type(algorithm.get_policy("learned")),
                    policy_mapping_fn=policy_mapping_fn,
                )

                # Set the weights of the new policy to the learned policy.
                # We'll keep training the learned policy, whereas `new_pol_id` will
                # remain fixed.
                learned_state = algorithm.get_policy("learned").get_state()
                new_policy.set_state(learned_state)
                # We need to sync the just copied local weights (from learned policy)
                # to all the remote workers as well.
                algorithm.workers.sync_weights()
            else:
                print("not good enough; will keep learning ...")

            result["league_size"] = self.current_opponent + len(opponent_policies) + 1

    return SelfPlayCallback