|
import torch |
|
import torch.nn.functional as F |
|
from models.SwinIR import compute_gradient_penalty |
|
from tqdm import tqdm |
|
|
|
from .utils import get_lr, show_result |
|
from .utils_metrics import PSNR, SSIM |
|
|
|
|
|
|
|
def fit_one_epoch(writer, G_model_train, D_model_train, G_model, D_model, VGG_feature_model, ResNeSt_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, Face_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval): |
|
G_total_loss = 0 |
|
D_total_loss = 0 |
|
G_total_PSNR = 0 |
|
G_total_SSIM = 0 |
|
|
|
with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3, ncols=150) as pbar: |
|
for iteration, batch in enumerate(gen): |
|
if iteration >= epoch_size: |
|
break |
|
|
|
with torch.no_grad(): |
|
lr_images, hr_images = batch |
|
lr_images, hr_images = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor) |
|
y_real, y_fake = torch.ones(batch_size), torch.zeros(batch_size) |
|
if cuda: |
|
lr_images, hr_images, y_real, y_fake = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda() |
|
|
|
|
|
|
|
|
|
D_optimizer.zero_grad() |
|
|
|
D_result_r = D_model_train(hr_images) |
|
|
|
G_result = G_model_train(lr_images) |
|
D_result_f = D_model_train(G_result).squeeze() |
|
D_result_rf = D_result_r - D_result_f.mean() |
|
D_result_fr = D_result_f - D_result_r.mean() |
|
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_real) |
|
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_fake) |
|
gradient_penalty = compute_gradient_penalty(D_model_train, hr_images, G_result) |
|
D_train_loss = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 |
|
D_train_loss.backward() |
|
|
|
D_optimizer.step() |
|
|
|
|
|
|
|
|
|
G_optimizer.zero_grad() |
|
|
|
G_result = G_model_train(lr_images) |
|
image_loss = L1_loss(G_result, hr_images) |
|
|
|
D_result_r = D_model_train(hr_images) |
|
D_result_f = D_model_train(G_result).squeeze() |
|
D_result_rf = D_result_r - D_result_f.mean() |
|
D_result_fr = D_result_f - D_result_r.mean() |
|
D_train_loss_rf = BCEWithLogits_loss(D_result_rf, y_fake) |
|
D_train_loss_fr = BCEWithLogits_loss(D_result_fr, y_real) |
|
adversarial_loss = (D_train_loss_rf + D_train_loss_fr) / 2 |
|
|
|
perception_loss = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images)) |
|
|
|
G_result_face = F.interpolate(G_result, size=(112, 112), mode='bicubic', align_corners=True) |
|
hr_images_face = F.interpolate(hr_images, size=(112, 112), mode='bicubic', align_corners=True) |
|
face_loss = torch.mean(1. - Face_loss(ResNeSt_model(G_result_face), ResNeSt_model(hr_images_face))) |
|
G_train_loss = 3.0 * image_loss + 1.0 * adversarial_loss + 0.9 * perception_loss + 2.5 * face_loss |
|
|
|
G_train_loss.backward() |
|
G_optimizer.step() |
|
|
|
G_total_loss += G_train_loss.item() |
|
D_total_loss += D_train_loss.item() |
|
|
|
with torch.no_grad(): |
|
G_total_PSNR += PSNR(G_result, hr_images).item() |
|
G_total_SSIM += SSIM(G_result, hr_images).item() |
|
|
|
pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1), |
|
'D_loss' : D_total_loss / (iteration + 1), |
|
'G_PSNR' : G_total_PSNR / (iteration + 1), |
|
'G_SSIM' : G_total_SSIM / (iteration + 1), |
|
'lr' : get_lr(G_optimizer)}) |
|
pbar.update(1) |
|
|
|
if iteration % save_interval == 0: |
|
show_result(epoch + 1, G_model_train, lr_images, hr_images) |
|
writer.add_scalar('G_loss', G_total_loss / (iteration + 1), epoch + 1) |
|
writer.add_scalar('D_loss', D_total_loss / (iteration + 1), epoch + 1) |
|
writer.add_scalar('G_PSNR', G_total_PSNR / (iteration + 1), epoch + 1) |
|
writer.add_scalar('G_SSIM', G_total_SSIM / (iteration + 1), epoch + 1) |
|
writer.add_scalar('lr', get_lr(G_optimizer), epoch + 1) |
|
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) |
|
print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size)) |
|
print('Saving state, iter:', str(epoch+1)) |
|
|
|
torch.save(G_model, 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size)) |
|
torch.save(D_model, 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size)) |
|
|