|
import os |
|
import time |
|
from datetime import datetime |
|
from tqdm import tqdm |
|
from tensorboardX import SummaryWriter |
|
import torch |
|
import torchinfo |
|
import numpy as np |
|
|
|
import options |
|
from validate import validate, calculate_acc |
|
from datasets import * |
|
from utils.logger import create_logger |
|
from utils.earlystop import EarlyStopping |
|
from networks.trainer import Trainer |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
train_opt = options.TrainOptions().parse() |
|
|
|
|
|
|
|
logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer") |
|
logger.info(f"working dir: {train_opt.checkpoints_dir}") |
|
|
|
|
|
model = Trainer(train_opt) |
|
|
|
logger.info(model.device) |
|
|
|
|
|
train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8) |
|
logger.info(f"train {len(train_loader)}") |
|
logger.info(f"validate {len(val_loader)}") |
|
|
|
|
|
train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train")) |
|
val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val")) |
|
|
|
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True) |
|
|
|
start_time = time.time() |
|
logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20, |
|
col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)) |
|
|
|
|
|
logger.info("Length of train loader: %d" %(len(train_loader))) |
|
for epoch in range(train_opt.niter): |
|
y_true, y_pred = [], [] |
|
pbar = tqdm(train_loader) |
|
for i, data in enumerate(pbar): |
|
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) |
|
|
|
model.total_steps += 1 |
|
|
|
model.set_input(data) |
|
model.optimize_parameters() |
|
|
|
y_pred.extend(model.output.sigmoid().flatten().tolist()) |
|
y_true.extend(data[1].flatten().tolist()) |
|
|
|
if model.total_steps % train_opt.loss_freq == 0: |
|
logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps)) |
|
train_writer.add_scalar('loss', model.loss, model.total_steps) |
|
logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) ) |
|
|
|
if model.total_steps in [10,30,50,100,1000,5000,10000] and False: |
|
model.save_networks('model_iters_%s.pth' % model.total_steps) |
|
|
|
pbar.set_postfix_str(f"loss: {model.loss}, ") |
|
r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5) |
|
logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}") |
|
|
|
if epoch % train_opt.save_epoch_freq == 0: |
|
logger.info('saving the model at the end of epoch %d' % (epoch)) |
|
model.save_networks( 'model_epoch_%s.pth' % epoch ) |
|
|
|
|
|
model.eval() |
|
ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger) |
|
val_writer.add_scalar('accuracy', acc, model.total_steps) |
|
val_writer.add_scalar('ap', ap, model.total_steps) |
|
logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) |
|
|
|
early_stopping(acc, model.model) |
|
if early_stopping.early_stop: |
|
cont_train = model.adjust_learning_rate() |
|
if cont_train: |
|
logger.info("Learning rate dropped by 10, continue training...") |
|
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True) |
|
else: |
|
logger.info("Early stopping.") |
|
break |
|
|
|
model.train() |
|
|
|
|