Spaces:
Running
Running
File size: 6,216 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import random
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
from utils.utils import load_dataset, save_dataset
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask, get_cap_mask
from models.classifiers.ground_truth.ground_truth import GroundTruth
from models.solvers.general_solver import GeneralSolver
from models.cf_generator import CFTourGenerator
class CFDatasetBase():
def __init__(self, problem, cf_generator, classifier, base_dataset, num_samples, random_seed, parallel, num_cpus):
self.problem = problem
self.parallel = parallel
self.num_cpus = num_cpus
self.seed = random_seed
self.cf_generator = CFTourGenerator(cf_solver=GeneralSolver(problem, cf_generator))
self.classifier = GroundTruth(problem, classifier)
self.node_mask = NodeMask(problem)
self.dataset = load_dataset(base_dataset)
self.num_samples = len(self.dataset) if num_samples is None else num_samples
def generate_cf_dataset(self):
random.seed(self.seed)
cf_dataset = []
num_required_samples = self.num_samples
end = False
print("Data generation started.", flush=True)
while(not end):
dataset = self.dataset[:num_required_samples]
self.dataset = np.roll(self.dataset, -num_required_samples)
if self.parallel:
instances = self.generate_labeldata_para(dataset, self.num_cpus)
else:
instances = self.generate_labeldata(dataset)
cf_dataset.extend(filter(None, instances))
num_required_samples = self.num_samples - len(cf_dataset)
if num_required_samples == 0:
end = True
else:
print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True)
print("Data generation completed.", flush=True)
return cf_dataset
def generate_labeldata(self, dataset):
return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")]
def generate_labeldata_para(self, dataset, num_cpus):
with Pool(num_cpus) as pool:
annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances"))
return annotation_data
def annotate(self, instance):
# generate a counterfactual route randomly
routes = instance["tour"]
vehicle_id = random.randint(0, len(routes) - 1)
if len(routes[vehicle_id]) - 2 <= 2:
return
cf_step = random.randint(2, len(routes[vehicle_id]) - 2)
route = routes[vehicle_id]
mask = self.node_mask.get_mask(route, cf_step, instance)
node_id = np.arange(len(instance["coords"]))
feasible_node_id = node_id[mask]
feasible_node_id = feasible_node_id[feasible_node_id != route[cf_step]].tolist()
if len(feasible_node_id) == 0:
return
cf_visit = random.choice(feasible_node_id)
cf_routes = self.cf_generator(routes, vehicle_id, cf_step, cf_visit, instance)
if cf_routes is None:
return
# annotate each edge
inputs = self.classifier.get_inputs(cf_routes, 0, instance)
labels = self.classifier(inputs, annotation=True)
# update tours and lables
instance["tour"] = cf_routes
instance["labels"] = labels
return instance
class NodeMask():
def __init__(self, problem):
self.problem = problem
if self.problem == "tsptw":
self.mask_func = get_tsptw_mask
elif self.problem == "pctsp":
self.mask_func = get_pctsp_mask
elif self.problem == "pctsptw":
self.mask_func = get_pctsptw_mask
elif self.problem == "cvrp":
self.mask_func = get_cvrp_mask
else:
NotImplementedError
def get_mask(self, route, step, instance):
return self.mask_func(route, step, instance)
def get_tsptw_mask(route, step, instance):
visited = get_visited_mask(route, step, instance)
not_exceed_tw = get_tw_mask(route, step, instance)
return ~visited & not_exceed_tw
def get_pctsp_mask(route, step, instance):
visited = get_visited_mask(route, step, instance)
return ~visited
def get_pctsptw_mask(route, step, instance):
visited = get_visited_mask(route, step, instance)
not_exceed_tw = get_tw_mask(route, step, instance)
return ~visited & not_exceed_tw
def get_cvrp_mask(route, step, instance):
visited = get_visited_mask(route, step, instance)
less_than_cap = get_cap_mask(route, step, instance)
return ~visited & less_than_cap
if __name__ == "__main__":
import os
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument("--problem", type=str, default="tsptw")
parser.add_argument("--base_dataset", type=str, required=True)
parser.add_argument("--cf_generator", type=str, default="ortools")
parser.add_argument("--classifier", type=str, default="ortools")
parser.add_argument("--num_samples", type=int, default=None)
parser.add_argument("--random_seed", type=int, default=1234)
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--num_cpus", type=int, default=4)
parser.add_argument("--output_dir", type=str, default="data")
args = parser.parse_args()
dataset_gen = CFDatasetBase(args.problem,
args.cf_generator,
args.classifier,
args.base_dataset,
args.num_samples,
args.random_seed,
args.parallel,
args.num_cpus)
cf_dataset = dataset_gen.generate_cf_dataset()
output_fname = f"{args.output_dir}/{args.problem}/cf_{dataset_gen.num_samples}samples_seed{args.random_seed}_base_{os.path.basename(args.base_dataset)}.pkl"
save_dataset(cf_dataset, output_fname) |