import gradio as gr import toml import torch from PIL import Image from torch import nn from torchvision import transforms import net from function import * cfg = toml.load("config.toml") # static variables # Setup device if torch.cuda.is_available() and cfg["use_cuda"]: device = torch.device("cuda") else: device = torch.device("cpu") # Load pretrained models decoder = net.decoder vgg = net.vgg decoder.eval() vgg.eval() decoder.load_state_dict(torch.load(cfg["decoder_weight"])) vgg.load_state_dict(torch.load(cfg["vgg_weight"])) vgg = nn.Sequential(*list(vgg.children())[:31]) vgg = vgg.to(device) decoder = decoder.to(device) def transform(img, size, crop): transform_list = [] if size > 0: transform_list.append(transforms.Resize(size)) if crop: transform_list.append(transforms.CenterCrop(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform(img) @torch.inference_mode() def style_transfer(content, style, style_type, alpha, keep_resolution): """Stylize function""" style_type = style_type.lower() # Step 1: convert image to PyTorch Tensor if keep_resolution: style = style.resize(content.size, Image.ANTIALIAS) if style_type == "efdm" and not keep_resolution: content = transform(content, cfg["content_size"], cfg["crop"]) style = transform(style, cfg["style_size"], cfg["crop"]) else: content = transform(content, -1, False) style = transform(style, -1, False) content = content.to(device).unsqueeze(0) style = style.to(device).unsqueeze(0) # Step 2: extract content feature and style feature content_feat = vgg(content) style_feat = vgg(style) # Step 3: perform style transfer transfer = { "adain": adaptive_instance_normalization, "adamean": adaptive_mean_normalization, "adastd": adaptive_std_normalization, "efdm": exact_feature_distribution_matching, "hm": histogram_matching, }[style_type] feat = transfer(content_feat, style_feat) # Step 4: content-style trade-off feat = feat * alpha + content_feat * (1 - alpha) # Step 5: decode to image output = decoder(feat).cpu().squeeze(0).clamp_(0, 1) output = transforms.ToPILImage()(output) if torch.cuda.is_available(): torch.cuda.ipc_collect() torch.cuda.empty_cache() return output # Add image examples example_img_pairs = { "examples/content/sailboat.jpg": "examples/style/sketch.png", "examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg", "examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg", "examples/content/paris.jpeg": "examples/style/vangogh.jpeg", "examples/content/cornell.jpg": "examples/style/asheville.jpg", } # Customize interface title = "Style Transfer with EFDM" description = """ Gradio demo for neural style transfer using exact feature distribution matching """ article = "

Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization

" content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil") style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil") style_type = gr.inputs.Radio( ["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method" ) alpha_selector = gr.inputs.Slider( minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off" ) keep_resolution = gr.inputs.Checkbox( default=True, label="Keep content image resolution" ) iface = gr.Interface( fn=style_transfer, inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution], outputs=["image"], title=title, description=description, article=article, theme="huggingface", examples=[ [content, style, "EFDM", 1.0, True] for content, style in example_img_pairs.items() ], ) iface.launch(debug=False, enable_queue=True)