from tqdm.autonotebook import tqdm import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score from models.classifiers.nn_classifiers.nn_classifier import NNClassifier from models.loss_functions import GeneralCrossEntropy from utils.data_utils.tsptw_dataset import TSPTWDataloader from utils.data_utils.pctsp_dataset import PCTSPDataloader from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader from utils.data_utils.cvrp_dataset import CVRPDataloader from utils.utils import set_device, count_trainable_params, batched_bincount, fix_seed def main(args): #--------------- # seed settings #--------------- fix_seed(args.seed) #-------------- # gpu settings #-------------- use_cuda, device = set_device(args.gpu) #------------------- # model & optimizer #------------------- num_classes = 3 if args.problem == "pctsptw" else 2 model = NNClassifier(problem=args.problem, node_enc_type=args.node_enc_type, edge_enc_type=args.edge_enc_type, dec_type=args.dec_type, emb_dim=args.emb_dim, num_enc_mlp_layers=args.num_enc_mlp_layers, num_dec_mlp_layers=args.num_dec_mlp_layers, num_classes=num_classes, dropout=args.dropout, pos_encoder=args.pos_encoder) is_sequential = model.is_sequential if use_cuda: model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # count number of trainable parameters num_trainable_params = count_trainable_params(model) print(f"num_trainable_params: {num_trainable_params}") with open(f"{args.model_checkpoint_path}/num_trainable_params.dat", "w") as f: f.write(str(num_trainable_params)) # loss function if not is_sequential: assert args.loss_function != "seq_cbce", "Non-sequential model does not support the loss funtion: seq_cbce" loss_func = GeneralCrossEntropy(weight_type=args.loss_function, beta=args.cb_beta, is_sequential=is_sequential) #--------- # dataset #--------- if args.problem == "tsptw": train_dataset = TSPTWDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) if args.valid_dataset_path is not None: valid_dataset = TSPTWDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) elif args.problem == "pctsp": train_dataset = PCTSPDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) if args.valid_dataset_path is not None: valid_dataset = PCTSPDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) elif args.problem == "pctsptw": train_dataset = PCTSPTWDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) if args.valid_dataset_path is not None: valid_dataset = PCTSPTWDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) elif args.problem == "cvrp": train_dataset = CVRPDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) if args.valid_dataset_path is not None: valid_dataset = CVRPDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus) else: raise NotImplementedError #------------ # dataloader #------------ if is_sequential: def pad_seq_length(batch): data = {} for key in batch[0].keys(): padding_value = True if key == "mask" else 0.0 # post-padding data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value) pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False) data.update({"pad_mask": pad_mask}) return data collate_fn = pad_seq_length else: collate_fn = None train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=args.num_workers) if args.valid_dataset_path is not None: valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=args.num_workers) #--------- # metrics #--------- macro_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro") if use_cuda: macro_accuracy.to(device) #--------------- # training loop #--------------- best_valid_accuracy = 0.0 model.train() with tqdm(range(args.epochs + 1)) as tq1: for epoch in tq1: #-------------------------- # save the current weights #-------------------------- # print(f"Epoch {epoch}: saving a model to {args.model_checkpoint_path}/model_epoch{epoch}.pth...", end="", flush=True) torch.save(model.cpu().state_dict(), f"{args.model_checkpoint_path}/model_epoch{epoch}.pth") model.to(device) # print("done.") #------------ # validation #------------ model.eval() with torch.no_grad(): tq1.set_description(f"Epoch {epoch}") # check train accuracy for data in train_dataloader: if use_cuda: data = {key: value.to(device) for key, value in data.items()} probs = model(data) if is_sequential: mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)] macro_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask]) else: macro_accuracy(probs.argmax(-1).view(-1), data["labels"].view(-1)) train_macro_accuracy = macro_accuracy.compute() # print(f"Epoch {epoch}: Train_accuracy={total_macro_accuracy}", flush=True) macro_accuracy.reset() # check valid accuracy if args.valid_dataset_path is not None: for data in valid_dataloader: if use_cuda: data = {key: value.to(device) for key, value in data.items()} probs = model(data) if is_sequential: mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)] macro_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask]) else: macro_accuracy(probs.argmax(-1).view(-1), data["labels"].view(-1)) valid_macro_accuracy = macro_accuracy.compute() # print(f"Epoch {epoch}: Valid_accuracy={total_macro_accuracy}", flush=True) macro_accuracy.reset() model.train() tq1.set_postfix(Train_accuracy=train_macro_accuracy.item(), Valid_accuracy=valid_macro_accuracy.item()) # update the best epoch if valid_macro_accuracy >= best_valid_accuracy: best_valid_accuracy = valid_macro_accuracy with open(f"{args.model_checkpoint_path}/best_epoch.dat", "w") as f: f.write(str(epoch)) #-------------------- # update the weights #-------------------- if epoch < args.epochs: with tqdm(train_dataloader, leave=False) as tq: tq.set_description(f"Epoch {epoch}") for data in tq: if use_cuda: data = {key: value.to(device) for key, value in data.items()} out = model(data) if is_sequential: loss = loss_func(out, data["labels"], data["pad_mask"]) else: loss = loss_func(out, data["labels"]) # if is_sequential: # # mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)] # # # out = out.view(-1, out.size(-1)) # # # bincount = data["labels"].view(-1)[mask].bincount() # # # weight = bincount.min() / bincount # # # loss = F.nll_loss(out[mask], data["labels"].view(-1)[mask], weight=weight) # # bin = batched_bincount(data["labels"].T, 1, out.size(-1)) # [max_seq_length x num_classes] # # bin_max, _ = bin.max(-1) # # weight = bin_max[:, None] / (bin + 1e-8) # # weight = weight / weight.max(-1, keepdim=True)[0] # # # weight = (1 - beta) / (1 - beta**bin) # # # print(weight) # # loss = 0.0 # torch.FloatTensor([0.0]).to(device) # # for seq_no in range(weight.size(0)): # # loss += F.nll_loss(out[:, seq_no], data["labels"][:, seq_no], weight=weight[seq_no]) # else: # bincount = data["labels"].view(-1).bincount() # weight = (1 - beta) / (1 - beta**bincount) # loss = F.nll_loss(out, data["labels"].squeeze(-1), weight=weight) optimizer.zero_grad() loss.backward() optimizer.step() tq.set_postfix(Loss=loss.item()) if __name__ == "__main__": import datetime import json import os import argparse now = datetime.datetime.now() parser = argparse.ArgumentParser() # general settings parser.add_argument("-p", "--problem", default="tsptw", type=str, help="Problem type: [tsptw, cvrptw]") parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu") parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader") parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed for reproductivity") # data setting parser.add_argument("-train", "--train_dataset_path", type=str, help="Path to a read file", required=True) parser.add_argument("-valid", "--valid_dataset_path", type=str, default=None) parser.add_argument("--parallel", action="store_true") parser.add_argument("--num_cpus", type=int, default=4) # training settings parser.add_argument("-e", "--epochs", default=100, type=int, help="Number of epochs") parser.add_argument("-b", "--batch_size", default=256, type=int, help="Batch size") parser.add_argument("--lr", default=0.001, type=float, help="Learning rate") parser.add_argument("--cb_beta", default=0.99) # parser.add_argument("--valid_interval", default=1, type=int, help="interval outputting intermidiate test accuracy") # parser.add_argument("--model_save_interval", type=int, default=1) parser.add_argument("--model_checkpoint_path", type=str, default=f"checkpoints/model_{now.strftime('%Y%m%d_%H%M%S')}") # model settings parser.add_argument("-loss", "--loss_function", type=str, default="seq_cbce", help="[seq_cbce, cbce, wce, ce]") parser.add_argument("-node_enc", "--node_enc_type", type=str, default="mlp") parser.add_argument("-edge_enc", "--edge_enc_type", type=str, default="attn") parser.add_argument("-dec", "--dec_type", type=str, default="lstm") parser.add_argument("-pe", "--pos_encoder", type=str, default="sincos") parser.add_argument("--emb_dim", type=int, default=128) parser.add_argument("--num_enc_mlp_layers", type=int, default=2) parser.add_argument("--num_dec_mlp_layers", type=int, default=3) parser.add_argument("--dropout", default=0.0, type=float, help="Dropout probability") args = parser.parse_args() os.makedirs(args.model_checkpoint_path, exist_ok=True) with open(f'{args.model_checkpoint_path}/cmd_args.dat', 'w') as f: json.dump(args.__dict__, f, indent=2) main(args)