import os import torch import torch.nn.functional as F from tqdm import tqdm from nets.cyclegan import compute_gradient_penalty from utils.utils import get_lr, show_result def fit_one_epoch(G_model_A2B_train, G_model_B2A_train, D_model_A_train, D_model_B_train, G_model_A2B, G_model_B2A, D_model_A, D_model_B, VGG_feature_model, ResNeSt_model, loss_history, G_optimizer, D_optimizer_A, D_optimizer_B, BCE_loss, L1_loss, Face_loss, epoch, epoch_step, gen, Epoch, cuda, fp16, scaler, save_period, save_dir, photo_save_step, local_rank=0): G_total_loss = 0 D_total_loss_A = 0 D_total_loss_B = 0 if local_rank == 0: print('Start Train') pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) for iteration, batch in enumerate(gen): if iteration >= epoch_step: break images_A, images_B = batch[0], batch[1] batch_size = images_A.size()[0] y_real = torch.ones(batch_size) y_fake = torch.zeros(batch_size) with torch.no_grad(): if cuda: images_A, images_B, y_real, y_fake = images_A.cuda(local_rank), images_B.cuda(local_rank), y_real.cuda(local_rank), y_fake.cuda(local_rank) if not fp16: #---------------------------------# # 训练生成器A2B和B2A #---------------------------------# G_optimizer.zero_grad() Same_B = G_model_A2B_train(images_B) loss_identity_B = L1_loss(Same_B, images_B) Same_A = G_model_B2A_train(images_A) loss_identity_A = L1_loss(Same_A, images_A) fake_B = G_model_A2B_train(images_A) pred_real = D_model_B_train(images_B) pred_fake = D_model_B_train(fake_B) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_fake) D_train_loss_fr = BCE_loss(pred_fr, y_real) loss_GAN_A2B = (D_train_loss_rf + D_train_loss_fr) / 2 fake_A = G_model_B2A_train(images_B) pred_real = D_model_A_train(images_A) pred_fake = D_model_A_train(fake_A) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_fake) D_train_loss_fr = BCE_loss(pred_fr, y_real) loss_GAN_B2A = (D_train_loss_rf + D_train_loss_fr) / 2 recovered_A = G_model_B2A_train(fake_B) loss_cycle_ABA = L1_loss(recovered_A, images_A) loss_per_ABA = L1_loss(VGG_feature_model(recovered_A), VGG_feature_model(images_A)) recovered_A_face = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True) images_A_face = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True) loss_face_ABA = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face))) recovered_B = G_model_A2B_train(fake_A) loss_cycle_BAB = L1_loss(recovered_B, images_B) loss_per_BAB = L1_loss(VGG_feature_model(recovered_B), VGG_feature_model(images_B)) recovered_B_face = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True) images_B_face = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True) loss_face_BAB = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face))) G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A + loss_per_ABA * 2.5 \ + loss_per_BAB *2.5 + loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5 G_loss.backward() G_optimizer.step() #---------------------------------# # 训练评价器A #---------------------------------# D_optimizer_A.zero_grad() pred_real = D_model_A_train(images_A) pred_fake = D_model_A_train(fake_A.detach()) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_real) D_train_loss_fr = BCE_loss(pred_fr, y_fake) gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach()) D_loss_A = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 D_loss_A.backward() D_optimizer_A.step() #---------------------------------# # 训练评价器B #---------------------------------# D_optimizer_B.zero_grad() pred_real = D_model_B_train(images_B) pred_fake = D_model_B_train(fake_B.detach()) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_real) D_train_loss_fr = BCE_loss(pred_fr, y_fake) gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach()) D_loss_B = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 D_loss_B.backward() D_optimizer_B.step() else: from torch.cuda.amp import autocast #---------------------------------# # 训练生成器A2B和B2A #---------------------------------# with autocast(): G_optimizer.zero_grad() Same_B = G_model_A2B_train(images_B) loss_identity_B = L1_loss(Same_B, images_B) Same_A = G_model_B2A_train(images_A) loss_identity_A = L1_loss(Same_A, images_A) fake_B = G_model_A2B_train(images_A) pred_real = D_model_B_train(images_B) pred_fake = D_model_B_train(fake_B) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_fake) D_train_loss_fr = BCE_loss(pred_fr, y_real) loss_GAN_A2B = (D_train_loss_rf + D_train_loss_fr) / 2 fake_A = G_model_B2A_train(images_B) pred_real = D_model_A_train(images_A) pred_fake = D_model_A_train(fake_A) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_fake) D_train_loss_fr = BCE_loss(pred_fr, y_real) loss_GAN_B2A = (D_train_loss_rf + D_train_loss_fr) / 2 recovered_A = G_model_B2A_train(fake_B) loss_cycle_ABA = L1_loss(recovered_A, images_A) recovered_A_face = F.interpolate(recovered_A, size=(112, 112), mode='bicubic', align_corners=True) images_A_face = F.interpolate(images_A, size=(112, 112), mode='bicubic', align_corners=True) loss_face_ABA = torch.mean(1. - Face_loss(ResNeSt_model(recovered_A_face), ResNeSt_model(images_A_face))) recovered_B = G_model_A2B_train(fake_A) loss_cycle_BAB = L1_loss(recovered_B, images_B) recovered_B_face = F.interpolate(recovered_B, size=(112, 112), mode='bicubic', align_corners=True) images_B_face = F.interpolate(images_B, size=(112, 112), mode='bicubic', align_corners=True) loss_face_BAB = torch.mean(1. - Face_loss(ResNeSt_model(recovered_B_face), ResNeSt_model(images_B_face))) G_loss = loss_identity_A * 5.0 + loss_identity_B * 5.0 + loss_GAN_A2B + loss_GAN_B2A \ + loss_cycle_ABA * 10.0 + loss_cycle_BAB * 10.0 + loss_face_ABA * 5 + loss_face_BAB * 5 #----------------------# # 反向传播 #----------------------# scaler.scale(G_loss).backward() scaler.step(G_optimizer) scaler.update() #---------------------------------# # 训练评价器A #---------------------------------# with autocast(): D_optimizer_A.zero_grad() pred_real = D_model_A_train(images_A) pred_fake = D_model_A_train(fake_A.detach()) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_real) D_train_loss_fr = BCE_loss(pred_fr, y_fake) gradient_penalty = compute_gradient_penalty(D_model_A_train, images_A, fake_A.detach()) D_loss_A = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 #----------------------# # 反向传播 #----------------------# scaler.scale(D_loss_A).backward() scaler.step(D_optimizer_A) scaler.update() #---------------------------------# # 训练评价器B #---------------------------------# with autocast(): D_optimizer_B.zero_grad() pred_real = D_model_B_train(images_B) pred_fake = D_model_B_train(fake_B.detach()) pred_rf = pred_real - pred_fake.mean() pred_fr = pred_fake - pred_real.mean() D_train_loss_rf = BCE_loss(pred_rf, y_real) D_train_loss_fr = BCE_loss(pred_fr, y_fake) gradient_penalty = compute_gradient_penalty(D_model_B_train, images_B, fake_B.detach()) D_loss_B = 10 * gradient_penalty + (D_train_loss_rf + D_train_loss_fr) / 2 #----------------------# # 反向传播 #----------------------# scaler.scale(D_loss_B).backward() scaler.step(D_optimizer_B) scaler.update() G_total_loss += G_loss.item() D_total_loss_A += D_loss_A.item() D_total_loss_B += D_loss_B.item() if local_rank == 0: pbar.set_postfix(**{'G_loss' : G_total_loss / (iteration + 1), 'D_loss_A' : D_total_loss_A / (iteration + 1), 'D_loss_B' : D_total_loss_B / (iteration + 1), 'lr' : get_lr(G_optimizer)}) pbar.update(1) if iteration % photo_save_step == 0: show_result(epoch + 1, G_model_A2B, G_model_B2A, images_A, images_B) G_total_loss = G_total_loss / epoch_step D_total_loss_A = D_total_loss_A / epoch_step D_total_loss_B = D_total_loss_B / epoch_step if local_rank == 0: pbar.close() print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) print('G Loss: %.4f || D Loss A: %.4f || D Loss B: %.4f ' % (G_total_loss, D_total_loss_A, D_total_loss_B)) loss_history.append_loss(epoch + 1, G_total_loss = G_total_loss, D_total_loss_A = D_total_loss_A, D_total_loss_B = D_total_loss_B) #-----------------------------------------------# # 保存权值 #-----------------------------------------------# if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, 'G_model_A2B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, 'G_model_B2A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) torch.save(D_model_A.state_dict(), os.path.join(save_dir, 'D_model_A_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) torch.save(D_model_B.state_dict(), os.path.join(save_dir, 'D_model_B_Epoch%d-GLoss%.4f-DALoss%.4f-DBLoss%.4f.pth'%(epoch + 1, G_total_loss, D_total_loss_A, D_total_loss_B))) torch.save(G_model_A2B.state_dict(), os.path.join(save_dir, "G_model_A2B_last_epoch_weights.pth")) torch.save(G_model_B2A.state_dict(), os.path.join(save_dir, "G_model_B2A_last_epoch_weights.pth")) torch.save(D_model_A.state_dict(), os.path.join(save_dir, "D_model_A_last_epoch_weights.pth")) torch.save(D_model_B.state_dict(), os.path.join(save_dir, "D_model_B_last_epoch_weights.pth"))