|
|
|
import os |
|
from config import Config |
|
|
|
opt = Config('training.yml') |
|
|
|
|
|
|
|
|
|
import torch |
|
print(torch.cuda.is_available()) |
|
gpus = ','.join([str(i) for i in opt.GPU]) |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpus |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import utilss |
|
from torch.utils.data import DataLoader |
|
import random |
|
import time |
|
import numpy as np |
|
|
|
from data_RGB import get_training_data, get_validation_data |
|
from Restormer import Restormer |
|
import losses |
|
from warmup_scheduler import GradualWarmupScheduler |
|
from tqdm import tqdm |
|
from pdb import set_trace as stx |
|
import utils |
|
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
random.seed(1234) |
|
np.random.seed(1234) |
|
torch.manual_seed(1234) |
|
torch.cuda.manual_seed_all(1234) |
|
|
|
start_epoch = 1 |
|
mode = opt.MODEL.MODE |
|
session = opt.MODEL.SESSION |
|
|
|
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) |
|
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) |
|
|
|
utilss.dir_utils.mkdir(result_dir) |
|
utilss.dir_utils.mkdir(model_dir) |
|
|
|
train_dir = opt.TRAINING.TRAIN_DIR |
|
val_dir = opt.TRAINING.VAL_DIR |
|
|
|
|
|
|
|
|
|
model_restoration = Restormer() |
|
|
|
|
|
device_ids = [i for i in range(torch.cuda.device_count())] |
|
|
|
|
|
|
|
|
|
print(torch.cuda.is_available()) |
|
|
|
if torch.cuda.device_count() > 1: |
|
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") |
|
model_restoration.cuda() |
|
|
|
new_lr = opt.OPTIM.LR_INITIAL |
|
|
|
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8) |
|
|
|
|
|
warmup_epochs = 3 |
|
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS - warmup_epochs, |
|
eta_min=opt.OPTIM.LR_MIN) |
|
|
|
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, |
|
after_scheduler=scheduler_cosine) |
|
|
|
|
|
if opt.TRAINING.RESUME: |
|
path_chk_rest = utilss.get_last_path(model_dir, '_latest.pth') |
|
utilss.load_checkpoint(model_restoration, path_chk_rest) |
|
start_epoch = utilss.load_start_epoch(path_chk_rest) + 1 |
|
utilss.load_optim(optimizer, path_chk_rest) |
|
|
|
for i in range(1, start_epoch): |
|
scheduler.step() |
|
new_lr = scheduler.get_lr()[0] |
|
print('------------------------------------------------------------------------------') |
|
print("==> Resuming Training with learning rate:", new_lr) |
|
print('------------------------------------------------------------------------------') |
|
|
|
if len(device_ids) > 1: |
|
print("有多个GPU,可以进行数据并行运算") |
|
model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids) |
|
|
|
|
|
criterion_char = losses.CharbonnierLoss() |
|
criterion_edge = losses.EdgeLoss() |
|
|
|
|
|
|
|
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) |
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=8, |
|
drop_last=False, pin_memory=True) |
|
|
|
|
|
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) |
|
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False, |
|
pin_memory=True) |
|
|
|
print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1)) |
|
print('===> Loading datasets') |
|
|
|
best_psnr = 0 |
|
best_epoch = 0 |
|
|
|
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): |
|
epoch_start_time = time.time() |
|
epoch_loss = 0 |
|
train_id = 1 |
|
|
|
model_restoration.train() |
|
for i, data in enumerate(tqdm(train_loader), 0): |
|
|
|
|
|
for param in model_restoration.parameters(): |
|
param.grad = None |
|
|
|
target = data[0].cuda() |
|
input_ = data[1].cuda() |
|
|
|
|
|
print("before in model,input_.shape: ",input_.shape) |
|
|
|
if hasattr(torch.cuda, 'empty_cache'): |
|
torch.cuda.empty_cache() |
|
|
|
restored = model_restoration(input_) |
|
|
|
|
|
loss_char = criterion_char(restored, target) |
|
loss_edge = criterion_edge(restored, target) |
|
|
|
loss = (loss_char) + (0.05 * loss_edge) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: |
|
model_restoration.eval() |
|
psnr_val_rgb = [] |
|
for ii, data_val in enumerate((val_loader), 0): |
|
target = data_val[0].cuda() |
|
input_ = data_val[1].cuda() |
|
|
|
|
|
|
|
restored = 0 |
|
with torch.no_grad(): |
|
restored = model_restoration(input_) |
|
|
|
|
|
for res, tar in zip(restored, target): |
|
psnr_val_rgb.append(utilss.torchPSNR(res, tar)) |
|
|
|
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() |
|
|
|
if psnr_val_rgb > best_psnr: |
|
best_psnr = psnr_val_rgb |
|
best_epoch = epoch |
|
torch.save({'epoch': epoch, |
|
'state_dict': model_restoration.state_dict(), |
|
'optimizer': optimizer.state_dict() |
|
}, os.path.join(model_dir, "model_best.pth")) |
|
|
|
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) |
|
|
|
torch.save({'epoch': epoch, |
|
'state_dict': model_restoration.state_dict(), |
|
'optimizer': optimizer.state_dict() |
|
}, os.path.join(model_dir, f"model_epoch_{epoch}.pth")) |
|
|
|
scheduler.step() |
|
|
|
print("------------------------------------------------------------------") |
|
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time() - epoch_start_time, |
|
epoch_loss, scheduler.get_lr()[0])) |
|
print("------------------------------------------------------------------") |
|
|
|
torch.save({'epoch': epoch, |
|
'state_dict': model_restoration.state_dict(), |
|
'optimizer': optimizer.state_dict() |
|
}, os.path.join(model_dir, "model_latest.pth")) |
|
|