Spaces:
Running
Running
import os | |
import argparse | |
import json | |
import multiprocessing | |
import torch | |
import time | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score | |
from utils.util_calc import TemporalConfusionMatrix | |
from models.classifiers.nn_classifiers.nn_classifier import NNClassifier | |
from models.classifiers.ground_truth.ground_truth import GroundTruth | |
from models.classifiers.ground_truth.ground_truth_base import FAIL_FLAG | |
from utils.data_utils.tsptw_dataset import TSPTWDataloader | |
from utils.data_utils.pctsp_dataset import PCTSPDataloader | |
from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader | |
from utils.data_utils.cvrp_dataset import CVRPDataloader | |
from utils.utils import set_device | |
from utils.utils import load_dataset | |
def load_eval_dataset(dataset_path, problem, model_type, batch_size, num_workers, parallel, num_cpus): | |
if model_type == "nn": | |
if problem == "tsptw": | |
eval_dataset = TSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
elif problem == "pctsp": | |
eval_dataset = PCTSPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
elif problem == "pctsptw": | |
eval_dataset = PCTSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
elif problem == "cvrp": | |
eval_dataset = CVRPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus) | |
else: | |
raise NotImplementedError | |
#------------ | |
# dataloader | |
#------------ | |
def pad_seq_length(batch): | |
data = {} | |
for key in batch[0].keys(): | |
padding_value = True if key == "mask" else 0.0 | |
# post-padding | |
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value) | |
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False) | |
data.update({"pad_mask": pad_mask}) | |
return data | |
eval_dataloader = DataLoader(eval_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=pad_seq_length, | |
num_workers=num_workers) | |
return eval_dataloader | |
else: | |
eval_dataset = load_dataset(dataset_path) | |
return eval_dataset | |
def eval_classifier(problem: str, | |
dataset, | |
model_type: str, | |
model_dir: str = None, | |
gpu: int = -1, | |
num_workers: int = 4, | |
batch_size: int = 128, | |
parallel: bool = True, | |
solver: str = "ortools", | |
num_cpus: int = 1): | |
#-------------- | |
# gpu settings | |
#-------------- | |
use_cuda, device = set_device(gpu) | |
#------- | |
# model | |
#------- | |
num_classes = 3 if problem == "pctsptw" else 2 | |
if model_type == "nn": | |
assert model_dir is not None, "please specify model_path when model_type is nn." | |
params = argparse.ArgumentParser() | |
# model_dir = os.path.split(args.model_path)[0] | |
with open(f"{model_dir}/cmd_args.dat", "r") as f: | |
params.__dict__ = json.load(f) | |
assert params.problem == problem, "problem of the trained model should match that of the dataset" | |
model = NNClassifier(problem=params.problem, | |
node_enc_type=params.node_enc_type, | |
edge_enc_type=params.edge_enc_type, | |
dec_type=params.dec_type, | |
emb_dim=params.emb_dim, | |
num_enc_mlp_layers=params.num_enc_mlp_layers, | |
num_dec_mlp_layers=params.num_dec_mlp_layers, | |
num_classes=num_classes, | |
dropout=params.dropout, | |
pos_encoder=params.pos_encoder) | |
# load trained weights (the best epoch) | |
with open(f"{model_dir}/best_epoch.dat", "r") as f: | |
best_epoch = int(f.read()) | |
print(f"loaded {model_dir}/model_epoch{best_epoch}.pth.") | |
model.load_state_dict(torch.load(f"{model_dir}/model_epoch{best_epoch}.pth")) | |
if use_cuda: | |
model.to(device) | |
is_sequential = model.is_sequential | |
elif model_type == "ground_truth": | |
model = GroundTruth(problem=problem, solver_type=solver) | |
is_sequential = False | |
else: | |
assert False, f"Invalid model type: {model_type}" | |
#--------- | |
# Metrics | |
#--------- | |
overall_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
eval_accuracy_dict = {} # MulticlassAccuracy(num_classes=num_classes, average="macro") | |
temp_confmat_dict = {} # TemporalConfusionMatrix(num_classes=num_classes, seq_length=50, device=device) | |
temporal_accuracy_dict = {} | |
num_nodes_dist_dict = {} | |
#------------ | |
# Evaluation | |
#------------ | |
if model_type == "nn": | |
model.eval() | |
eval_time = 0.0 | |
print("Evaluating models ...", end="") | |
start_time = time.perf_counter() | |
for data in dataset: | |
if use_cuda: | |
data = {key: value.to(device) for key, value in data.items()} | |
if not is_sequential: | |
shp = data["curr_node_id"].size() | |
data = {key: value.flatten(0, 1) for key, value in data.items()} | |
probs = model(data) # [batch_size x num_classes] or [batch_size x max_seq_length x num_classes] | |
if not is_sequential: | |
probs = probs.view(*shp, -1) # [batch_size x max_seq_length x num_classes] | |
data["labels"] = data["labels"].view(*shp) | |
data["pad_mask"] = data["pad_mask"].view(*shp) | |
#------------ | |
# evaluation | |
#------------ | |
start_eval_time = time.perf_counter() | |
# accuracy | |
seq_length_list = torch.unique(data["pad_mask"].sum(-1)) | |
for seq_length_tensor in seq_length_list: | |
seq_length = seq_length_tensor.item() | |
if seq_length not in eval_accuracy_dict.keys(): | |
eval_accuracy_dict[seq_length] = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
temp_confmat_dict[seq_length] = TemporalConfusionMatrix(num_classes=num_classes, seq_length=seq_length, device=device) | |
temporal_accuracy_dict[seq_length] = [MulticlassF1Score(num_classes=num_classes, average="macro").to(device) for _ in range(seq_length)] | |
num_nodes_dist_dict[seq_length] = 0 | |
seq_length_mask = (data["pad_mask"].sum(-1) == seq_length) # [batch_size] | |
extracted_labels = data["labels"][seq_length_mask] | |
extracted_probs = probs[seq_length_mask] | |
extracted_mask = data["pad_mask"][seq_length_mask].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)] | |
eval_accuracy_dict[seq_length](extracted_probs.argmax(-1).view(-1)[extracted_mask], extracted_labels.view(-1)[extracted_mask]) | |
mask = data["pad_mask"].view(-1) | |
overall_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask]) | |
# confusion matrix | |
temp_confmat_dict[seq_length].update(probs.argmax(-1), data["labels"], data["pad_mask"]) | |
# temporal accuracy | |
for step in range(seq_length): | |
temporal_accuracy_dict[seq_length][step](extracted_probs[:, step, :], extracted_labels[:, step]) | |
# number of samples whose sequence length is seq_length | |
num_nodes_dist_dict[seq_length] += len(extracted_labels) | |
eval_time += time.perf_counter() - start_eval_time | |
calc_time = time.perf_counter() - start_time - eval_time | |
total_eval_accuracy = {key: value.compute().item() for key, value in eval_accuracy_dict.items()} | |
overall_accuracy = overall_accuracy.compute() #.item() | |
temporal_confmat = {key: value.compute() for key, value in temp_confmat_dict.items()} | |
temporal_accuracy = {key: [value.compute().item() for value in values] for key, values in temporal_accuracy_dict.items()} | |
print("done") | |
return overall_accuracy, total_eval_accuracy, temporal_accuracy, calc_time, temporal_confmat, num_nodes_dist_dict | |
else: | |
eval_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device) | |
print("Loading data ...", end=" ") | |
with multiprocessing.Pool(num_cpus) as pool: | |
input_list = list(pool.starmap(model.get_inputs, [(instance["tour"], 0, instance) for instance in dataset])) | |
print("done") | |
print("Infering labels ...", end="") | |
pool = multiprocessing.Pool(num_cpus) | |
start_time = time.perf_counter() | |
prob_list = list(pool.starmap(model, tqdm([(inputs, False, False) for inputs in input_list]))) | |
calc_time = time.perf_counter() - start_time | |
pool.close() | |
print("done") | |
print("Evaluating models ...", end="") | |
for i, instance in enumerate(dataset): | |
labels = instance["labels"] | |
for vehicle_id in range(len(labels)): | |
for step, label in labels[vehicle_id]: | |
pred_label = prob_list[i][vehicle_id][step-1] # [num_classes] | |
if pred_label == FAIL_FLAG: | |
pred_label = label - 1 if label != 0 else label + 1 | |
eval_accuracy(torch.LongTensor([pred_label]).view(1, -1), torch.LongTensor([label]).view(1, -1)) | |
total_eval_accuracy = eval_accuracy.compute() | |
print("done") | |
return total_eval_accuracy.item(), calc_time | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
#----------------- | |
# general settings | |
#----------------- | |
parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu") | |
parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader") | |
parser.add_argument("--parallel", ) | |
#------------- | |
# data setting | |
#------------- | |
parser.add_argument("--dataset_path", type=str, help="Path to a dataset", required=True) | |
#------------------ | |
# Metrics settings | |
#------------------ | |
#---------------- | |
# model settings | |
#---------------- | |
parser.add_argument("--model_type", type=str, default="nn", help="Select from [nn, ground_truth]") | |
# nn classifier | |
parser.add_argument("--model_dir", type=str, default=None) | |
parser.add_argument("--batch_size", type=int, default=256) | |
parser.add_argument("--parallel", action="store_true") | |
# ground truth | |
parser.add_argument("--solver", type=str, default="ortools") | |
parser.add_argument("--num_cpus", type=int, default=os.cpu_count()) | |
args = parser.parse_args() | |
problem = str(os.path.basename(os.path.dirname(args.dataset_path))) | |
dataset = load_eval_dataset(args.dataset_path, problem, args.model_type, args.batch_size, args.num_workers, args.parallel, args.num_cpus) | |
eval_classifier(problem=problem, | |
dataset=dataset, | |
model_type=args.model_type, | |
model_dir=args.model_dir, | |
gpu=args.gpu, | |
num_workers=args.num_workers, | |
batch_size=args.batch_size, | |
parallel=args.parallel, | |
solver=args.solver, | |
num_cpus=args.num_cpus) |