Spaces:
Runtime error
Runtime error
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import torch | |
import utils | |
from typing import Tuple, Dict | |
from model import TransformerNet | |
from torchvision import transforms | |
from PIL import Image | |
from pathlib import Path | |
# Get model | |
model_dir = 'models' | |
models = list(Path(model_dir).glob("*/*.pth.tar")) | |
models = sorted(models) | |
# Get style image | |
style_dir = 'style_images' | |
style_list = list(Path(style_dir).glob("*")) | |
style_list = sorted(style_list) | |
# Get examples | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
def transfer(image, model): | |
device = 'cpu' | |
width = image.size[0] | |
height = image.size[1] | |
if width > 750 or height > 500: | |
iamge = image.thumbnail((712, 474)) | |
# load model | |
style_model = TransformerNet() | |
state_dict = torch.load(models[int(model)], map_location=torch.device('cpu')) | |
style_model.load_state_dict(state_dict["state_dict"]) | |
content_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x.mul(255)) | |
]) | |
content_image = content_transform(image) | |
content_image = content_image.unsqueeze(0).to(device) | |
style_model.eval() | |
with torch.no_grad(): | |
style_model.to(device) | |
output = style_model(content_image).cpu() | |
img = utils.deprocess(output[0]) | |
img = Image.fromarray(img) | |
return img, style_list[int(model)] | |
# Create title, description and article strings | |
title = "Image Style Transfer" | |
description = "Choose a image that you want to transfer and the corresponding style. The app will be transfer your image. You will have received new image." | |
article = "Model have created base on paper [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/pdf/1603.08155v1.pdf)." | |
image_output_1 = gr.Image(label='Tranfer') # output result | |
image_output_2 = gr.Image(label='Style Image') # Show style image | |
# Create the Gradio demo | |
demo = gr.Interface(fn=transfer, # mapping function from input to output | |
inputs=[gr.Image(type="pil", label='Input'), | |
gr.Dropdown(choices=[i.parent.name for i in models], value='rain_princess', type='index', label="Style", info="Chooses kind of style image")], # what are the inputs? | |
outputs=[image_output_1, image_output_2], # our fn has two outputs, therefore we have two outputs | |
label = ['One', "Two"], | |
examples=example_list, | |
title=title, | |
description=description, | |
article=article) | |
# Launch the demo! | |
demo.launch() | |