MFLP / facility_location /multi_eval.py
苏泓源
update
a257639
raw
history blame
3 kB
import os
import pickle
import setproctitle
from absl import app, flags
import time
import random
from typing import Tuple, Union, Text
import numpy as np
import torch as th
import sys
import gymnasium
sys.modules["gym"] = gymnasium
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnvWrapper
from facility_location.agent.solver import PMPSolver
from facility_location.env import EvalPMPEnv, MULTIPMP
from facility_location.utils import Config
from facility_location.agent import MaskedFacilityLocationActorCriticPolicy
from facility_location.utils.policy import get_policy_kwargs
import warnings
warnings.filterwarnings('ignore')
AGENT = Union[PMPSolver, PPO]
def get_model(cfg: Config,
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv],
device: str) -> PPO:
policy_kwargs = get_policy_kwargs(cfg)
model = PPO(MaskedFacilityLocationActorCriticPolicy,
env,
verbose=1,
policy_kwargs=policy_kwargs,
device=device)
return model
def get_agent(cfg: Config,
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv],
model_path: Text) -> AGENT:
if cfg.agent in ['rl-mlp', 'rl-gnn', 'rl-agnn']:
test_model = get_model(cfg, env, device='cuda:0')
trained_model = PPO.load(model_path)
test_model.set_parameters(trained_model.get_parameters())
agent = test_model
else:
raise ValueError(f'Agent {cfg.agent} not supported.')
return agent
def evaluate(agent: AGENT,
env: Union[VecEnvWrapper, DummyVecEnv, EvalPMPEnv],
num_cases: int,
return_episode_rewards: bool):
if isinstance(agent, PPO):
return evaluate_ppo(agent, env, num_cases, return_episode_rewards=return_episode_rewards)
else:
raise ValueError(f'Agent {agent} not supported.')
from stable_baselines3.common.callbacks import BaseCallback
def evaluate_ppo(agent: PPO, env: EvalPMPEnv, num_cases: int, return_episode_rewards: bool) -> Tuple[float, float]:
rewards, _ = evaluate_policy(agent, env, n_eval_episodes=num_cases, return_episode_rewards=return_episode_rewards)
return rewards
def main(data_npy, boost=False):
th.manual_seed(0)
np.random.seed(0)
random.seed(0)
model_path = './facility_location/best_model.zip'
cfg = Config('plot', 0, False, '/data2/suhongyuan/flp', 'rl-gnn', model_path=model_path)
eval_env = MULTIPMP(cfg, data_npy, boost)
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
agent = get_agent(cfg, eval_env, model_path)
start_time = time.time()
_ = evaluate(agent, eval_env, 1, return_episode_rewards=True)
eval_time = time.time() - start_time
print(f'\t time: {eval_time}')
if __name__ == '__main__':
app.run(main)