|
import os |
|
import copy |
|
import gym |
|
import numpy as np |
|
import torch |
|
from tensorboardX import SummaryWriter |
|
from easydict import EasyDict |
|
|
|
from ding.config import compile_config |
|
from ding.worker import BaseLearner, BattleInteractionSerialEvaluator, NaiveReplayBuffer |
|
from ding.envs import BaseEnvManager, DingEnvWrapper |
|
from ding.policy import PPOPolicy |
|
from ding.model import VAC |
|
from ding.utils import set_pkg_seed, Scheduler, deep_merge_dicts |
|
from dizoo.league_demo.game_env import GameEnv |
|
from dizoo.league_demo.demo_league import DemoLeague |
|
from dizoo.league_demo.league_demo_collector import LeagueDemoCollector |
|
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config |
|
|
|
|
|
class EvalPolicy1: |
|
|
|
def __init__(self, optimal_policy: list) -> None: |
|
assert len(optimal_policy) == 2 |
|
self.optimal_policy = optimal_policy |
|
|
|
def forward(self, data: dict) -> dict: |
|
return { |
|
env_id: { |
|
'action': torch.from_numpy(np.random.choice([0, 1], p=self.optimal_policy, size=(1, ))) |
|
} |
|
for env_id in data.keys() |
|
} |
|
|
|
def reset(self, data_id: list = []) -> None: |
|
pass |
|
|
|
|
|
class EvalPolicy2: |
|
|
|
def forward(self, data: dict) -> dict: |
|
return { |
|
env_id: { |
|
'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, ))) |
|
} |
|
for env_id in data.keys() |
|
} |
|
|
|
def reset(self, data_id: list = []) -> None: |
|
pass |
|
|
|
|
|
def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)): |
|
cfg = compile_config( |
|
cfg, |
|
BaseEnvManager, |
|
PPOPolicy, |
|
BaseLearner, |
|
LeagueDemoCollector, |
|
BattleInteractionSerialEvaluator, |
|
NaiveReplayBuffer, |
|
save_cfg=True |
|
) |
|
env_type = cfg.env.env_type |
|
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num |
|
evaluator_env1 = BaseEnvManager( |
|
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager |
|
) |
|
evaluator_env2 = BaseEnvManager( |
|
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager |
|
) |
|
evaluator_env3 = BaseEnvManager( |
|
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager |
|
) |
|
|
|
evaluator_env1.seed(seed, dynamic_seed=False) |
|
evaluator_env2.seed(seed, dynamic_seed=False) |
|
evaluator_env3.seed(seed, dynamic_seed=False) |
|
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) |
|
|
|
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) |
|
league = DemoLeague(cfg.policy.other.league) |
|
eval_policy1 = EvalPolicy1(evaluator_env1._env_ref.optimal_policy) |
|
eval_policy2 = EvalPolicy2() |
|
policies = {} |
|
learners = {} |
|
collectors = {} |
|
|
|
for player_id in league.active_players_ids: |
|
|
|
model = VAC(**cfg.policy.model) |
|
policy = PPOPolicy(cfg.policy, model=model) |
|
policies[player_id] = policy |
|
collector_env = BaseEnvManager( |
|
env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager |
|
) |
|
collector_env.seed(seed) |
|
|
|
learners[player_id] = BaseLearner( |
|
cfg.policy.learn.learner, |
|
policy.learn_mode, |
|
tb_logger=tb_logger, |
|
exp_name=cfg.exp_name, |
|
instance_name=player_id + '_learner' |
|
) |
|
collectors[player_id] = LeagueDemoCollector( |
|
cfg.policy.collect.collector, |
|
collector_env, |
|
tb_logger=tb_logger, |
|
exp_name=cfg.exp_name, |
|
instance_name=player_id + '_collector', |
|
) |
|
|
|
model = VAC(**cfg.policy.model) |
|
policy = PPOPolicy(cfg.policy, model=model) |
|
policies['historical'] = policy |
|
|
|
eval_policy3 = PPOPolicy(cfg.policy, model=copy.deepcopy(model)).collect_mode |
|
|
|
main_key = [k for k in learners.keys() if k.startswith('main_player')][0] |
|
main_player = league.get_player_by_id(main_key) |
|
main_learner = learners[main_key] |
|
main_collector = collectors[main_key] |
|
|
|
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) |
|
evaluator1_cfg.stop_value = cfg.env.stop_value[0] |
|
evaluator1 = BattleInteractionSerialEvaluator( |
|
evaluator1_cfg, |
|
evaluator_env1, [policies[main_key].collect_mode, eval_policy1], |
|
tb_logger, |
|
exp_name=cfg.exp_name, |
|
instance_name='fixed_evaluator' |
|
) |
|
evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator) |
|
evaluator2_cfg.stop_value = cfg.env.stop_value[1] |
|
evaluator2 = BattleInteractionSerialEvaluator( |
|
evaluator2_cfg, |
|
evaluator_env2, [policies[main_key].collect_mode, eval_policy2], |
|
tb_logger, |
|
exp_name=cfg.exp_name, |
|
instance_name='uniform_evaluator' |
|
) |
|
evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator) |
|
evaluator3_cfg.stop_value = 99999999 |
|
evaluator3 = BattleInteractionSerialEvaluator( |
|
evaluator3_cfg, |
|
evaluator_env3, [policies[main_key].collect_mode, eval_policy3], |
|
tb_logger, |
|
exp_name=cfg.exp_name, |
|
instance_name='init_evaluator' |
|
) |
|
|
|
def load_checkpoint_fn(player_id: str, ckpt_path: str): |
|
state_dict = torch.load(ckpt_path) |
|
policies[player_id].learn_mode.load_state_dict(state_dict) |
|
|
|
torch.save(policies['historical'].learn_mode.state_dict(), league.reset_checkpoint_path) |
|
league.load_checkpoint = load_checkpoint_fn |
|
|
|
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): |
|
torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path) |
|
league.judge_snapshot(player_id, force=True) |
|
init_main_player_rating = league.metric_env.create_rating(mu=0) |
|
|
|
count = 0 |
|
while True: |
|
if evaluator1.should_eval(main_learner.train_iter): |
|
stop_flag1, episode_info = evaluator1.eval( |
|
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep |
|
) |
|
win_loss_result = [e['result'] for e in episode_info[0]] |
|
|
|
main_player.rating = league.metric_env.rate_1vsC( |
|
main_player.rating, league.metric_env.create_rating(mu=10, sigma=1e-8), win_loss_result |
|
) |
|
|
|
if evaluator2.should_eval(main_learner.train_iter): |
|
stop_flag2, episode_info = evaluator2.eval( |
|
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep |
|
) |
|
win_loss_result = [e['result'] for e in episode_info[0]] |
|
|
|
main_player.rating = league.metric_env.rate_1vsC( |
|
main_player.rating, league.metric_env.create_rating(mu=0, sigma=1e-8), win_loss_result |
|
) |
|
if evaluator3.should_eval(main_learner.train_iter): |
|
_, episode_info = evaluator3.eval( |
|
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep |
|
) |
|
win_loss_result = [e['result'] for e in episode_info[0]] |
|
|
|
main_player.rating, init_main_player_rating = league.metric_env.rate_1vs1( |
|
main_player.rating, init_main_player_rating, win_loss_result |
|
) |
|
tb_logger.add_scalar( |
|
'league/init_main_player_trueskill', init_main_player_rating.exposure, main_collector.envstep |
|
) |
|
if stop_flag1 and stop_flag2: |
|
break |
|
|
|
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): |
|
tb_logger.add_scalar( |
|
'league/{}_trueskill'.format(player_id), |
|
league.get_player_by_id(player_id).rating.exposure, main_collector.envstep |
|
) |
|
collector, learner = collectors[player_id], learners[player_id] |
|
job = league.get_job_info(player_id) |
|
opponent_player_id = job['player_id'][1] |
|
|
|
if 'historical' in opponent_player_id: |
|
opponent_policy = policies['historical'].collect_mode |
|
opponent_path = job['checkpoint_path'][1] |
|
opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu')) |
|
else: |
|
opponent_policy = policies[opponent_player_id].collect_mode |
|
collector.reset_policy([policies[player_id].collect_mode, opponent_policy]) |
|
train_data, episode_info = collector.collect(train_iter=learner.train_iter) |
|
train_data, episode_info = train_data[0], episode_info[0] |
|
for d in train_data: |
|
d['adv'] = d['reward'] |
|
|
|
for i in range(cfg.policy.learn.update_per_collect): |
|
learner.train(train_data, collector.envstep) |
|
torch.save(learner.policy.state_dict(), player_ckpt_path) |
|
|
|
player_info = learner.learn_info |
|
player_info['player_id'] = player_id |
|
league.update_active_player(player_info) |
|
league.judge_snapshot(player_id) |
|
|
|
job_finish_info = { |
|
'eval_flag': True, |
|
'launch_player': job['launch_player'], |
|
'player_id': job['player_id'], |
|
'result': [e['result'] for e in episode_info], |
|
} |
|
league.finish_job(job_finish_info) |
|
|
|
if main_collector.envstep >= max_env_step or main_learner.train_iter >= max_train_iter: |
|
break |
|
if count % 100 == 0: |
|
print(repr(league.payoff)) |
|
count += 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
main(league_demo_ppo_config) |
|
|