File size: 1,390 Bytes
12b5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a11ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from decoder import decoder as Decoder
from encoder import encoder as Encoder
from net import StyleTransfer
from PIL import Image

encoder = Encoder
decoder = Decoder

encoder.load_state_dict(torch.load("./vgg_normalised.pth"))
encoder = nn.Sequential(*list(encoder.children())[:31])
decoder.load_state_dict(torch.load("./saved-models/decoder_iter_1000.pth.tar"))


net = StyleTransfer(encoder, decoder)

net.eval()

def train_transform():
        transform_list = [
            transforms.Resize(size=(512, 512)),
            # transforms.CenterCrop(256),
            transforms.ToTensor()
        ]
        return transforms.Compose(transform_list)

def cleanup(input, style):
    transform = train_transform()
    input_img = transform(Image.fromarray(input))
    style_img = transform(Image.fromarray(style))
    input_img = input_img.view(1, *input_img.shape)
    style_img = style_img.view(1, *style_img.shape)
    final_image_tensor = net(input_img, style_img)
    final_image_tensor = final_image_tensor.squeeze()
    to_pil = transforms.ToPILImage()
    image = to_pil(final_image_tensor)
    return image

def greet(name):
    return "Hello " + name + "!"

demo = gr.Interface(fn=cleanup, inputs=[gr.Image(shape=(224, 224)),gr.Image(shape=(224,224))],outputs="image")
demo.launch()