from ray.rllib.algorithms.callbacks import DefaultCallbacks import numpy as np def create_self_play_callback(win_rate_thr, opponent_policies): class SelfPlayCallback(DefaultCallbacks): win_rate_threshold = win_rate_thr def __init__(self): super().__init__() self.current_opponent = 0 def on_train_result(self, *, algorithm, result, **kwargs): # Get the win rate for the train batch. # Note that normally, one should set up a proper evaluation config, # such that evaluation always happens on the already updated policy, # instead of on the already used train_batch. main_rew = result["hist_stats"].pop("policy_learned_reward") opponent_rew = result["hist_stats"].pop("episode_reward") if len(main_rew) != len(opponent_rew): raise Exception( "len(main_rew) != len(opponent_rew)", len(main_rew), len(opponent_rew), result["hist_stats"].keys(), "episode len", len(opponent_rew), ) won = 0 for r_main, r_opponent in zip(main_rew, opponent_rew): if r_main > r_opponent: won += 1 win_rate = won / len(main_rew) result["win_rate"] = win_rate print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="") # If win rate is good -> Snapshot current policy and play against # it next, keeping the snapshot fixed and only improving the "learned" # policy. if win_rate > self.win_rate_threshold: self.current_opponent += 1 new_pol_id = f"learned_v{self.current_opponent}" print( f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})." ) # Re-define the mapping function, such that "learned" is forced # to play against any of the previously played policies # (excluding "random"). def policy_mapping_fn(agent_id, episode, worker, **kwargs): # agent_id = [0|1] -> policy depends on episode ID # This way, we make sure that both policies sometimes play # (start player) and sometimes agent1 (player to move 2nd). return ( "learned" if episode.episode_id % 2 == int(agent_id[-1:]) else np.random.choice( opponent_policies + [ f"learned_v{i}" for i in range(1, self.current_opponent + 1) ] ) ) new_policy = algorithm.add_policy( policy_id=new_pol_id, policy_cls=type(algorithm.get_policy("learned")), policy_mapping_fn=policy_mapping_fn, ) # Set the weights of the new policy to the learned policy. # We'll keep training the learned policy, whereas `new_pol_id` will # remain fixed. learned_state = algorithm.get_policy("learned").get_state() new_policy.set_state(learned_state) # We need to sync the just copied local weights (from learned policy) # to all the remote workers as well. algorithm.workers.sync_weights() else: print("not good enough; will keep learning ...") result["league_size"] = self.current_opponent + len(opponent_policies) + 1 return SelfPlayCallback