|
import argparse, os, torch, time
|
|
import torch.optim
|
|
|
|
from utils.utils import load_embedder_ckpt_with_optim, adjust_learning_rate, freeze_text_embedder, AverageMeter
|
|
from utils.utils_data import init_embedding_data
|
|
|
|
|
|
|
|
def train_embedding(cur_epoch, model, optimizer, trainloader, testloader, device, cfg_em):
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
acc_train_meter = AverageMeter()
|
|
acc_test_meter = AverageMeter()
|
|
loss_train_meter = AverageMeter()
|
|
loss_test_meter = AverageMeter()
|
|
time_train_meter = AverageMeter()
|
|
time_test_meter = AverageMeter()
|
|
|
|
freeze_text_embedder(model)
|
|
for k,v in model.named_parameters():
|
|
print('{}: {}'.format(k, v.requires_grad))
|
|
for epoch in range(cur_epoch, cfg_em.epoch+1):
|
|
|
|
optimizer = adjust_learning_rate(optimizer, epoch-1, cfg_em.lr_decay)
|
|
lr = optimizer.param_groups[-1]['lr']
|
|
|
|
model.train()
|
|
for idx, batch in enumerate(trainloader):
|
|
for i in range(len(batch)):
|
|
batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
|
|
time_start = time.time()
|
|
out = model(batch, 'train')
|
|
loss = out['loss_total']
|
|
acc = out['acc_type']
|
|
time_train_meter.update(time.time() - time_start)
|
|
|
|
acc_train_meter.update(acc)
|
|
loss_train_meter.update(loss)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
print(f'Epoch:{epoch}|Iter:{idx+1}/{len(trainloader)}|lr:{lr},'
|
|
f'Loss: {loss_train_meter.avg:.3f},'
|
|
f'Acc: {acc_train_meter.avg:.3f},'
|
|
f'Time: {time_train_meter.avg:.3f},', flush=True)
|
|
|
|
model.eval()
|
|
for idx, batch in enumerate(testloader):
|
|
for i in range(len(batch)):
|
|
batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
time_start = time.time()
|
|
out = model(batch, 'train')
|
|
loss = out['loss_total']
|
|
acc = out['acc_type']
|
|
time_test_meter.update(time.time() - time_start)
|
|
|
|
acc_test_meter.update(acc)
|
|
loss_test_meter.update(loss)
|
|
print(f'Epoch:{epoch}|Iter:{idx+1}/{len(testloader)}|lr:{lr},'
|
|
f'Loss: {loss_test_meter.avg:.3f},'
|
|
f'Acc: {acc_test_meter.avg:.3f},'
|
|
f'Time: {time_test_meter.avg:.3f},', flush=True)
|
|
|
|
torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()},
|
|
f'{cfg_em.check_dir}/embedder_model_epoch{epoch}_{acc_train_meter.avg:.3f}_{loss_train_meter.avg:.3f}_{acc_test_meter.avg:.3f}_{loss_test_meter.avg:.3f}.tar')
|
|
acc_train_meter.reset()
|
|
acc_test_meter.reset()
|
|
loss_train_meter.reset()
|
|
loss_test_meter.reset()
|
|
time_train_meter.reset()
|
|
time_test_meter.reset()
|
|
print('Done!')
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--seed", type=int, default = 124)
|
|
parser.add_argument("--pre_weight", type=str, default = '')
|
|
parser.add_argument("--lr", type=float, default = 0.0001)
|
|
parser.add_argument("--type_name", type=list, default = ['clear', 'low', 'haze', 'rain',\
|
|
'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
|
|
'haze_snow', 'low_haze_rain', 'low_haze_snow'])
|
|
parser.add_argument("--train-dir", type=str, default = './data/CDD-11_train/')
|
|
parser.add_argument("--test-dir", type=str, default = './data/CDD-11_test/')
|
|
parser.add_argument("--batch", type=int, default = 128)
|
|
parser.add_argument("--num-workers", type=int, default = 0)
|
|
parser.add_argument("--epoch", type=int, default = 200)
|
|
parser.add_argument("--lr-decay", type=int, default = 50)
|
|
parser.add_argument("--check-dir", type=str, default = "./ckpts")
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.makedirs(args.check_dir,exist_ok=True)
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
embedder, optimizer, cur_epoch, device = load_embedder_ckpt_with_optim(device, args)
|
|
trainloader, testloader = init_embedding_data(args, 'train')
|
|
train_embedding(cur_epoch, embedder, optimizer, trainloader, testloader, device, args)
|
|
|