AdaIN / app.py
vkganesan's picture
messed up share
80a11ed
raw
history blame contribute delete
No virus
1.39 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from decoder import decoder as Decoder
from encoder import encoder as Encoder
from net import StyleTransfer
from PIL import Image
encoder = Encoder
decoder = Decoder
encoder.load_state_dict(torch.load("./vgg_normalised.pth"))
encoder = nn.Sequential(*list(encoder.children())[:31])
decoder.load_state_dict(torch.load("./saved-models/decoder_iter_1000.pth.tar"))
net = StyleTransfer(encoder, decoder)
net.eval()
def train_transform():
transform_list = [
transforms.Resize(size=(512, 512)),
# transforms.CenterCrop(256),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
def cleanup(input, style):
transform = train_transform()
input_img = transform(Image.fromarray(input))
style_img = transform(Image.fromarray(style))
input_img = input_img.view(1, *input_img.shape)
style_img = style_img.view(1, *style_img.shape)
final_image_tensor = net(input_img, style_img)
final_image_tensor = final_image_tensor.squeeze()
to_pil = transforms.ToPILImage()
image = to_pil(final_image_tensor)
return image
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=cleanup, inputs=[gr.Image(shape=(224, 224)),gr.Image(shape=(224,224))],outputs="image")
demo.launch()