File size: 4,474 Bytes
905cd18
 
 
db5513e
905cd18
 
 
db5513e
905cd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db5513e
905cd18
 
db5513e
 
 
 
 
 
 
905cd18
 
 
 
 
 
 
db5513e
905cd18
db5513e
905cd18
db5513e
 
 
 
 
 
 
905cd18
db5513e
905cd18
db5513e
905cd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from tqdm import tqdm

from .utils import get_lr, show_result
from .utils_metrics import PSNR, SSIM


def fit_one_epoch(G_model_train, D_model_train, G_model, D_model, VGG_feature_model, G_optimizer, D_optimizer, BCEWithLogits_loss, L1_loss, epoch, epoch_size, gen, Epoch, cuda, batch_size, save_interval):
    G_total_loss = 0
    D_total_loss = 0
    G_total_PSNR = 0
    G_total_SSIM = 0

    with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_size:
                break
            
            with torch.no_grad():
                lr_images, hr_images    = batch
                lr_images, hr_images    = torch.from_numpy(lr_images).type(torch.FloatTensor), torch.from_numpy(hr_images).type(torch.FloatTensor)
                y_real, y_fake          = torch.ones(batch_size), torch.zeros(batch_size)
                if cuda:
                    lr_images, hr_images, y_real, y_fake  = lr_images.cuda(), hr_images.cuda(), y_real.cuda(), y_fake.cuda()
            
            #-------------------------------------------------#
            #   训练判别器
            #-------------------------------------------------#
            D_optimizer.zero_grad()

            D_result_r              = D_model_train(hr_images)

            G_result                = G_model_train(lr_images)
            D_result_f              = D_model_train(G_result).squeeze()
            D_result_rf             = D_result_r - D_result_f.mean()
            D_result_fr             = D_result_f - D_result_r.mean()
            D_train_loss_rf         = BCEWithLogits_loss(D_result_rf, y_real)
            D_train_loss_fr         = BCEWithLogits_loss(D_result_fr, y_fake)
            D_train_loss            = (D_train_loss_rf + D_train_loss_fr) / 2
            D_train_loss.backward()

            D_optimizer.step()

            #-------------------------------------------------#
            #   训练生成器
            #-------------------------------------------------#
            G_optimizer.zero_grad()
            
            G_result                = G_model_train(lr_images)
            image_loss              = L1_loss(G_result, hr_images)

            D_result_r              = D_model_train(hr_images)
            D_result_f              = D_model_train(G_result).squeeze()
            D_result_rf             = D_result_r - D_result_f.mean()
            D_result_fr             = D_result_f - D_result_r.mean()
            D_train_loss_rf         = BCEWithLogits_loss(D_result_rf, y_fake)
            D_train_loss_fr         = BCEWithLogits_loss(D_result_fr, y_real)
            adversarial_loss        = (D_train_loss_rf + D_train_loss_fr) / 2

            perception_loss         = L1_loss(VGG_feature_model(G_result), VGG_feature_model(hr_images))

            G_train_loss            = image_loss + 1e-1 * adversarial_loss + 1e-1 * perception_loss 

            G_train_loss.backward()
            G_optimizer.step()
            
            G_total_loss            += G_train_loss.item()
            D_total_loss            += D_train_loss.item()

            with torch.no_grad():
                G_total_PSNR        += PSNR(G_result, hr_images).item()
                G_total_SSIM        += SSIM(G_result, hr_images).item()
                
            pbar.set_postfix(**{'G_loss'    : G_total_loss / (iteration + 1), 
                                'D_loss'    : D_total_loss / (iteration + 1), 
                                'G_PSNR'    : G_total_PSNR / (iteration + 1), 
                                'G_SSIM'    : G_total_SSIM / (iteration + 1), 
                                'lr'        : get_lr(G_optimizer)})
            pbar.update(1)

            if iteration % save_interval == 0:
                show_result(epoch + 1, G_model_train, lr_images, hr_images)

    print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
    print('G Loss: %.4f || D Loss: %.4f ' % (G_total_loss / epoch_size, D_total_loss / epoch_size))
    print('Saving state, iter:', str(epoch+1))
    
    if (epoch + 1) % 10==0:
        torch.save(G_model.state_dict(), 'logs/G_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))
        torch.save(D_model.state_dict(), 'logs/D_Epoch%d-GLoss%.4f-DLoss%.4f.pth'%((epoch + 1), G_total_loss / epoch_size, D_total_loss / epoch_size))