HW_DL_GAN / app.py
ss99es's picture
Update app.py
aa9867a
raw
history blame contribute delete
No virus
6.14 kB
import torch
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from PIL import Image
from torchvision import transforms
cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as FF
import torch.nn as nn
import torch.nn.functional as F
import torch
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, "bias") and m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# RESNET
##############################
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
return self.model(img)
plt.rcParams["savefig.bbox"] = 'tight'
input_shape = (3, 32, 32)
n_residual_blocks=9
G_AB = GeneratorResNet(input_shape, n_residual_blocks)
G_BA = GeneratorResNet(input_shape, n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
G_AB.load_state_dict(torch.load("G_AB_30.pth",map_location=torch.device('cpu')))
def sample_images(img):
"""Saves a generated sample from the test set"""
imgs = Image.open(img).convert('RGB')
convert_tensor =transforms.Compose ([
transforms.Resize(int(32 * 1.12), Image.BICUBIC),
transforms.RandomCrop((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
imgss=convert_tensor(imgs)
real_A = imgss.reshape([1,3,32,32])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G_AB.to(device)
real_A=real_A.to(device)
fake_B = G_AB(real_A)
if not isinstance(fake_B[0], list):
imgs = [fake_B[0]]
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = FF.to_pil_image(img)
OO=np.asarray(img)
#real_B = Variable(imgs["B"].type(Tensor))
#fake_A = G_BA(real_B)
# Arange images along x-axis
#real_A = make_grid(real_A, nrow=5, normalize=True)
#real_B = make_grid(real_B, nrow=5, normalize=True)
#fake_A = make_grid(fake_A, nrow=5, normalize=True)
#fake_B = make_grid(fake_B, nrow=1, normalize=True)
#fake_B=fake_B.to('cpu').detach().numpy()
#fake_B=(fake_B * 127.5 + 127.5).astype(np.uint8)
# Arange images along y-axis
return OO
import gradio as gr
image = gr.inputs.Image(type="filepath")
op = gr.outputs.Image(type="numpy")
iface = gr.Interface(
sample_images,
image,
op,
title="CycleGAN",
description='Implementation of CycleGAN model using Apple emojis to Windows emojis',
examples=["U+1F446.png"],
)
iface.launch()