File size: 5,512 Bytes
73ca179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn.functional as F
from models.SwinIR import compute_gradient_penalty
from tqdm import tqdm

from .utils import get_lr, show_result
from .utils_metrics import PSNR, SSIM



def fit_one_epoch(writer, G_model_train, D_model_train, G_model, D_model, VGG_feature_model, ResNeSt_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, Face_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
    G_total_loss = 0
    D_total_loss = 0
    G_total_PSNR = 0
    G_total_SSIM = 0

    with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3, ncols=150) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_size:
                break
            
            with torch.no_grad():
                lr_images, hr_images    = batch
                lr_images, hr_images    = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
                y_real, y_fake          = torch.ones(batch_size), torch.zeros(batch_size)
                if cuda:
                    lr_images, hr_images, y_real, y_fake  = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
            
            #-------------------------------------------------#
            #   训练判别器
            #-------------------------------------------------#
            D_optimizer.zero_grad()

            D_result_r              = D_model_train(hr_images)

            G_result                = G_model_train(lr_images)
            D_result_f              = D_model_train(G_result).squeeze()
            D_result_rf             = D_result_r - D_result_f.mean()
            D_result_fr             = D_result_f - D_result_r.mean()
            D_train_loss_rf         = BCEWithLogits_loss(D_result_rf, y_real)
            D_train_loss_fr         = BCEWithLogits_loss(D_result_fr, y_fake)
            gradient_penalty        = compute_gradient_penalty(D_model_train, hr_images, G_result)
            D_train_loss            = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
            D_train_loss.backward()

            D_optimizer.step()

            #-------------------------------------------------#
            #   训练生成器
            #-------------------------------------------------#
            G_optimizer.zero_grad()
            
            G_result                = G_model_train(lr_images)
            image_loss              = L1_loss(G_result, hr_images)

            D_result_r              = D_model_train(hr_images)
            D_result_f              = D_model_train(G_result).squeeze()
            D_result_rf             = D_result_r - D_result_f.mean()
            D_result_fr             = D_result_f - D_result_r.mean()
            D_train_loss_rf         = BCEWithLogits_loss(D_result_rf, y_fake)
            D_train_loss_fr         = BCEWithLogits_loss(D_result_fr, y_real)
            adversarial_loss        = (D_train_loss_rf + D_train_loss_fr) / 2

            perception_loss         = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))
            # 进行下采样以适配人脸识别网络
            G_result_face           = F.interpolate(G_result, size=(112, 112), mode='bicubic', align_corners=True)
            hr_images_face          = F.interpolate(hr_images, size=(112, 112), mode='bicubic', align_corners=True)
            face_loss               = torch.mean(1. - Face_loss(ResNeSt_model(G_result_face), ResNeSt_model(hr_images_face)))
            G_train_loss            = 3.0 * image_loss + 1.0 * adversarial_loss + 0.9 * perception_loss + 2.5 * face_loss

            G_train_loss.backward()
            G_optimizer.step()
            
            G_total_loss            += G_train_loss.item()
            D_total_loss            += D_train_loss.item()

            with torch.no_grad():
                G_total_PSNR        += PSNR(G_result, hr_images).item()
                G_total_SSIM        += SSIM(G_result, hr_images).item()
                
            pbar.set_postfix(**{'G_loss'    : G_total_loss / (iteration + 1), 
                                'D_loss'    : D_total_loss / (iteration + 1), 
                                'G_PSNR'    : G_total_PSNR / (iteration + 1), 
                                'G_SSIM'    : G_total_SSIM / (iteration + 1), 
                                'lr'        : get_lr(G_optimizer)})
            pbar.update(1)

            if iteration % save_interval == 0:
                show_result(epoch + 1, G_model_train, lr_images, hr_images)
        writer.add_scalar('G_loss', G_total_loss / (iteration + 1), epoch + 1)
        writer.add_scalar('D_loss', D_total_loss / (iteration + 1), epoch + 1)
        writer.add_scalar('G_PSNR', G_total_PSNR / (iteration + 1), epoch + 1)
        writer.add_scalar('G_SSIM', G_total_SSIM / (iteration + 1), epoch + 1)
        writer.add_scalar('lr',     get_lr(G_optimizer),            epoch + 1)
    print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
    print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
    print('Saving state, iter:', str(epoch+1))
    # 保存模型权重
    torch.save(G_model, 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
    torch.save(D_model, 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))