苏泓源
update
a257639
raw
history blame
2.98 kB
from typing import Dict
from facility_location.agent import FacilityLocationMLPExtractor, FacilityLocationGNNExtractor, FacilityLocationAttentionGNNExtractor
from facility_location.utils import Config
def get_policy_kwargs(cfg: Config) -> Dict:
if cfg.agent == 'rl-mlp':
hidden_units = cfg.mlp_specs.get('hidden_units', (32, 32))
node_dim = hidden_units[-1]
policy_feature_dim = FacilityLocationMLPExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationMLPExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationMLPExtractor,
features_extractor_kwargs=dict(
hidden_units=hidden_units,),
popstar=cfg.env_specs.get('popstar', False),)
elif cfg.agent == 'rl-gnn':
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2)
node_dim = cfg.gnn_specs.get('node_dim', 32)
policy_feature_dim = FacilityLocationGNNExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationGNNExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationGNNExtractor,
features_extractor_kwargs=dict(
num_gnn_layers=num_gnn_layers,
node_dim=node_dim),
popstar=cfg.env_specs.get('popstar', False),)
elif cfg.agent == 'rl-agnn':
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2)
node_dim = cfg.gnn_specs.get('node_dim', 32)
policy_feature_dim = FacilityLocationAttentionGNNExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationAttentionGNNExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationAttentionGNNExtractor,
features_extractor_kwargs=dict(
num_gnn_layers=num_gnn_layers,
node_dim=node_dim),
popstar=cfg.env_specs.get('popstar', False),)
else:
raise NotImplementedError
return policy_kwargs