Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import setproctitle | |
from absl import app, flags | |
import time | |
import random | |
from typing import Tuple, Union, Text | |
import numpy as np | |
import torch as th | |
import sys | |
import gymnasium | |
sys.modules["gym"] = gymnasium | |
from stable_baselines3.common.evaluation import evaluate_policy | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.monitor import Monitor | |
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnvWrapper | |
from facility_location.agent.solver import PMPSolver | |
from facility_location.env import EvalPMPEnv, MULTIPMP | |
from facility_location.utils import Config | |
from facility_location.agent import MaskedFacilityLocationActorCriticPolicy | |
from facility_location.utils.policy import get_policy_kwargs | |
import warnings | |
warnings.filterwarnings('ignore') | |
AGENT = Union[PMPSolver, PPO] | |
def get_model(cfg: Config, | |
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], | |
device: str) -> PPO: | |
policy_kwargs = get_policy_kwargs(cfg) | |
model = PPO(MaskedFacilityLocationActorCriticPolicy, | |
env, | |
verbose=1, | |
policy_kwargs=policy_kwargs, | |
device=device) | |
return model | |
def get_agent(cfg: Config, | |
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], | |
model_path: Text) -> AGENT: | |
if cfg.agent in ['rl-mlp', 'rl-gnn', 'rl-agnn']: | |
test_model = get_model(cfg, env, device='cuda:0') | |
trained_model = PPO.load(model_path) | |
test_model.set_parameters(trained_model.get_parameters()) | |
agent = test_model | |
else: | |
raise ValueError(f'Agent {cfg.agent} not supported.') | |
return agent | |
def evaluate(agent: AGENT, | |
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv], | |
num_cases: int, | |
return_episode_rewards: bool): | |
if isinstance(agent, PPO): | |
return evaluate_ppo(agent, env, num_cases, return_episode_rewards=return_episode_rewards) | |
else: | |
raise ValueError(f'Agent {agent} not supported.') | |
from stable_baselines3.common.callbacks import BaseCallback | |
def evaluate_ppo(agent: PPO, env: EvalPMPEnv, num_cases: int, return_episode_rewards: bool) -> Tuple[float, float]: | |
rewards, _ = evaluate_policy(agent, env, n_eval_episodes=num_cases, return_episode_rewards=return_episode_rewards) | |
return rewards | |
def main(data_npy, boost=False): | |
th.manual_seed(0) | |
np.random.seed(0) | |
random.seed(0) | |
model_path = './facility_location/best_model.zip' | |
cfg = Config('plot', 0, False, '/data2/suhongyuan/flp', 'rl-gnn', model_path=model_path) | |
eval_env = MULTIPMP(cfg, data_npy, boost) | |
eval_env = Monitor(eval_env) | |
eval_env = DummyVecEnv([lambda: eval_env]) | |
agent = get_agent(cfg, eval_env, model_path) | |
start_time = time.time() | |
_ = evaluate(agent, eval_env, 1, return_episode_rewards=True) | |
eval_time = time.time() - start_time | |
print(f'\t time: {eval_time}') | |
if __name__ == '__main__': | |
app.run(main) | |