AioMedica / main.py
chris1nexus
First commit
d60982d
raw
history blame
5.66 kB
#!/usr/bin/env python
# coding: utf-8
from __future__ import absolute_import, division, print_function
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from utils.dataset import GraphDataset
from utils.lr_scheduler import LR_Scheduler
from tensorboardX import SummaryWriter
from helper import Trainer, Evaluator, collate
from option import Options
from models.GraphTransformer import Classifier
from models.weight_init import weight_init
import pickle
args = Options().parse()
label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
n_class = len(label_map)
torch.cuda.synchronize()
torch.backends.cudnn.deterministic = True
data_path = args.data_path
model_path = args.model_path
if not os.path.isdir(model_path): os.mkdir(model_path)
log_path = args.log_path
if not os.path.isdir(log_path): os.mkdir(log_path)
task_name = args.task_name
print(task_name)
###################################
train = args.train
test = args.test
graphcam = args.graphcam
print("train:", train, "test:", test, "graphcam:", graphcam)
##### Load datasets
print("preparing datasets and dataloaders......")
batch_size = args.batch_size
if train:
ids_train = open(args.train_set).readlines()
dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path)
dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True)
total_train_num = len(dataloader_train) * batch_size
ids_val = open(args.val_set).readlines()
dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path)
dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True)
total_val_num = len(dataloader_val) * batch_size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##### creating models #############
print("creating models......")
num_epochs = args.num_epochs
learning_rate = args.lr
model = Classifier(n_class)
model = nn.DataParallel(model)
if args.resume:
print('load model{}'.format(args.resume))
model.load_state_dict(torch.load(args.resume))
if torch.cuda.is_available():
model = model.cuda()
#model.apply(weight_init)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150
##################################
criterion = nn.CrossEntropyLoss()
if not test:
writer = SummaryWriter(log_dir=log_path + task_name)
f_log = open(log_path + task_name + ".log", 'w')
trainer = Trainer(n_class)
evaluator = Evaluator(n_class)
best_pred = 0.0
for epoch in range(num_epochs):
# optimizer.zero_grad()
model.train()
train_loss = 0.
total = 0.
current_lr = optimizer.param_groups[0]['lr']
print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred))
if train:
for i_batch, sample_batched in enumerate(dataloader_train):
scheduler.step(epoch)
preds,labels,loss = trainer.train(sample_batched, model)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss
total += len(labels)
trainer.metrics.update(labels, preds)
if (i_batch + 1) % args.log_interval_local == 0:
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores()))
trainer.plot_cm()
if not test:
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores()))
trainer.plot_cm()
if epoch % 1 == 0:
with torch.no_grad():
model.eval()
print("evaluating...")
total = 0.
batch_idx = 0
for i_batch, sample_batched in enumerate(dataloader_val):
preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam)
total += len(labels)
evaluator.metrics.update(labels, preds)
if (i_batch + 1) % args.log_interval_local == 0:
print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores()))
evaluator.plot_cm()
print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores()))
evaluator.plot_cm()
# torch.cuda.empty_cache()
val_acc = evaluator.get_scores()
if val_acc > best_pred:
best_pred = val_acc
if not test:
print("saving model...")
torch.save(model.state_dict(), model_path + task_name + ".pth")
log = ""
log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n"
log += "================================\n"
print(log)
if test: break
f_log.write(log)
f_log.flush()
writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1)
trainer.reset_metrics()
evaluator.reset_metrics()
if not test: f_log.close()