Spaces:
Runtime error
Runtime error
from typing import Dict | |
from facility_location.agent import FacilityLocationMLPExtractor, FacilityLocationGNNExtractor, FacilityLocationAttentionGNNExtractor | |
from facility_location.utils import Config | |
def get_policy_kwargs(cfg: Config) -> Dict: | |
if cfg.agent == 'rl-mlp': | |
hidden_units = cfg.mlp_specs.get('hidden_units', (32, 32)) | |
node_dim = hidden_units[-1] | |
policy_feature_dim = FacilityLocationMLPExtractor.get_policy_feature_dim(node_dim) | |
value_feature_dim = FacilityLocationMLPExtractor.get_value_feature_dim(node_dim) | |
policy_kwargs = dict( | |
policy_feature_dim=policy_feature_dim, | |
value_feature_dim=value_feature_dim, | |
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), | |
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), | |
features_extractor_class=FacilityLocationMLPExtractor, | |
features_extractor_kwargs=dict( | |
hidden_units=hidden_units,), | |
popstar=cfg.env_specs.get('popstar', False),) | |
elif cfg.agent == 'rl-gnn': | |
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2) | |
node_dim = cfg.gnn_specs.get('node_dim', 32) | |
policy_feature_dim = FacilityLocationGNNExtractor.get_policy_feature_dim(node_dim) | |
value_feature_dim = FacilityLocationGNNExtractor.get_value_feature_dim(node_dim) | |
policy_kwargs = dict( | |
policy_feature_dim=policy_feature_dim, | |
value_feature_dim=value_feature_dim, | |
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), | |
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), | |
features_extractor_class=FacilityLocationGNNExtractor, | |
features_extractor_kwargs=dict( | |
num_gnn_layers=num_gnn_layers, | |
node_dim=node_dim), | |
popstar=cfg.env_specs.get('popstar', False),) | |
elif cfg.agent == 'rl-agnn': | |
num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2) | |
node_dim = cfg.gnn_specs.get('node_dim', 32) | |
policy_feature_dim = FacilityLocationAttentionGNNExtractor.get_policy_feature_dim(node_dim) | |
value_feature_dim = FacilityLocationAttentionGNNExtractor.get_value_feature_dim(node_dim) | |
policy_kwargs = dict( | |
policy_feature_dim=policy_feature_dim, | |
value_feature_dim=value_feature_dim, | |
policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), | |
value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), | |
features_extractor_class=FacilityLocationAttentionGNNExtractor, | |
features_extractor_kwargs=dict( | |
num_gnn_layers=num_gnn_layers, | |
node_dim=node_dim), | |
popstar=cfg.env_specs.get('popstar', False),) | |
else: | |
raise NotImplementedError | |
return policy_kwargs | |