from typing import Dict from facility_location.agent import FacilityLocationMLPExtractor, FacilityLocationGNNExtractor, FacilityLocationAttentionGNNExtractor from facility_location.utils import Config def get_policy_kwargs(cfg: Config) -> Dict: if cfg.agent == 'rl-mlp': hidden_units = cfg.mlp_specs.get('hidden_units', (32, 32)) node_dim = hidden_units[-1] policy_feature_dim = FacilityLocationMLPExtractor.get_policy_feature_dim(node_dim) value_feature_dim = FacilityLocationMLPExtractor.get_value_feature_dim(node_dim) policy_kwargs = dict( policy_feature_dim=policy_feature_dim, value_feature_dim=value_feature_dim, policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), features_extractor_class=FacilityLocationMLPExtractor, features_extractor_kwargs=dict( hidden_units=hidden_units,), popstar=cfg.env_specs.get('popstar', False),) elif cfg.agent == 'rl-gnn': num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2) node_dim = cfg.gnn_specs.get('node_dim', 32) policy_feature_dim = FacilityLocationGNNExtractor.get_policy_feature_dim(node_dim) value_feature_dim = FacilityLocationGNNExtractor.get_value_feature_dim(node_dim) policy_kwargs = dict( policy_feature_dim=policy_feature_dim, value_feature_dim=value_feature_dim, policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), features_extractor_class=FacilityLocationGNNExtractor, features_extractor_kwargs=dict( num_gnn_layers=num_gnn_layers, node_dim=node_dim), popstar=cfg.env_specs.get('popstar', False),) elif cfg.agent == 'rl-agnn': num_gnn_layers = cfg.gnn_specs.get('num_gnn_layers', 2) node_dim = cfg.gnn_specs.get('node_dim', 32) policy_feature_dim = FacilityLocationAttentionGNNExtractor.get_policy_feature_dim(node_dim) value_feature_dim = FacilityLocationAttentionGNNExtractor.get_value_feature_dim(node_dim) policy_kwargs = dict( policy_feature_dim=policy_feature_dim, value_feature_dim=value_feature_dim, policy_hidden_units=cfg.agent_specs.get('policy_hidden_units', (32, 32, 1)), value_hidden_units=cfg.agent_specs.get('value_hidden_units', (32, 32, 1)), features_extractor_class=FacilityLocationAttentionGNNExtractor, features_extractor_kwargs=dict( num_gnn_layers=num_gnn_layers, node_dim=node_dim), popstar=cfg.env_specs.get('popstar', False),) else: raise NotImplementedError return policy_kwargs