Spaces:
Runtime error
Runtime error
import gradio as gr | |
import toml | |
import torch | |
from PIL import Image | |
from torch import nn | |
from torchvision import transforms | |
import net | |
from function import * | |
cfg = toml.load("config.toml") # static variables | |
# Setup device | |
if torch.cuda.is_available() and cfg["use_cuda"]: | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# Load pretrained models | |
decoder = net.decoder | |
vgg = net.vgg | |
decoder.eval() | |
vgg.eval() | |
decoder.load_state_dict(torch.load(cfg["decoder_weight"])) | |
vgg.load_state_dict(torch.load(cfg["vgg_weight"])) | |
vgg = nn.Sequential(*list(vgg.children())[:31]) | |
vgg = vgg.to(device) | |
decoder = decoder.to(device) | |
def transform(img, size, crop): | |
transform_list = [] | |
if size > 0: | |
transform_list.append(transforms.Resize(size)) | |
if crop: | |
transform_list.append(transforms.CenterCrop(size)) | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform(img) | |
def style_transfer(content, style, style_type, alpha, keep_resolution): | |
"""Stylize function""" | |
style_type = style_type.lower() | |
# Step 1: convert image to PyTorch Tensor | |
if keep_resolution: | |
style = style.resize(content.size, Image.ANTIALIAS) | |
if style_type == "efdm" and not keep_resolution: | |
content = transform(content, cfg["content_size"], cfg["crop"]) | |
style = transform(style, cfg["style_size"], cfg["crop"]) | |
else: | |
content = transform(content, -1, False) | |
style = transform(style, -1, False) | |
content = content.to(device).unsqueeze(0) | |
style = style.to(device).unsqueeze(0) | |
# Step 2: extract content feature and style feature | |
content_feat = vgg(content) | |
style_feat = vgg(style) | |
# Step 3: perform style transfer | |
transfer = { | |
"adain": adaptive_instance_normalization, | |
"adamean": adaptive_mean_normalization, | |
"adastd": adaptive_std_normalization, | |
"efdm": exact_feature_distribution_matching, | |
"hm": histogram_matching, | |
}[style_type] | |
feat = transfer(content_feat, style_feat) | |
# Step 4: content-style trade-off | |
feat = feat * alpha + content_feat * (1 - alpha) | |
# Step 5: decode to image | |
output = decoder(feat).cpu().squeeze(0).clamp_(0, 1) | |
output = transforms.ToPILImage()(output) | |
if torch.cuda.is_available(): | |
torch.cuda.ipc_collect() | |
torch.cuda.empty_cache() | |
return output | |
# Add image examples | |
example_img_pairs = { | |
"examples/content/sailboat.jpg": "examples/style/sketch.png", | |
"examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg", | |
"examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg", | |
"examples/content/paris.jpeg": "examples/style/vangogh.jpeg", | |
"examples/content/cornell.jpg": "examples/style/asheville.jpg", | |
} | |
# Customize interface | |
title = "Style Transfer with EFDM" | |
description = """ | |
Gradio demo for neural style transfer using exact feature distribution matching | |
""" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>" | |
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil") | |
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil") | |
style_type = gr.inputs.Radio( | |
["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method" | |
) | |
alpha_selector = gr.inputs.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off" | |
) | |
keep_resolution = gr.inputs.Checkbox( | |
default=True, label="Keep content image resolution" | |
) | |
iface = gr.Interface( | |
fn=style_transfer, | |
inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution], | |
outputs=["image"], | |
title=title, | |
description=description, | |
article=article, | |
theme="huggingface", | |
examples=[ | |
[content, style, "EFDM", 1.0, True] | |
for content, style in example_img_pairs.items() | |
], | |
) | |
iface.launch(debug=False, enable_queue=True) | |