File size: 6,216 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import random
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
from utils.utils import load_dataset, save_dataset
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask, get_cap_mask
from models.classifiers.ground_truth.ground_truth import GroundTruth
from models.solvers.general_solver import GeneralSolver
from models.cf_generator import CFTourGenerator


class CFDatasetBase():
    def __init__(self, problem, cf_generator, classifier, base_dataset, num_samples, random_seed, parallel, num_cpus):
        self.problem = problem
        self.parallel = parallel
        self.num_cpus = num_cpus
        self.seed = random_seed
        self.cf_generator = CFTourGenerator(cf_solver=GeneralSolver(problem, cf_generator))
        self.classifier = GroundTruth(problem, classifier)
        self.node_mask = NodeMask(problem)
        self.dataset = load_dataset(base_dataset)
        self.num_samples = len(self.dataset) if num_samples is None else num_samples

    def generate_cf_dataset(self):
        random.seed(self.seed)
        cf_dataset = []
        num_required_samples = self.num_samples
        end = False
        print("Data generation started.", flush=True)
        while(not end):
            dataset = self.dataset[:num_required_samples]
            self.dataset = np.roll(self.dataset, -num_required_samples)
            if self.parallel:
                instances = self.generate_labeldata_para(dataset, self.num_cpus)
            else:
                instances = self.generate_labeldata(dataset)
            cf_dataset.extend(filter(None, instances))
            num_required_samples = self.num_samples - len(cf_dataset) 
            if num_required_samples == 0:
                end = True
            else:
                print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True)
        print("Data generation completed.", flush=True)
        return cf_dataset

    def generate_labeldata(self, dataset):
        return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")]

    def generate_labeldata_para(self, dataset, num_cpus):
        with Pool(num_cpus) as pool:
            annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances"))
        return annotation_data

    def annotate(self, instance):
        # generate a counterfactual route randomly
        routes = instance["tour"]
        vehicle_id = random.randint(0, len(routes) - 1)
        if len(routes[vehicle_id]) - 2 <= 2:
            return
        cf_step    = random.randint(2, len(routes[vehicle_id]) - 2)
        route = routes[vehicle_id]
        mask = self.node_mask.get_mask(route, cf_step, instance)
        node_id = np.arange(len(instance["coords"]))
        feasible_node_id = node_id[mask]
        feasible_node_id = feasible_node_id[feasible_node_id != route[cf_step]].tolist()
        if len(feasible_node_id) == 0:
            return
        cf_visit  = random.choice(feasible_node_id)
        cf_routes = self.cf_generator(routes, vehicle_id, cf_step, cf_visit, instance)
        if cf_routes is None:
            return

        # annotate each edge
        inputs = self.classifier.get_inputs(cf_routes, 0, instance)
        labels = self.classifier(inputs, annotation=True)

        # update tours and lables
        instance["tour"] = cf_routes
        instance["labels"] = labels
        return instance

class NodeMask():
    def __init__(self, problem):
        self.problem = problem

        if self.problem == "tsptw":
            self.mask_func = get_tsptw_mask
        elif self.problem == "pctsp":
            self.mask_func = get_pctsp_mask
        elif self.problem == "pctsptw":
            self.mask_func = get_pctsptw_mask
        elif self.problem == "cvrp":
            self.mask_func = get_cvrp_mask
        else:
            NotImplementedError

    def get_mask(self, route, step, instance):
        return self.mask_func(route, step, instance)

def get_tsptw_mask(route, step, instance):
    visited = get_visited_mask(route, step, instance)
    not_exceed_tw = get_tw_mask(route, step, instance)
    return ~visited & not_exceed_tw

def get_pctsp_mask(route, step, instance):
    visited = get_visited_mask(route, step, instance)
    return ~visited

def get_pctsptw_mask(route, step, instance):
    visited = get_visited_mask(route, step, instance)
    not_exceed_tw = get_tw_mask(route, step, instance)
    return ~visited & not_exceed_tw

def get_cvrp_mask(route, step, instance):
    visited = get_visited_mask(route, step, instance)
    less_than_cap = get_cap_mask(route, step, instance)
    return ~visited & less_than_cap

if __name__ == "__main__":
    import os
    import argparse
    parser = argparse.ArgumentParser(description='')
    parser.add_argument("--problem",      type=str, default="tsptw")
    parser.add_argument("--base_dataset", type=str, required=True)
    parser.add_argument("--cf_generator", type=str, default="ortools")
    parser.add_argument("--classifier",   type=str, default="ortools")
    parser.add_argument("--num_samples",  type=int, default=None)
    parser.add_argument("--random_seed",  type=int, default=1234)
    parser.add_argument("--parallel",     action="store_true")
    parser.add_argument("--num_cpus",     type=int, default=4)
    parser.add_argument("--output_dir",   type=str, default="data")
    args = parser.parse_args()

    dataset_gen = CFDatasetBase(args.problem,
                                args.cf_generator,
                                args.classifier,
                                args.base_dataset,
                                args.num_samples,
                                args.random_seed,
                                args.parallel,
                                args.num_cpus)
    cf_dataset = dataset_gen.generate_cf_dataset()

    output_fname = f"{args.output_dir}/{args.problem}/cf_{dataset_gen.num_samples}samples_seed{args.random_seed}_base_{os.path.basename(args.base_dataset)}.pkl"
    save_dataset(cf_dataset, output_fname)