File size: 13,154 Bytes
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from nets.cyclegan import compute_gradient_penalty
from utils.utils import get_lr, show_result


def fit_one_epoch(G_model_A2B_train, G_model_B2A_train, D_model_A_train, D_model_B_train, G_model_A2B, G_model_B2A, D_model_A, D_model_B, VGG_feature_model, ResNeSt_model, loss_history, 
                G_optimizer, D_optimizer_A, D_optimizer_B, BCE_loss, L1_loss, Face_loss, epoch, epoch_step, gen, Epoch, cuda, fp16, scaler, save_period, save_dir, photo_save_step, local_rank=0):
    G_total_loss    = 0
    D_total_loss_A  = 0
    D_total_loss_B  = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:
            break
        
        images_A, images_B = batch[0], batch[1]
        batch_size  = images_A.size()[0]
        y_real      = torch.ones(batch_size)
        y_fake      = torch.zeros(batch_size)
        
        with torch.no_grad():
            if cuda:
                images_A, images_B, y_real, y_fake  = images_A.cuda(local_rank), images_B.cuda(local_rank), y_real.cuda(local_rank), y_fake.cuda(local_rank)

        if not fp16:
            #---------------------------------#
            #   训练生成器A2B和B2A
            #---------------------------------#
            G_optimizer.zero_grad()
            
            Same_B          = G_model_A2B_train(images_B)
            loss_identity_B = L1_loss(Same_B, images_B)
            
            Same_A          = G_model_B2A_train(images_A)
            loss_identity_A = L1_loss(Same_A, images_A)
            
            fake_B          = G_model_A2B_train(images_A)
            pred_real       = D_model_B_train(images_B)
            pred_fake       = D_model_B_train(fake_B)
            pred_rf         = pred_real - pred_fake.mean()
            pred_fr         = pred_fake - pred_real.mean()
            D_train_loss_rf = BCE_loss(pred_rf, y_fake)
            D_train_loss_fr = BCE_loss(pred_fr, y_real)
            loss_GAN_A2B    = (D_train_loss_rf + D_train_loss_fr) / 2

            fake_A          = G_model_B2A_train(images_B)
            pred_real       = D_model_A_train(images_A)
            pred_fake       = D_model_A_train(fake_A)
            pred_rf         = pred_real - pred_fake.mean()
            pred_fr         = pred_fake - pred_real.mean()
            D_train_loss_rf = BCE_loss(pred_rf, y_fake)
            D_train_loss_fr = BCE_loss(pred_fr, y_real)
            loss_GAN_B2A    = (D_train_loss_rf + D_train_loss_fr) / 2
            
            recovered_A     = G_model_B2A_train(fake_B)
            loss_cycle_ABA  = L1_loss(recovered_A, images_A)

            loss_per_ABA    = L1_loss(VGG_feature_model(recovered_A), VGG_feature_model(images_A))

            recovered_A_face  = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True)
            images_A_face     = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True)
            loss_face_ABA     = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face)))

            recovered_B     = G_model_A2B_train(fake_A)
            loss_cycle_BAB  = L1_loss(recovered_B, images_B)

            loss_per_BAB    = L1_loss(VGG_feature_model(recovered_B), VGG_feature_model(images_B))
            
            recovered_B_face  = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True)
            images_B_face     = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True)
            loss_face_BAB     = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face)))

            G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A  + loss_per_ABA * 2.5 \
                   + loss_per_BAB *2.5 + loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5
            G_loss.backward()
            G_optimizer.step()
                
            #---------------------------------#
            #   训练评价器A
            #---------------------------------#
            D_optimizer_A.zero_grad()
            pred_real   = D_model_A_train(images_A)
            pred_fake   = D_model_A_train(fake_A.detach())
            pred_rf     = pred_real - pred_fake.mean()
            pred_fr     = pred_fake - pred_real.mean()
            D_train_loss_rf  = BCE_loss(pred_rf, y_real)
            D_train_loss_fr  = BCE_loss(pred_fr, y_fake)
            gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach())

            D_loss_A    = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
            D_loss_A.backward()
            D_optimizer_A.step()
            
            #---------------------------------#
            #   训练评价器B
            #---------------------------------#
            D_optimizer_B.zero_grad()

            pred_real   = D_model_B_train(images_B)
            pred_fake   = D_model_B_train(fake_B.detach())
            pred_rf     = pred_real - pred_fake.mean()
            pred_fr     = pred_fake - pred_real.mean()
            D_train_loss_rf  = BCE_loss(pred_rf, y_real)
            D_train_loss_fr  = BCE_loss(pred_fr, y_fake)
            gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach())

            D_loss_B    = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
            D_loss_B.backward()
            D_optimizer_B.step()

        else:
            from torch.cuda.amp import autocast

            #---------------------------------#
            #   训练生成器A2B和B2A
            #---------------------------------#
            with autocast():
                G_optimizer.zero_grad()
                Same_B          = G_model_A2B_train(images_B)
                loss_identity_B = L1_loss(Same_B, images_B)
                
                Same_A          = G_model_B2A_train(images_A)
                loss_identity_A = L1_loss(Same_A, images_A)
                
                fake_B          = G_model_A2B_train(images_A)
                pred_real       = D_model_B_train(images_B)
                pred_fake       = D_model_B_train(fake_B)
                pred_rf         = pred_real - pred_fake.mean()
                pred_fr         = pred_fake - pred_real.mean()
                D_train_loss_rf = BCE_loss(pred_rf, y_fake)
                D_train_loss_fr = BCE_loss(pred_fr, y_real)
                loss_GAN_A2B    = (D_train_loss_rf + D_train_loss_fr) / 2

                fake_A          = G_model_B2A_train(images_B)
                pred_real       = D_model_A_train(images_A)
                pred_fake       = D_model_A_train(fake_A)
                pred_rf         = pred_real - pred_fake.mean()
                pred_fr         = pred_fake - pred_real.mean()
                D_train_loss_rf = BCE_loss(pred_rf, y_fake)
                D_train_loss_fr = BCE_loss(pred_fr, y_real)
                loss_GAN_B2A    = (D_train_loss_rf + D_train_loss_fr) / 2
                
                recovered_A     = G_model_B2A_train(fake_B)
                loss_cycle_ABA  = L1_loss(recovered_A, images_A)
                recovered_A_face  = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True)
                images_A_face     = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True)
                loss_face_ABA     = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face)))

                recovered_B     = G_model_A2B_train(fake_A)
                loss_cycle_BAB  = L1_loss(recovered_B, images_B)
                recovered_B_face  = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True)
                images_B_face     = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True)
                loss_face_BAB     = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face)))

                G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A \
                    + loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5
            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(G_loss).backward()
            scaler.step(G_optimizer)
            scaler.update()
            
            #---------------------------------#
            #   训练评价器A
            #---------------------------------#
            with autocast():
                D_optimizer_A.zero_grad()
                pred_real   = D_model_A_train(images_A)
                pred_fake   = D_model_A_train(fake_A.detach())
                pred_rf     = pred_real - pred_fake.mean()
                pred_fr     = pred_fake - pred_real.mean()
                D_train_loss_rf  = BCE_loss(pred_rf, y_real)
                D_train_loss_fr  = BCE_loss(pred_fr, y_fake)
                gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach())

                D_loss_A    = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(D_loss_A).backward()
            scaler.step(D_optimizer_A)
            scaler.update()
            
            #---------------------------------#
            #   训练评价器B
            #---------------------------------#
            with autocast():
                D_optimizer_B.zero_grad()

                pred_real   = D_model_B_train(images_B)
                pred_fake   = D_model_B_train(fake_B.detach())
                pred_rf     = pred_real - pred_fake.mean()
                pred_fr     = pred_fake - pred_real.mean()
                D_train_loss_rf  = BCE_loss(pred_rf, y_real)
                D_train_loss_fr  = BCE_loss(pred_fr, y_fake)
                gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach())

                D_loss_B    = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2
            #----------------------#
            #   反向传播
            #----------------------#
            scaler.scale(D_loss_B).backward()
            scaler.step(D_optimizer_B)
            scaler.update()
                
        G_total_loss    += G_loss.item()
        D_total_loss_A  += D_loss_A.item()
        D_total_loss_B  += D_loss_B.item()

        if local_rank == 0:
            pbar.set_postfix(**{'G_loss'    : G_total_loss / (iteration + 1), 
                                'D_loss_A'  : D_total_loss_A / (iteration + 1), 
                                'D_loss_B'  : D_total_loss_B / (iteration + 1), 
                                'lr'        : get_lr(G_optimizer)})
            pbar.update(1)

            if iteration % photo_save_step == 0:
                show_result(epoch + 1, G_model_A2B, G_model_B2A, images_A, images_B)

    G_total_loss    = G_total_loss / epoch_step
    D_total_loss_A  = D_total_loss_A / epoch_step
    D_total_loss_B  = D_total_loss_B / epoch_step
    
    if local_rank == 0:
        pbar.close()
        print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
        print('G Loss: %.4f || D Loss A: %.4f || D Loss B: %.4f  ' % (G_total_loss, D_total_loss_A, D_total_loss_B))
        loss_history.append_loss(epoch + 1, G_total_loss = G_total_loss, D_total_loss_A = D_total_loss_A, D_total_loss_B = D_total_loss_B)

        #-----------------------------------------------#
        #   保存权值
        #-----------------------------------------------#
        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, 'G_model_A2B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B)))
            torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, 'G_model_B2A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B)))
            torch.save(D_model_A.state_dict(), os.path.join(save_dir, 'D_model_A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B)))
            torch.save(D_model_B.state_dict(), os.path.join(save_dir, 'D_model_B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B)))

        torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, "G_model_A2B_last_epoch_weights.pth"))
        torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, "G_model_B2A_last_epoch_weights.pth"))
        torch.save(D_model_A.state_dict(), os.path.join(save_dir, "D_model_A_last_epoch_weights.pth"))
        torch.save(D_model_B.state_dict(), os.path.join(save_dir, "D_model_B_last_epoch_weights.pth"))