| from torchvision import transforms |
| import torch |
| import torch.nn as nn |
| from torchvision.utils import make_grid |
| from torch.utils.data import DataLoader |
| import matplotlib.pyplot as plt |
| import glob |
| import os |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| def show_tensor_images(image_tensor, epoch,step,num_images=25, size=(1, 28, 28)): |
| image_tensor = (image_tensor + 1) / 2 |
| image_shifted = image_tensor |
| image_unflat = image_shifted.detach().cpu().view(-1, *size) |
| image_grid = make_grid(image_unflat[:num_images], nrow=5) |
| if not os.path.exists(f"/outputs/Epoch{epoch}"): |
| os.makedirs(f"/outputs/Epoch{epoch}") |
| plt.imshow(image_grid.permute(1, 2, 0).squeeze()) |
| plt.savefig(os.path.join(f"outputs/Epoch{epoch}_step_{step}")) |
| plt.close() |
|
|
|
|
| class ImageDataset(Dataset): |
| def __init__(self, root, transform=None, mode='train'): |
| self.transform = transform |
| self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*')) |
| self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*')) |
| if len(self.files_A) > len(self.files_B): |
| self.files_A, self.files_B = self.files_B, self.files_A |
| self.new_perm() |
| assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!" |
|
|
| def new_perm(self): |
| self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)] |
|
|
| def __getitem__(self, index): |
| item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) |
| item_B = self.transform(Image.open(self.files_B[self.randperm[index]])) |
| if item_A.shape[0] != 3: |
| item_A = item_A.repeat(3, 1, 1) |
| if item_B.shape[0] != 3: |
| item_B = item_B.repeat(3, 1, 1) |
| if index == len(self) - 1: |
| self.new_perm() |
| |
| return (item_A - 0.5) * 2, (item_B - 0.5) * 2 |
|
|
| def __len__(self): |
| return min(len(self.files_A), len(self.files_B)) |
| |
|
|
| def weights_init(m): |
| if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): |
| torch.nn.init.normal_(m.weight,1.0,0.2) |
| if isinstance(m, nn.BatchNorm2d): |
| torch.nn.init.normal_(m.weight, 0.0, 0.02) |
| torch.nn.init.constant_(m.bias, 0) |
|
|
|
|
| def get_disc_loss(real_X, fake_X,disc_X, adv_criterion): |
| real_pred = disc_X(real_X.detach()) |
| disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred)) |
| fake_pred = disc_X(fake_X.deatch()) |
| disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred)) |
| disc_loss = (disc_real_loss + disc_fake_loss) / 2 |
| return disc_loss |
|
|
|
|
| def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion): |
| fake_Y = gen_XY(real_X.detach()) |
| disc_pred = disc_Y(fake_Y) |
| adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred)) |
| return adverserial_loss,fake_Y |
|
|
| def get_identity_loss(real_X, gen_YX,identity_criterion): |
| identity_X = gen_YX(real_X) |
| identity_loss = identity_criterion(identity_X,real_X) |
| return identity_loss,identity_X |
|
|
|
|
|
|
| def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion): |
| cycle_X = gen_YX(fake_Y) |
| cycle_loss = cycle_criterion(cycle_X,real_X) |
| return cycle_loss,cycle_X |
|
|
|
|
|
|
| def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10): |
| adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion) |
| adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion) |
| gen_adversarial_loss = adv_loss_BA + adv_loss_AB |
|
|
| |
| identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion) |
| identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion) |
| gen_identity_loss = identity_loss_A + identity_loss_B |
|
|
| |
| cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion) |
| cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion) |
| gen_cycle_loss = cycle_loss_BA + cycle_loss_AB |
|
|
| |
| gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss |
|
|
| return gen_loss , fake_A,fake_B |