route-explainer / models /classifiers /general_classifier.py
daisuke.kikuta
first commit
719d0db
import torch
import torch.nn as nn
import argparse
import json
import os
from models.classifiers.predictor import DecisionPredictor
from models.classifiers.meaningless_models import FixedClassPredictor, RandomPredictor
from models.classifiers.rule_based_models import kNearestPredictor
from models.classifiers.ground_truth.ground_truth import GroundTruth
class GeneralClassifier(nn.Module):
def __init__(self, problem, model_type):
super().__init__()
self.model_type = model_type
self.problem = problem
self.model = self.get_model(problem, model_type)
def change_model(self, problem, model_type):
if self.model_type != model_type or self.problem != problem:
self.model_type = model_type
self.problem = problem
self.model = self.get_model(problem, model_type)
def get_model(self, problem, model_type):
if model_type == "gnn":
model_path = "checkpoints/model_20230309_101058/model_epoch4.pth"
params = argparse.ArgumentParser()
model_dir = os.path.split(model_path)[0]
with open(f"{model_dir}/cmd_args.dat", "r") as f:
params.__dict__ = json.load(f)
model = DecisionPredictor(params.problem,
params.emb_dim,
params.num_mlp_layers,
params.num_classes,
params.dropout)
model.load_state_dict(torch.load(model_path))
return model
elif model_type == "gt(ortools)":
return GroundTruth(problem, solver_type="ortools")
elif model_type == "gt(lkh)":
return GroundTruth(problem, solver_type="lkh")
elif model_type == "gt(concorde)":
return GroundTruth(problem, solver_type="concorde")
elif model_type == "random":
return RandomPredictor(num_classes=2)
elif model_type == "fixed":
predicted_class = 0
return FixedClassPredictor(predicted_class=predicted_class, num_classes=2)
elif model_type == "knn":
k = 5
k_type = "num"
return kNearestPredictor(problem, k, k_type)
else:
assert False, f"Invalid model type: {model_type}"
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
return self.model.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
def forward(self, inputs):
return self.model(inputs)