from pydoc import describe import re import numpy as np from PIL import Image import torch from torchvision import transforms import gradio as gr from model import TransformerNet style_model = TransformerNet() device=torch.device("cpu") styles_map = {"Kandinsky, Several circles": "kand_circles.model", "Haring, Dance": "haring_dance.model", "Picasso, The weeping woman": "picasso_weeping.model", "Van Gogh, Wheatfield with crows": "vangogh_crows.model"} content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) def run(content_image, style): content_image.thumbnail((1080, 1080)) img = content_transform(content_image) img = img.unsqueeze(0).to(device) model = styles_map[style] state_dict = torch.load(f"./models/{model}") for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.to(device) with torch.no_grad(): output = style_model(img) img = output[0].clone().clamp(0, 255).numpy() img = img.transpose(1, 2, 0).astype("uint8") img = Image.fromarray(img) return img content_image_input = gr.inputs.Image(label="Content Image", type="pil") style_input = gr.inputs.Dropdown(list(styles_map.keys()), type="value", default="Kandinsky, Several circles", label="Style") description="Fast Neural Style Transfer demo (trained from scratch!). Upload a content image. Select an artwork. Enjoy." article=""" **References**\n\n You can find here a post I put together describing the approach I used to train models and deploy them on visualneurons.com using AWS Lambda. \n Here is instead the Jupyter notebook with the training logic. \n

**Kandinsky, Several circles**
**Haring, Dance**
**Picasso, The weeping woman**
**Van Gogh, Wheatfield with crows** """ example = ["dog.jpeg", "Kandinsky, Several circles"] app_interface = gr.Interface(fn=run, inputs=[content_image_input, style_input], outputs="image", title="Fast Neural Style Transfer", description=description, examples=[example], article=article) app_interface.launch()