Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import argparse | |
import json | |
import os | |
from models.classifiers.predictor import DecisionPredictor | |
from models.classifiers.meaningless_models import FixedClassPredictor, RandomPredictor | |
from models.classifiers.rule_based_models import kNearestPredictor | |
from models.classifiers.ground_truth.ground_truth import GroundTruth | |
class GeneralClassifier(nn.Module): | |
def __init__(self, problem, model_type): | |
super().__init__() | |
self.model_type = model_type | |
self.problem = problem | |
self.model = self.get_model(problem, model_type) | |
def change_model(self, problem, model_type): | |
if self.model_type != model_type or self.problem != problem: | |
self.model_type = model_type | |
self.problem = problem | |
self.model = self.get_model(problem, model_type) | |
def get_model(self, problem, model_type): | |
if model_type == "gnn": | |
model_path = "checkpoints/model_20230309_101058/model_epoch4.pth" | |
params = argparse.ArgumentParser() | |
model_dir = os.path.split(model_path)[0] | |
with open(f"{model_dir}/cmd_args.dat", "r") as f: | |
params.__dict__ = json.load(f) | |
model = DecisionPredictor(params.problem, | |
params.emb_dim, | |
params.num_mlp_layers, | |
params.num_classes, | |
params.dropout) | |
model.load_state_dict(torch.load(model_path)) | |
return model | |
elif model_type == "gt(ortools)": | |
return GroundTruth(problem, solver_type="ortools") | |
elif model_type == "gt(lkh)": | |
return GroundTruth(problem, solver_type="lkh") | |
elif model_type == "gt(concorde)": | |
return GroundTruth(problem, solver_type="concorde") | |
elif model_type == "random": | |
return RandomPredictor(num_classes=2) | |
elif model_type == "fixed": | |
predicted_class = 0 | |
return FixedClassPredictor(predicted_class=predicted_class, num_classes=2) | |
elif model_type == "knn": | |
k = 5 | |
k_type = "num" | |
return kNearestPredictor(problem, k, k_type) | |
else: | |
assert False, f"Invalid model type: {model_type}" | |
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None): | |
return self.model.get_inputs(tour, first_explained_step, node_feats, dist_matrix) | |
def forward(self, inputs): | |
return self.model(inputs) |