EFDM / app.py
biubiubiiu's picture
add an example
a49fd59
raw
history blame
No virus
4.11 kB
import gradio as gr
import toml
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
import net
from function import *
cfg = toml.load("config.toml") # static variables
# Setup device
if torch.cuda.is_available() and cfg["use_cuda"]:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load pretrained models
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg = vgg.to(device)
decoder = decoder.to(device)
def transform(img, size, crop):
transform_list = []
if size > 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform(img)
@torch.inference_mode()
def style_transfer(content, style, style_type, alpha, keep_resolution):
"""Stylize function"""
style_type = style_type.lower()
# Step 1: convert image to PyTorch Tensor
if keep_resolution:
style = style.resize(content.size, Image.ANTIALIAS)
if style_type == "efdm" and not keep_resolution:
content = transform(content, cfg["content_size"], cfg["crop"])
style = transform(style, cfg["style_size"], cfg["crop"])
else:
content = transform(content, -1, False)
style = transform(style, -1, False)
content = content.to(device).unsqueeze(0)
style = style.to(device).unsqueeze(0)
# Step 2: extract content feature and style feature
content_feat = vgg(content)
style_feat = vgg(style)
# Step 3: perform style transfer
transfer = {
"adain": adaptive_instance_normalization,
"adamean": adaptive_mean_normalization,
"adastd": adaptive_std_normalization,
"efdm": exact_feature_distribution_matching,
"hm": histogram_matching,
}[style_type]
feat = transfer(content_feat, style_feat)
# Step 4: content-style trade-off
feat = feat * alpha + content_feat * (1 - alpha)
# Step 5: decode to image
output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
output = transforms.ToPILImage()(output)
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
return output
# Add image examples
example_img_pairs = {
"examples/content/sailboat.jpg": "examples/style/sketch.png",
"examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
"examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
"examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
"examples/content/cornell.jpg": "examples/style/asheville.jpg",
}
# Customize interface
title = "Style Transfer with EFDM"
description = """
Gradio demo for neural style transfer using exact feature distribution matching
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
style_type = gr.inputs.Radio(
["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
)
alpha_selector = gr.inputs.Slider(
minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
)
keep_resolution = gr.inputs.Checkbox(
default=True, label="Keep content image resolution"
)
iface = gr.Interface(
fn=style_transfer,
inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
outputs=["image"],
title=title,
description=description,
article=article,
theme="huggingface",
examples=[
[content, style, "EFDM", 1.0, True]
for content, style in example_img_pairs.items()
],
)
iface.launch(debug=False, enable_queue=True)