hieupt's picture
Update app.py
637b1dc
raw
history blame
2.65 kB
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
import utils
from typing import Tuple, Dict
from model import TransformerNet
from torchvision import transforms
from PIL import Image
from pathlib import Path
# Get model
model_dir = 'models'
models = list(Path(model_dir).glob("*/*.pth.tar"))
models = sorted(models)
# Get style image
style_dir = 'style_images'
style_list = list(Path(style_dir).glob("*"))
style_list = sorted(style_list)
# Get examples
example_list = [["examples/" + example] for example in os.listdir("examples")]
def transfer(image, model):
device = 'cpu'
width = image.size[0]
height = image.size[1]
if width > 750 or height > 500:
iamge = image.thumbnail((712, 474))
# load model
style_model = TransformerNet()
state_dict = torch.load(models[int(model)], map_location=torch.device('cpu'))
style_model.load_state_dict(state_dict["state_dict"])
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(image)
content_image = content_image.unsqueeze(0).to(device)
style_model.eval()
with torch.no_grad():
style_model.to(device)
output = style_model(content_image).cpu()
img = utils.deprocess(output[0])
img = Image.fromarray(img)
return img, style_list[int(model)]
# Create title, description and article strings
title = "Image Style Transfer"
description = "Choose a image that you want to transfer and the corresponding style. The app will be transfer your image. You will have received new image."
article = "Model have created base on paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155v1.pdf)."
image_output_1 = gr.Image(label='Tranfer') # output result
image_output_2 = gr.Image(label='Style Image') # Show style image
# Create the Gradio demo
demo = gr.Interface(fn=transfer, # mapping function from input to output
inputs=[gr.Image(type="pil", label='Input'),
gr.Dropdown(choices=[i.parent.name for i in models], value='rain_princess', type='index', label="Style", info="Chooses kind of style image")], # what are the inputs?
outputs=[image_output_1, image_output_2], # our fn has two outputs, therefore we have two outputs
label = ['One', "Two"],
examples=example_list,
title=title,
description=description,
article=article)
# Launch the demo!
demo.launch()