Spaces:
Runtime error
Runtime error
from ray.rllib.algorithms.callbacks import DefaultCallbacks | |
import numpy as np | |
def create_self_play_callback(win_rate_thr, opponent_policies): | |
class SelfPlayCallback(DefaultCallbacks): | |
win_rate_threshold = win_rate_thr | |
def __init__(self): | |
super().__init__() | |
self.current_opponent = 0 | |
def on_train_result(self, *, algorithm, result, **kwargs): | |
# Get the win rate for the train batch. | |
# Note that normally, one should set up a proper evaluation config, | |
# such that evaluation always happens on the already updated policy, | |
# instead of on the already used train_batch. | |
main_rew = result["hist_stats"].pop("policy_learned_reward") | |
opponent_rew = result["hist_stats"].pop("episode_reward") | |
if len(main_rew) != len(opponent_rew): | |
raise Exception( | |
"len(main_rew) != len(opponent_rew)", | |
len(main_rew), | |
len(opponent_rew), | |
result["hist_stats"].keys(), | |
"episode len", | |
len(opponent_rew), | |
) | |
won = 0 | |
for r_main, r_opponent in zip(main_rew, opponent_rew): | |
if r_main > r_opponent: | |
won += 1 | |
win_rate = won / len(main_rew) | |
result["win_rate"] = win_rate | |
print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="") | |
# If win rate is good -> Snapshot current policy and play against | |
# it next, keeping the snapshot fixed and only improving the "learned" | |
# policy. | |
if win_rate > self.win_rate_threshold: | |
self.current_opponent += 1 | |
new_pol_id = f"learned_v{self.current_opponent}" | |
print( | |
f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})." | |
) | |
# Re-define the mapping function, such that "learned" is forced | |
# to play against any of the previously played policies | |
# (excluding "random"). | |
def policy_mapping_fn(agent_id, episode, worker, **kwargs): | |
# agent_id = [0|1] -> policy depends on episode ID | |
# This way, we make sure that both policies sometimes play | |
# (start player) and sometimes agent1 (player to move 2nd). | |
return ( | |
"learned" | |
if episode.episode_id % 2 == int(agent_id[-1:]) | |
else np.random.choice( | |
opponent_policies | |
+ [ | |
f"learned_v{i}" | |
for i in range(1, self.current_opponent + 1) | |
] | |
) | |
) | |
new_policy = algorithm.add_policy( | |
policy_id=new_pol_id, | |
policy_cls=type(algorithm.get_policy("learned")), | |
policy_mapping_fn=policy_mapping_fn, | |
) | |
# Set the weights of the new policy to the learned policy. | |
# We'll keep training the learned policy, whereas `new_pol_id` will | |
# remain fixed. | |
learned_state = algorithm.get_policy("learned").get_state() | |
new_policy.set_state(learned_state) | |
# We need to sync the just copied local weights (from learned policy) | |
# to all the remote workers as well. | |
algorithm.workers.sync_weights() | |
else: | |
print("not good enough; will keep learning ...") | |
result["league_size"] = self.current_opponent + len(opponent_policies) + 1 | |
return SelfPlayCallback | |