import random import numpy as np from tqdm import tqdm from multiprocessing import Pool from utils.utils import load_dataset, save_dataset from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask, get_cap_mask from models.classifiers.ground_truth.ground_truth import GroundTruth from models.solvers.general_solver import GeneralSolver from models.cf_generator import CFTourGenerator class CFDatasetBase(): def __init__(self, problem, cf_generator, classifier, base_dataset, num_samples, random_seed, parallel, num_cpus): self.problem = problem self.parallel = parallel self.num_cpus = num_cpus self.seed = random_seed self.cf_generator = CFTourGenerator(cf_solver=GeneralSolver(problem, cf_generator)) self.classifier = GroundTruth(problem, classifier) self.node_mask = NodeMask(problem) self.dataset = load_dataset(base_dataset) self.num_samples = len(self.dataset) if num_samples is None else num_samples def generate_cf_dataset(self): random.seed(self.seed) cf_dataset = [] num_required_samples = self.num_samples end = False print("Data generation started.", flush=True) while(not end): dataset = self.dataset[:num_required_samples] self.dataset = np.roll(self.dataset, -num_required_samples) if self.parallel: instances = self.generate_labeldata_para(dataset, self.num_cpus) else: instances = self.generate_labeldata(dataset) cf_dataset.extend(filter(None, instances)) num_required_samples = self.num_samples - len(cf_dataset) if num_required_samples == 0: end = True else: print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True) print("Data generation completed.", flush=True) return cf_dataset def generate_labeldata(self, dataset): return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")] def generate_labeldata_para(self, dataset, num_cpus): with Pool(num_cpus) as pool: annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances")) return annotation_data def annotate(self, instance): # generate a counterfactual route randomly routes = instance["tour"] vehicle_id = random.randint(0, len(routes) - 1) if len(routes[vehicle_id]) - 2 <= 2: return cf_step = random.randint(2, len(routes[vehicle_id]) - 2) route = routes[vehicle_id] mask = self.node_mask.get_mask(route, cf_step, instance) node_id = np.arange(len(instance["coords"])) feasible_node_id = node_id[mask] feasible_node_id = feasible_node_id[feasible_node_id != route[cf_step]].tolist() if len(feasible_node_id) == 0: return cf_visit = random.choice(feasible_node_id) cf_routes = self.cf_generator(routes, vehicle_id, cf_step, cf_visit, instance) if cf_routes is None: return # annotate each edge inputs = self.classifier.get_inputs(cf_routes, 0, instance) labels = self.classifier(inputs, annotation=True) # update tours and lables instance["tour"] = cf_routes instance["labels"] = labels return instance class NodeMask(): def __init__(self, problem): self.problem = problem if self.problem == "tsptw": self.mask_func = get_tsptw_mask elif self.problem == "pctsp": self.mask_func = get_pctsp_mask elif self.problem == "pctsptw": self.mask_func = get_pctsptw_mask elif self.problem == "cvrp": self.mask_func = get_cvrp_mask else: NotImplementedError def get_mask(self, route, step, instance): return self.mask_func(route, step, instance) def get_tsptw_mask(route, step, instance): visited = get_visited_mask(route, step, instance) not_exceed_tw = get_tw_mask(route, step, instance) return ~visited & not_exceed_tw def get_pctsp_mask(route, step, instance): visited = get_visited_mask(route, step, instance) return ~visited def get_pctsptw_mask(route, step, instance): visited = get_visited_mask(route, step, instance) not_exceed_tw = get_tw_mask(route, step, instance) return ~visited & not_exceed_tw def get_cvrp_mask(route, step, instance): visited = get_visited_mask(route, step, instance) less_than_cap = get_cap_mask(route, step, instance) return ~visited & less_than_cap if __name__ == "__main__": import os import argparse parser = argparse.ArgumentParser(description='') parser.add_argument("--problem", type=str, default="tsptw") parser.add_argument("--base_dataset", type=str, required=True) parser.add_argument("--cf_generator", type=str, default="ortools") parser.add_argument("--classifier", type=str, default="ortools") parser.add_argument("--num_samples", type=int, default=None) parser.add_argument("--random_seed", type=int, default=1234) parser.add_argument("--parallel", action="store_true") parser.add_argument("--num_cpus", type=int, default=4) parser.add_argument("--output_dir", type=str, default="data") args = parser.parse_args() dataset_gen = CFDatasetBase(args.problem, args.cf_generator, args.classifier, args.base_dataset, args.num_samples, args.random_seed, args.parallel, args.num_cpus) cf_dataset = dataset_gen.generate_cf_dataset() output_fname = f"{args.output_dir}/{args.problem}/cf_{dataset_gen.num_samples}samples_seed{args.random_seed}_base_{os.path.basename(args.base_dataset)}.pkl" save_dataset(cf_dataset, output_fname)