Spaces:
Runtime error
Runtime error
File size: 2,979 Bytes
a257639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from typing import Dict
from facility_location.agent import FacilityLocationMLPExtractor, FacilityLocationGNNExtractor, FacilityLocationAttentionGNNExtractor
from facility_location.utils import Config
def get_policy_kwargs(cfg: Config) -> Dict:
if cfg.agent == 'rl-mlp':
hidden_units = cfg.mlp_specs.get('hidden_units', (32, 32))
node_dim = hidden_units[-1]
policy_feature_dim = FacilityLocationMLPExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationMLPExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationMLPExtractor,
features_extractor_kwargs=dict(
hidden_units=hidden_units,),
popstar=cfg.env_specs.get('popstar', False),)
elif cfg.agent == 'rl-gnn':
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2)
node_dim = cfg.gnn_specs.get('node_dim', 32)
policy_feature_dim = FacilityLocationGNNExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationGNNExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationGNNExtractor,
features_extractor_kwargs=dict(
num_gnn_layers=num_gnn_layers,
node_dim=node_dim),
popstar=cfg.env_specs.get('popstar', False),)
elif cfg.agent == 'rl-agnn':
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2)
node_dim = cfg.gnn_specs.get('node_dim', 32)
policy_feature_dim = FacilityLocationAttentionGNNExtractor.get_policy_feature_dim(node_dim)
value_feature_dim = FacilityLocationAttentionGNNExtractor.get_value_feature_dim(node_dim)
policy_kwargs = dict(
policy_feature_dim=policy_feature_dim,
value_feature_dim=value_feature_dim,
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)),
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)),
features_extractor_class=FacilityLocationAttentionGNNExtractor,
features_extractor_kwargs=dict(
num_gnn_layers=num_gnn_layers,
node_dim=node_dim),
popstar=cfg.env_specs.get('popstar', False),)
else:
raise NotImplementedError
return policy_kwargs
|