ClementBM's picture
first commit
ffe7549
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import numpy as np
def create_self_play_callback(win_rate_thr, opponent_policies):
class SelfPlayCallback(DefaultCallbacks):
win_rate_threshold = win_rate_thr
def __init__(self):
super().__init__()
self.current_opponent = 0
def on_train_result(self, *, algorithm, result, **kwargs):
# Get the win rate for the train batch.
# Note that normally, one should set up a proper evaluation config,
# such that evaluation always happens on the already updated policy,
# instead of on the already used train_batch.
main_rew = result["hist_stats"].pop("policy_learned_reward")
opponent_rew = result["hist_stats"].pop("episode_reward")
if len(main_rew) != len(opponent_rew):
raise Exception(
"len(main_rew) != len(opponent_rew)",
len(main_rew),
len(opponent_rew),
result["hist_stats"].keys(),
"episode len",
len(opponent_rew),
)
won = 0
for r_main, r_opponent in zip(main_rew, opponent_rew):
if r_main > r_opponent:
won += 1
win_rate = won / len(main_rew)
result["win_rate"] = win_rate
print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="")
# If win rate is good -> Snapshot current policy and play against
# it next, keeping the snapshot fixed and only improving the "learned"
# policy.
if win_rate > self.win_rate_threshold:
self.current_opponent += 1
new_pol_id = f"learned_v{self.current_opponent}"
print(
f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})."
)
# Re-define the mapping function, such that "learned" is forced
# to play against any of the previously played policies
# (excluding "random").
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
# agent_id = [0|1] -> policy depends on episode ID
# This way, we make sure that both policies sometimes play
# (start player) and sometimes agent1 (player to move 2nd).
return (
"learned"
if episode.episode_id % 2 == int(agent_id[-1:])
else np.random.choice(
opponent_policies
+ [
f"learned_v{i}"
for i in range(1, self.current_opponent + 1)
]
)
)
new_policy = algorithm.add_policy(
policy_id=new_pol_id,
policy_cls=type(algorithm.get_policy("learned")),
policy_mapping_fn=policy_mapping_fn,
)
# Set the weights of the new policy to the learned policy.
# We'll keep training the learned policy, whereas `new_pol_id` will
# remain fixed.
learned_state = algorithm.get_policy("learned").get_state()
new_policy.set_state(learned_state)
# We need to sync the just copied local weights (from learned policy)
# to all the remote workers as well.
algorithm.workers.sync_weights()
else:
print("not good enough; will keep learning ...")
result["league_size"] = self.current_opponent + len(opponent_policies) + 1
return SelfPlayCallback