import argparse, os, torch, time import torch.optim from utils.utils import load_embedder_ckpt_with_optim, adjust_learning_rate, freeze_text_embedder, AverageMeter from utils.utils_data import init_embedding_data def train_embedding(cur_epoch, model, optimizer, trainloader, testloader, device, cfg_em): torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = True acc_train_meter = AverageMeter() acc_test_meter = AverageMeter() loss_train_meter = AverageMeter() loss_test_meter = AverageMeter() time_train_meter = AverageMeter() time_test_meter = AverageMeter() freeze_text_embedder(model) for k,v in model.named_parameters(): print('{}: {}'.format(k, v.requires_grad)) for epoch in range(cur_epoch, cfg_em.epoch+1): optimizer = adjust_learning_rate(optimizer, epoch-1, cfg_em.lr_decay) lr = optimizer.param_groups[-1]['lr'] model.train() for idx, batch in enumerate(trainloader): for i in range(len(batch)): batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu") time_start = time.time() out = model(batch, 'train') loss = out['loss_total'] acc = out['acc_type'] time_train_meter.update(time.time() - time_start) acc_train_meter.update(acc) loss_train_meter.update(loss) optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch:{epoch}|Iter:{idx+1}/{len(trainloader)}|lr:{lr},' f'Loss: {loss_train_meter.avg:.3f},' f'Acc: {acc_train_meter.avg:.3f},' f'Time: {time_train_meter.avg:.3f},', flush=True) model.eval() for idx, batch in enumerate(testloader): for i in range(len(batch)): batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu") time_start = time.time() out = model(batch, 'train') loss = out['loss_total'] acc = out['acc_type'] time_test_meter.update(time.time() - time_start) acc_test_meter.update(acc) loss_test_meter.update(loss) print(f'Epoch:{epoch}|Iter:{idx+1}/{len(testloader)}|lr:{lr},' f'Loss: {loss_test_meter.avg:.3f},' f'Acc: {acc_test_meter.avg:.3f},' f'Time: {time_test_meter.avg:.3f},', flush=True) torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, f'{cfg_em.check_dir}/embedder_model_epoch{epoch}_{acc_train_meter.avg:.3f}_{loss_train_meter.avg:.3f}_{acc_test_meter.avg:.3f}_{loss_test_meter.avg:.3f}.tar') acc_train_meter.reset() acc_test_meter.reset() loss_train_meter.reset() loss_test_meter.reset() time_train_meter.reset() time_test_meter.reset() print('Done!') os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" if __name__ == "__main__": parser = argparse.ArgumentParser() # load model parser.add_argument("--seed", type=int, default = 124) parser.add_argument("--pre_weight", type=str, default = '') parser.add_argument("--lr", type=float, default = 0.0001) parser.add_argument("--type_name", type=list, default = ['clear', 'low', 'haze', 'rain',\ 'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\ 'haze_snow', 'low_haze_rain', 'low_haze_snow']) parser.add_argument("--train-dir", type=str, default = './data/CDD-11_train/') parser.add_argument("--test-dir", type=str, default = './data/CDD-11_test/') parser.add_argument("--batch", type=int, default = 128) parser.add_argument("--num-workers", type=int, default = 0) parser.add_argument("--epoch", type=int, default = 200) parser.add_argument("--lr-decay", type=int, default = 50) parser.add_argument("--check-dir", type=str, default = "./ckpts") args = parser.parse_args() os.makedirs(args.check_dir,exist_ok=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") embedder, optimizer, cur_epoch, device = load_embedder_ckpt_with_optim(device, args) trainloader, testloader = init_embedding_data(args, 'train') train_embedding(cur_epoch, embedder, optimizer, trainloader, testloader, device, args)