File size: 2,586 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import argparse
import json
import os
from models.classifiers.predictor import DecisionPredictor
from models.classifiers.meaningless_models import FixedClassPredictor, RandomPredictor
from models.classifiers.rule_based_models import kNearestPredictor
from models.classifiers.ground_truth.ground_truth import GroundTruth

class GeneralClassifier(nn.Module):
    def __init__(self, problem, model_type):
        super().__init__()
        self.model_type = model_type
        self.problem = problem
        self.model = self.get_model(problem, model_type)

    def change_model(self, problem, model_type):
        if self.model_type != model_type or self.problem != problem:
            self.model_type = model_type
            self.problem = problem
            self.model = self.get_model(problem, model_type)

    def get_model(self, problem, model_type):
        if model_type == "gnn":
            model_path = "checkpoints/model_20230309_101058/model_epoch4.pth"
            params = argparse.ArgumentParser()
            model_dir = os.path.split(model_path)[0]
            with open(f"{model_dir}/cmd_args.dat", "r") as f:
                params.__dict__ = json.load(f)
            model = DecisionPredictor(params.problem,
                                      params.emb_dim,
                                      params.num_mlp_layers,
                                      params.num_classes,
                                      params.dropout)
            model.load_state_dict(torch.load(model_path))
            return model
        elif model_type == "gt(ortools)":
            return GroundTruth(problem, solver_type="ortools")
        elif model_type == "gt(lkh)":
            return GroundTruth(problem, solver_type="lkh")
        elif model_type == "gt(concorde)":
            return GroundTruth(problem, solver_type="concorde")
        elif model_type == "random":
            return RandomPredictor(num_classes=2)
        elif model_type == "fixed":
            predicted_class = 0
            return FixedClassPredictor(predicted_class=predicted_class, num_classes=2)
        elif model_type == "knn":
            k = 5
            k_type = "num"
            return kNearestPredictor(problem, k, k_type)
        else:
            assert False, f"Invalid model type: {model_type}"
    
    def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
        return self.model.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
    
    def forward(self, inputs):
        return self.model(inputs)