Spaces:
Sleeping
Sleeping
File size: 6,366 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from utils.data_utils.tsptw_dataset import TSPTWDataset
from utils.data_utils.pctsp_dataset import PCTSPDataset
from utils.data_utils.pctsptw_dataset import PCTSPTWDataset
from utils.data_utils.cvrp_dataset import CVRPDataset
from utils.data_utils.cvrptw_dataset import CVRPTWDataset
from utils.utils import save_dataset
def generate_dataset(num_samples, args):
if args.problem == "tsptw":
data_generator = TSPTWDataset(coord_dim=args.coord_dim,
num_samples=num_samples,
num_nodes=args.num_nodes,
random_seed=args.random_seed,
solver=args.solver,
classifier=args.classifier,
annotation=args.annotation,
parallel=args.parallel,
num_cpus=args.num_cpus,
distribution=args.distribution)
elif args.problem == "pctsp":
data_generator = PCTSPDataset(coord_dim=args.coord_dim,
num_samples=num_samples,
num_nodes=args.num_nodes,
random_seed=args.random_seed,
solver=args.solver,
classifier=args.classifier,
annotation=args.annotation,
parallel=args.parallel,
num_cpus=args.num_cpus,
penalty_factor=args.penalty_factor)
elif args.problem == "pctsptw":
data_generator = PCTSPTWDataset(coord_dim=args.coord_dim,
num_samples=num_samples,
num_nodes=args.num_nodes,
random_seed=args.random_seed,
solver=args.solver,
classifier=args.classifier,
annotation=args.annotation,
parallel=args.parallel,
num_cpus=args.num_cpus,
penalty_factor=args.penalty_factor)
elif args.problem == "cvrp":
data_generator = CVRPDataset(coord_dim=args.coord_dim,
num_samples=num_samples,
num_nodes=args.num_nodes,
random_seed=args.random_seed,
solver=args.solver,
classifier=args.classifier,
annotation=args.annotation,
parallel=args.parallel,
num_cpus=args.num_cpus)
elif args.problem == "cvrptw":
data_generator = CVRPTWDataset(coord_dim=args.coord_dim,
num_samples=num_samples,
num_nodes=args.num_nodes,
random_seed=args.random_seed,
solver=args.solver,
classifier=args.classifier,
annotation=args.annotation,
parallel=args.parallel,
num_cpus=args.num_cpus)
else:
raise NotImplementedError
return data_generator.generate_dataset()
if __name__ == "__main__":
import argparse
import os
import numpy as np
parser = argparse.ArgumentParser(description='')
# common settings
parser.add_argument("--problem", type=str, default="tsptw")
parser.add_argument("--random_seed", type=int, default=1234)
parser.add_argument("--data_type", type=str, nargs="*", default=["all"], help="data type: 'all' or combo. of ['train', 'valid', 'test'].")
parser.add_argument("--num_samples", type=int, nargs="*", default=[1000, 100, 100])
parser.add_argument("--num_nodes", type=int, default=20)
parser.add_argument("--coord_dim", type=int, default=2, help="only coord_dim=2 is supported for now.")
parser.add_argument("--solver", type=str, default="ortools", help="solver that outputs a tour")
parser.add_argument("--classifier", type=str, default="ortools", help="classifier for annotation")
parser.add_argument("--annotation", action="store_true")
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--num_cpus", type=int, default=os.cpu_count())
parser.add_argument("--output_dir", type=str, default="data")
# for TSPTW
parser.add_argument("--distribution", type=str, default="da_silva")
# for PCTSP
parser.add_argument("--penalty_factor", type=float, default=3.)
args = parser.parse_args()
# 3d problems are not supported
assert args.coord_dim == 2, "only coord_dim=2 is supported for now."
# calc num. of total samples (train + valid + test samples)
if args.data_type[0] == "all":
assert len(args.num_samples) == 3, "please specify # samples for each of the three types (train/valid/test) when you set data_type 'all'. (e.g., --num_samples 1280000 1000 1000)"
else:
assert len(args.data_type) == len(args.num_samples), "please match # data_types and # elements in num_samples-arg"
num_samples = np.sum(args.num_samples)
# generate a dataset
dataset = generate_dataset(num_samples, args)
# split the dataset
if args.data_type[0] == "all":
types = ["train", "valid", "eval"]
else:
types = args.data_type
num_sample_list = args.num_samples
num_sample_list.insert(0, 0)
start = 0
for i, type_name in enumerate(types):
start += num_sample_list[i]
end = start + num_sample_list[i+1]
divided_datset = dataset[start:end]
output_fname = f"{args.output_dir}/{args.problem}/{type_name}_{args.problem}_{args.num_nodes}nodes_{num_sample_list[i+1]}samples_seed{args.random_seed}.pkl"
save_dataset(divided_datset, output_fname) |