route-explainer / train.py
daisuke.kikuta
first commit
719d0db
raw
history blame
13.2 kB
from tqdm.autonotebook import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from models.classifiers.nn_classifiers.nn_classifier import NNClassifier
from models.loss_functions import GeneralCrossEntropy
from utils.data_utils.tsptw_dataset import TSPTWDataloader
from utils.data_utils.pctsp_dataset import PCTSPDataloader
from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader
from utils.data_utils.cvrp_dataset import CVRPDataloader
from utils.utils import set_device, count_trainable_params, batched_bincount, fix_seed
def main(args):
#---------------
# seed settings
#---------------
fix_seed(args.seed)
#--------------
# gpu settings
#--------------
use_cuda, device = set_device(args.gpu)
#-------------------
# model & optimizer
#-------------------
num_classes = 3 if args.problem == "pctsptw" else 2
model = NNClassifier(problem=args.problem,
node_enc_type=args.node_enc_type,
edge_enc_type=args.edge_enc_type,
dec_type=args.dec_type,
emb_dim=args.emb_dim,
num_enc_mlp_layers=args.num_enc_mlp_layers,
num_dec_mlp_layers=args.num_dec_mlp_layers,
num_classes=num_classes,
dropout=args.dropout,
pos_encoder=args.pos_encoder)
is_sequential = model.is_sequential
if use_cuda:
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# count number of trainable parameters
num_trainable_params = count_trainable_params(model)
print(f"num_trainable_params: {num_trainable_params}")
with open(f"{args.model_checkpoint_path}/num_trainable_params.dat", "w") as f:
f.write(str(num_trainable_params))
# loss function
if not is_sequential:
assert args.loss_function != "seq_cbce", "Non-sequential model does not support the loss funtion: seq_cbce"
loss_func = GeneralCrossEntropy(weight_type=args.loss_function, beta=args.cb_beta, is_sequential=is_sequential)
#---------
# dataset
#---------
if args.problem == "tsptw":
train_dataset = TSPTWDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
if args.valid_dataset_path is not None:
valid_dataset = TSPTWDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
elif args.problem == "pctsp":
train_dataset = PCTSPDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
if args.valid_dataset_path is not None:
valid_dataset = PCTSPDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
elif args.problem == "pctsptw":
train_dataset = PCTSPTWDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
if args.valid_dataset_path is not None:
valid_dataset = PCTSPTWDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
elif args.problem == "cvrp":
train_dataset = CVRPDataloader(args.train_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
if args.valid_dataset_path is not None:
valid_dataset = CVRPDataloader(args.valid_dataset_path, sequential=is_sequential, parallel=args.parallel, num_cpus=args.num_cpus)
else:
raise NotImplementedError
#------------
# dataloader
#------------
if is_sequential:
def pad_seq_length(batch):
data = {}
for key in batch[0].keys():
padding_value = True if key == "mask" else 0.0
# post-padding
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
data.update({"pad_mask": pad_mask})
return data
collate_fn = pad_seq_length
else:
collate_fn = None
train_dataloader = DataLoader(train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=args.num_workers)
if args.valid_dataset_path is not None:
valid_dataloader = DataLoader(valid_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=args.num_workers)
#---------
# metrics
#---------
macro_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro")
if use_cuda:
macro_accuracy.to(device)
#---------------
# training loop
#---------------
best_valid_accuracy = 0.0
model.train()
with tqdm(range(args.epochs + 1)) as tq1:
for epoch in tq1:
#--------------------------
# save the current weights
#--------------------------
# print(f"Epoch {epoch}: saving a model to {args.model_checkpoint_path}/model_epoch{epoch}.pth...", end="", flush=True)
torch.save(model.cpu().state_dict(), f"{args.model_checkpoint_path}/model_epoch{epoch}.pth")
model.to(device)
# print("done.")
#------------
# validation
#------------
model.eval()
with torch.no_grad():
tq1.set_description(f"Epoch {epoch}")
# check train accuracy
for data in train_dataloader:
if use_cuda:
data = {key: value.to(device) for key, value in data.items()}
probs = model(data)
if is_sequential:
mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)]
macro_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask])
else:
macro_accuracy(probs.argmax(-1).view(-1), data["labels"].view(-1))
train_macro_accuracy = macro_accuracy.compute()
# print(f"Epoch {epoch}: Train_accuracy={total_macro_accuracy}", flush=True)
macro_accuracy.reset()
# check valid accuracy
if args.valid_dataset_path is not None:
for data in valid_dataloader:
if use_cuda:
data = {key: value.to(device) for key, value in data.items()}
probs = model(data)
if is_sequential:
mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)]
macro_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask])
else:
macro_accuracy(probs.argmax(-1).view(-1), data["labels"].view(-1))
valid_macro_accuracy = macro_accuracy.compute()
# print(f"Epoch {epoch}: Valid_accuracy={total_macro_accuracy}", flush=True)
macro_accuracy.reset()
model.train()
tq1.set_postfix(Train_accuracy=train_macro_accuracy.item(), Valid_accuracy=valid_macro_accuracy.item())
# update the best epoch
if valid_macro_accuracy >= best_valid_accuracy:
best_valid_accuracy = valid_macro_accuracy
with open(f"{args.model_checkpoint_path}/best_epoch.dat", "w") as f:
f.write(str(epoch))
#--------------------
# update the weights
#--------------------
if epoch < args.epochs:
with tqdm(train_dataloader, leave=False) as tq:
tq.set_description(f"Epoch {epoch}")
for data in tq:
if use_cuda:
data = {key: value.to(device) for key, value in data.items()}
out = model(data)
if is_sequential:
loss = loss_func(out, data["labels"], data["pad_mask"])
else:
loss = loss_func(out, data["labels"])
# if is_sequential:
# # mask = data["pad_mask"].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)]
# # # out = out.view(-1, out.size(-1))
# # # bincount = data["labels"].view(-1)[mask].bincount()
# # # weight = bincount.min() / bincount
# # # loss = F.nll_loss(out[mask], data["labels"].view(-1)[mask], weight=weight)
# # bin = batched_bincount(data["labels"].T, 1, out.size(-1)) # [max_seq_length x num_classes]
# # bin_max, _ = bin.max(-1)
# # weight = bin_max[:, None] / (bin + 1e-8)
# # weight = weight / weight.max(-1, keepdim=True)[0]
# # # weight = (1 - beta) / (1 - beta**bin)
# # # print(weight)
# # loss = 0.0 # torch.FloatTensor([0.0]).to(device)
# # for seq_no in range(weight.size(0)):
# # loss += F.nll_loss(out[:, seq_no], data["labels"][:, seq_no], weight=weight[seq_no])
# else:
# bincount = data["labels"].view(-1).bincount()
# weight = (1 - beta) / (1 - beta**bincount)
# loss = F.nll_loss(out, data["labels"].squeeze(-1), weight=weight)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tq.set_postfix(Loss=loss.item())
if __name__ == "__main__":
import datetime
import json
import os
import argparse
now = datetime.datetime.now()
parser = argparse.ArgumentParser()
# general settings
parser.add_argument("-p", "--problem", default="tsptw", type=str, help="Problem type: [tsptw, cvrptw]")
parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu")
parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader")
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed for reproductivity")
# data setting
parser.add_argument("-train", "--train_dataset_path", type=str, help="Path to a read file", required=True)
parser.add_argument("-valid", "--valid_dataset_path", type=str, default=None)
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--num_cpus", type=int, default=4)
# training settings
parser.add_argument("-e", "--epochs", default=100, type=int, help="Number of epochs")
parser.add_argument("-b", "--batch_size", default=256, type=int, help="Batch size")
parser.add_argument("--lr", default=0.001, type=float, help="Learning rate")
parser.add_argument("--cb_beta", default=0.99)
# parser.add_argument("--valid_interval", default=1, type=int, help="interval outputting intermidiate test accuracy")
# parser.add_argument("--model_save_interval", type=int, default=1)
parser.add_argument("--model_checkpoint_path", type=str, default=f"checkpoints/model_{now.strftime('%Y%m%d_%H%M%S')}")
# model settings
parser.add_argument("-loss", "--loss_function", type=str, default="seq_cbce", help="[seq_cbce, cbce, wce, ce]")
parser.add_argument("-node_enc", "--node_enc_type", type=str, default="mlp")
parser.add_argument("-edge_enc", "--edge_enc_type", type=str, default="attn")
parser.add_argument("-dec", "--dec_type", type=str, default="lstm")
parser.add_argument("-pe", "--pos_encoder", type=str, default="sincos")
parser.add_argument("--emb_dim", type=int, default=128)
parser.add_argument("--num_enc_mlp_layers", type=int, default=2)
parser.add_argument("--num_dec_mlp_layers", type=int, default=3)
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout probability")
args = parser.parse_args()
os.makedirs(args.model_checkpoint_path, exist_ok=True)
with open(f'{args.model_checkpoint_path}/cmd_args.dat', 'w') as f:
json.dump(args.__dict__, f, indent=2)
main(args)