demo / train.py
ybbwcwaps
Add FakeVideoDetect
e8e478e
raw
history blame
4.02 kB
import os
import time
from datetime import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torch
import torchinfo
import numpy as np
import options
from validate import validate, calculate_acc
from datasets import *
from utils.logger import create_logger
from utils.earlystop import EarlyStopping
from networks.trainer import Trainer
if __name__ == '__main__':
train_opt = options.TrainOptions().parse()
# val_opt = options.TestOptions().parse()
# logger
logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer")
logger.info(f"working dir: {train_opt.checkpoints_dir}")
model = Trainer(train_opt)
# logger.info(opt.gpu_ids[0])
logger.info(model.device)
# extract_feature_model = model.extract_feature_model
train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8)
logger.info(f"train {len(train_loader)}")
logger.info(f"validate {len(val_loader)}")
train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train"))
val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val"))
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True)
start_time = time.time()
logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20,
col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0))
logger.info("Length of train loader: %d" %(len(train_loader)))
for epoch in range(train_opt.niter):
y_true, y_pred = [], []
pbar = tqdm(train_loader)
for i, data in enumerate(pbar):
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
model.total_steps += 1
model.set_input(data)
model.optimize_parameters()
y_pred.extend(model.output.sigmoid().flatten().tolist())
y_true.extend(data[1].flatten().tolist())
if model.total_steps % train_opt.loss_freq == 0:
logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps))
train_writer.add_scalar('loss', model.loss, model.total_steps)
logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) )
if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters
model.save_networks('model_iters_%s.pth' % model.total_steps)
# logger.info("trained one batch")
pbar.set_postfix_str(f"loss: {model.loss}, ")
r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5)
logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}")
if epoch % train_opt.save_epoch_freq == 0:
logger.info('saving the model at the end of epoch %d' % (epoch))
model.save_networks( 'model_epoch_%s.pth' % epoch )
# Validation
model.eval()
ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger)
val_writer.add_scalar('accuracy', acc, model.total_steps)
val_writer.add_scalar('ap', ap, model.total_steps)
logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
early_stopping(acc, model.model)
if early_stopping.early_stop:
cont_train = model.adjust_learning_rate()
if cont_train:
logger.info("Learning rate dropped by 10, continue training...")
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True)
else:
logger.info("Early stopping.")
break
model.train()