FakeVideoDetect / train.py
ybbwcwaps
AI Video
3cc4a06
import os
import time
from datetime import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torch
import torchinfo
import numpy as np
import options
from validate import validate, calculate_acc
from datasetss import *
from utilss.logger import create_logger
from utilss.earlystop import EarlyStopping
from networks.trainer import Trainer
if __name__ == '__main__':
train_opt = options.TrainOptions().parse()
# val_opt = options.TestOptions().parse()
# logger
logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer")
logger.info(f"working dir: {train_opt.checkpoints_dir}")
model = Trainer(train_opt)
# logger.info(opt.gpu_ids[0])
logger.info(model.device)
# extract_feature_model = model.extract_feature_model
train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8)
logger.info(f"train {len(train_loader)}")
logger.info(f"validate {len(val_loader)}")
train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train"))
val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val"))
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True)
start_time = time.time()
logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20,
col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0))
logger.info("Length of train loader: %d" %(len(train_loader)))
for epoch in range(train_opt.niter):
y_true, y_pred = [], []
pbar = tqdm(train_loader)
for i, data in enumerate(pbar):
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
model.total_steps += 1
model.set_input(data)
model.optimize_parameters()
y_pred.extend(model.output.sigmoid().flatten().tolist())
y_true.extend(data[1].flatten().tolist())
if model.total_steps % train_opt.loss_freq == 0:
logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps))
train_writer.add_scalar('loss', model.loss, model.total_steps)
logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) )
if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters
model.save_networks('model_iters_%s.pth' % model.total_steps)
# logger.info("trained one batch")
pbar.set_postfix_str(f"loss: {model.loss}, ")
r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5)
logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}")
if epoch % train_opt.save_epoch_freq == 0:
logger.info('saving the model at the end of epoch %d' % (epoch))
model.save_networks( 'model_epoch_%s.pth' % epoch )
# Validation
model.eval()
ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger)
val_writer.add_scalar('accuracy', acc, model.total_steps)
val_writer.add_scalar('ap', ap, model.total_steps)
logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
early_stopping(acc, model.model)
if early_stopping.early_stop:
cont_train = model.adjust_learning_rate()
if cont_train:
logger.info("Learning rate dropped by 10, continue training...")
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True)
else:
logger.info("Early stopping.")
break
model.train()