File size: 1,390 Bytes
12b5a88 80a11ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from decoder import decoder as Decoder
from encoder import encoder as Encoder
from net import StyleTransfer
from PIL import Image
encoder = Encoder
decoder = Decoder
encoder.load_state_dict(torch.load("./vgg_normalised.pth"))
encoder = nn.Sequential(*list(encoder.children())[:31])
decoder.load_state_dict(torch.load("./saved-models/decoder_iter_1000.pth.tar"))
net = StyleTransfer(encoder, decoder)
net.eval()
def train_transform():
transform_list = [
transforms.Resize(size=(512, 512)),
# transforms.CenterCrop(256),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
def cleanup(input, style):
transform = train_transform()
input_img = transform(Image.fromarray(input))
style_img = transform(Image.fromarray(style))
input_img = input_img.view(1, *input_img.shape)
style_img = style_img.view(1, *style_img.shape)
final_image_tensor = net(input_img, style_img)
final_image_tensor = final_image_tensor.squeeze()
to_pil = transforms.ToPILImage()
image = to_pil(final_image_tensor)
return image
def greet(name):
return "Hello " + name + "!"
demo = gr.Interface(fn=cleanup, inputs=[gr.Image(shape=(224, 224)),gr.Image(shape=(224,224))],outputs="image")
demo.launch() |