苏泓源
update
a257639
raw
history blame
6.43 kB
import os
from typing import Text, Dict
from stable_baselines3.common.utils import get_latest_run_id
import yaml
class Config:
def __init__(self, cfg_id: Text, global_seed: int, tmp: bool, root_dir: Text,
agent: Text = 'rl-gnn', reset_num_timesteps: bool = True, cfg_dict: Dict = None, model_path: Text = None):
self.cfg_id = cfg_id
self.seed = global_seed
if cfg_dict is not None:
cfg = cfg_dict
else:
file_path = './facility_location/cfg/{}.yaml'.format(self.cfg_id)
class TupleSafeLoader(yaml.SafeLoader):
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
TupleSafeLoader.add_constructor(
u'tag:yaml.org,2002:python/tuple',
TupleSafeLoader.construct_python_tuple)
def load_yaml(file_path):
cfg = yaml.load(open(file_path, 'r'), Loader=TupleSafeLoader)
return cfg
cfg = load_yaml(file_path)
# create dirs
self.root_dir = '/tmp/flp' if tmp else root_dir
self.agent = agent
self.multi = cfg.get('multi', False)
self.tb_log_path = os.path.join(self.root_dir, 'runs')
self.tb_log_name = f'{cfg_id}-agent-{agent}-seed-{global_seed}'
latest_run_id = get_latest_run_id(self.tb_log_path, self.tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
self.cfg_dir = os.path.join(self.root_dir,
'output', f'{cfg_id}-agent-{agent}-seed-{global_seed}_{latest_run_id + 1}')
self.ckpt_save_path = os.path.join(self.cfg_dir, 'ckpt')
self.best_model_path = os.path.join(self.cfg_dir, 'best-models')
self.latest_model_path = os.path.join(self.cfg_dir, 'latest-models')
self.load_model_path = model_path
# env
self.env_specs = cfg.get('env_specs', dict())
self.reward_specs = cfg.get('reward_specs', dict())
self.obs_specs = cfg.get('obs_specs', dict())
self.eval_specs = cfg.get('eval_specs', dict())
# agent config
self.agent_specs = cfg.get('agent_specs', dict())
self.mlp_specs = cfg.get('mlp_specs', dict())
self.gnn_specs = cfg.get('gnn_specs', dict())
self.ts_specs = cfg.get('ts_specs', dict())
self.popstar_specs = cfg.get('popstar_specs', dict())
self.ga_specs = cfg.get('ga_specs', dict())
# training config
self.gamma = cfg.get('gamma', 0.99)
self.tau = cfg.get('tau', 0.95)
self.state_encoder_specs = cfg.get('state_encoder_specs', dict())
self.policy_specs = cfg.get('policy_specs', dict())
self.value_specs = cfg.get('value_specs', dict())
self.lr = cfg.get('lr', 4e-4)
self.weightdecay = cfg.get('weightdecay', 0.0)
self.eps = cfg.get('eps', 1e-5)
self.value_pred_coef = cfg.get('value_pred_coef', 0.5)
self.entropy_coef = cfg.get('entropy_coef', 0.01)
self.clip_epsilon = cfg.get('clip_epsilon', 0.2)
self.max_num_iterations = cfg.get('max_num_iterations', 1000)
self.num_episodes_per_iteration = cfg.get('num_episodes_per_iteration', 1000)
self.max_sequence_length = cfg.get('max_sequence_length', 100)
self.num_optim_epoch = cfg.get('num_optim_epoch', 4)
self.mini_batch_size = cfg.get('mini_batch_size', 1024)
self.save_model_interval = cfg.get('save_model_interval', 10)
def log(self, logger, tb_logger):
"""Log cfg to logger and tensorboard."""
logger.info(f'id: {self.cfg_id}')
logger.info(f'seed: {self.seed}')
logger.info(f'env_specs: {self.env_specs}')
logger.info(f'reward_specs: {self.reward_specs}')
logger.info(f'obs_specs: {self.obs_specs}')
logger.info(f'agent_specs: {self.agent_specs}')
logger.info(f'gamma: {self.gamma}')
logger.info(f'tau: {self.tau}')
logger.info(f'state_encoder_specs: {self.state_encoder_specs}')
logger.info(f'policy_specs: {self.policy_specs}')
logger.info(f'value_specs: {self.value_specs}')
logger.info(f'lr: {self.lr}')
logger.info(f'weightdecay: {self.weightdecay}')
logger.info(f'eps: {self.eps}')
logger.info(f'value_pred_coef: {self.value_pred_coef}')
logger.info(f'entropy_coef: {self.entropy_coef}')
logger.info(f'clip_epsilon: {self.clip_epsilon}')
logger.info(f'max_num_iterations: {self.max_num_iterations}')
logger.info(f'num_episodes_per_iteration: {self.num_episodes_per_iteration}')
logger.info(f'max_sequence_length: {self.max_sequence_length}')
logger.info(f'num_optim_epoch: {self.num_optim_epoch}')
logger.info(f'mini_batch_size: {self.mini_batch_size}')
logger.info(f'save_model_interval: {self.save_model_interval}')
if tb_logger is not None:
tb_logger.add_hparams(
hparam_dict={
'id': self.cfg_id,
'seed': self.seed,
'env_specs': str(self.env_specs),
'reward_specs': str(self.reward_specs),
'obs_specs': str(self.obs_specs),
'agent_specs': str(self.agent_specs),
'gamma': self.gamma,
'tau': self.tau,
'state_encoder_specs': str(self.state_encoder_specs),
'policy_specs': str(self.policy_specs),
'value_specs': str(self.value_specs),
'lr': self.lr,
'weightdecay': self.weightdecay,
'eps': self.eps,
'value_pred_coef': self.value_pred_coef,
'entropy_coef': self.entropy_coef,
'clip_epsilon': self.clip_epsilon,
'max_num_iterations': self.max_num_iterations,
'num_episodes_per_iteration': self.num_episodes_per_iteration,
'max_sequence_length': self.max_sequence_length,
'num_optim_epoch': self.num_optim_epoch,
'mini_batch_size': self.mini_batch_size,
'save_model_interval': self.save_model_interval},
metric_dict={'hparam/placeholder': 0.0})