Spaces:
Sleeping
Sleeping
from utils.data_utils.tsptw_dataset import TSPTWDataset | |
from utils.data_utils.pctsp_dataset import PCTSPDataset | |
from utils.data_utils.pctsptw_dataset import PCTSPTWDataset | |
from utils.data_utils.cvrp_dataset import CVRPDataset | |
from utils.data_utils.cvrptw_dataset import CVRPTWDataset | |
from utils.utils import save_dataset | |
def generate_dataset(num_samples, args): | |
if args.problem == "tsptw": | |
data_generator = TSPTWDataset(coord_dim=args.coord_dim, | |
num_samples=num_samples, | |
num_nodes=args.num_nodes, | |
random_seed=args.random_seed, | |
solver=args.solver, | |
classifier=args.classifier, | |
annotation=args.annotation, | |
parallel=args.parallel, | |
num_cpus=args.num_cpus, | |
distribution=args.distribution) | |
elif args.problem == "pctsp": | |
data_generator = PCTSPDataset(coord_dim=args.coord_dim, | |
num_samples=num_samples, | |
num_nodes=args.num_nodes, | |
random_seed=args.random_seed, | |
solver=args.solver, | |
classifier=args.classifier, | |
annotation=args.annotation, | |
parallel=args.parallel, | |
num_cpus=args.num_cpus, | |
penalty_factor=args.penalty_factor) | |
elif args.problem == "pctsptw": | |
data_generator = PCTSPTWDataset(coord_dim=args.coord_dim, | |
num_samples=num_samples, | |
num_nodes=args.num_nodes, | |
random_seed=args.random_seed, | |
solver=args.solver, | |
classifier=args.classifier, | |
annotation=args.annotation, | |
parallel=args.parallel, | |
num_cpus=args.num_cpus, | |
penalty_factor=args.penalty_factor) | |
elif args.problem == "cvrp": | |
data_generator = CVRPDataset(coord_dim=args.coord_dim, | |
num_samples=num_samples, | |
num_nodes=args.num_nodes, | |
random_seed=args.random_seed, | |
solver=args.solver, | |
classifier=args.classifier, | |
annotation=args.annotation, | |
parallel=args.parallel, | |
num_cpus=args.num_cpus) | |
elif args.problem == "cvrptw": | |
data_generator = CVRPTWDataset(coord_dim=args.coord_dim, | |
num_samples=num_samples, | |
num_nodes=args.num_nodes, | |
random_seed=args.random_seed, | |
solver=args.solver, | |
classifier=args.classifier, | |
annotation=args.annotation, | |
parallel=args.parallel, | |
num_cpus=args.num_cpus) | |
else: | |
raise NotImplementedError | |
return data_generator.generate_dataset() | |
if __name__ == "__main__": | |
import argparse | |
import os | |
import numpy as np | |
parser = argparse.ArgumentParser(description='') | |
# common settings | |
parser.add_argument("--problem", type=str, default="tsptw") | |
parser.add_argument("--random_seed", type=int, default=1234) | |
parser.add_argument("--data_type", type=str, nargs="*", default=["all"], help="data type: 'all' or combo. of ['train', 'valid', 'test'].") | |
parser.add_argument("--num_samples", type=int, nargs="*", default=[1000, 100, 100]) | |
parser.add_argument("--num_nodes", type=int, default=20) | |
parser.add_argument("--coord_dim", type=int, default=2, help="only coord_dim=2 is supported for now.") | |
parser.add_argument("--solver", type=str, default="ortools", help="solver that outputs a tour") | |
parser.add_argument("--classifier", type=str, default="ortools", help="classifier for annotation") | |
parser.add_argument("--annotation", action="store_true") | |
parser.add_argument("--parallel", action="store_true") | |
parser.add_argument("--num_cpus", type=int, default=os.cpu_count()) | |
parser.add_argument("--output_dir", type=str, default="data") | |
# for TSPTW | |
parser.add_argument("--distribution", type=str, default="da_silva") | |
# for PCTSP | |
parser.add_argument("--penalty_factor", type=float, default=3.) | |
args = parser.parse_args() | |
# 3d problems are not supported | |
assert args.coord_dim == 2, "only coord_dim=2 is supported for now." | |
# calc num. of total samples (train + valid + test samples) | |
if args.data_type[0] == "all": | |
assert len(args.num_samples) == 3, "please specify # samples for each of the three types (train/valid/test) when you set data_type 'all'. (e.g., --num_samples 1280000 1000 1000)" | |
else: | |
assert len(args.data_type) == len(args.num_samples), "please match # data_types and # elements in num_samples-arg" | |
num_samples = np.sum(args.num_samples) | |
# generate a dataset | |
dataset = generate_dataset(num_samples, args) | |
# split the dataset | |
if args.data_type[0] == "all": | |
types = ["train", "valid", "eval"] | |
else: | |
types = args.data_type | |
num_sample_list = args.num_samples | |
num_sample_list.insert(0, 0) | |
start = 0 | |
for i, type_name in enumerate(types): | |
start += num_sample_list[i] | |
end = start + num_sample_list[i+1] | |
divided_datset = dataset[start:end] | |
output_fname = f"{args.output_dir}/{args.problem}/{type_name}_{args.problem}_{args.num_nodes}nodes_{num_sample_list[i+1]}samples_seed{args.random_seed}.pkl" | |
save_dataset(divided_datset, output_fname) |