from huggingface_hub import hf_hub_download Rain_Princess = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Rain_Princess_512.pth") The_Scream = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Scream_512.pth") The_Mosaic = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Mosaic_512.pth") Starry_Night = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Starry_Night_512.pth") import numpy as np from PIL import Image import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class TransformerNetwork(nn.Module): def __init__(self, tanh_multiplier=None): super(TransformerNetwork, self).__init__() self.ConvBlock = nn.Sequential( ConvLayer(3, 32, 9, 1), nn.ReLU(), ConvLayer(32, 64, 3, 2), nn.ReLU(), ConvLayer(64, 128, 3, 2), nn.ReLU() ) self.ResidualBlock = nn.Sequential( ResidualLayer(128, 3), ResidualLayer(128, 3), ResidualLayer(128, 3), ResidualLayer(128, 3), ResidualLayer(128, 3) ) self.DeconvBlock = nn.Sequential( DeconvLayer(128, 64, 3, 2, 1), nn.ReLU(), DeconvLayer(64, 32, 3, 2, 1), nn.ReLU(), ConvLayer(32, 3, 9, 1, norm="None") ) self.tanh_multiplier = tanh_multiplier def forward(self, x): x = self.ConvBlock(x) x = self.ResidualBlock(x) x = self.DeconvBlock(x) if isinstance(self.tanh_multiplier, int): x = self.tanh_multiplier * F.tanh(x) return x class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"): super(ConvLayer, self).__init__() padding_size = kernel_size // 2 self.pad = nn.ReflectionPad2d(padding_size) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride) if norm == "instance": self.norm = nn.InstanceNorm2d(out_channels, affine=True) elif norm == "batch": self.norm = nn.BatchNorm2d(out_channels, affine=True) else: self.norm = nn.Identity() def forward(self, x): x = self.pad(x) x = self.conv(x) x = self.norm(x) return x class ResidualLayer(nn.Module): def __init__(self, channels=128, kernel_size=3): super(ResidualLayer, self).__init__() self.conv1 = ConvLayer(channels, channels, kernel_size, stride=1) self.relu = nn.ReLU() self.conv2 = ConvLayer(channels, channels, kernel_size, stride=1) def forward(self, x): identity = x out = self.relu(self.conv1(x)) out = self.conv2(out) out = out + identity return out class DeconvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"): super(DeconvLayer, self).__init__() padding_size = kernel_size // 2 self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding) if norm == "instance": self.norm = nn.InstanceNorm2d(out_channels, affine=True) elif norm == "batch": self.norm = nn.BatchNorm2d(out_channels, affine=True) else: self.norm = nn.Identity() def forward(self, x): x = self.conv_transpose(x) out = self.norm(x) return out mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) transformer = TransformerNetwork().to(device) transformer.eval() transform = transforms.Compose([ transforms.Resize(512), transforms.ToTensor(), transforms.Normalize(mean, std), ]) denormalize = transforms.Normalize( mean= [-m/s for m, s in zip(mean, std)], std= [1/s for s in std] ) tensor2Image = transforms.ToPILImage() @torch.no_grad() def process(image, model): image = transform(image).to(device) image = image.unsqueeze(dim=0) image = denormalize(model(image)).cpu() image = torch.clamp(image.squeeze(dim=0), 0, 1) image = tensor2Image(image) return image def main(image, backbone, style): if style == "The Scream": transformer.load_state_dict(torch.load(The_Scream, map_location=torch.device('cpu'))) elif style == "Rain Princess": transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu'))) elif style == "The Mosaic": transformer.load_state_dict(torch.load(The_Mosaic, map_location=torch.device('cpu'))) elif style == "Starry Night": transformer.load_state_dict(torch.load(Starry_Night, map_location=torch.device('cpu'))) else: transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu'))) image = Image.fromarray(image) isize = image.size image = process(image, transformer) s = f"The output image {str(image.size)} is processed by {backbone} based on input image {str(isize)}.
Please rate the generated image through the Flag button below!" print(s) return image, s # "Standard ResNet50", "VGG19" gr.Interface( title = "Stylize", description = "Image generated based on Fast Style Transfer", fn = main, inputs = [ gr.inputs.Image(), gr.inputs.Radio(["Robust ResNet50"], label="Backbone"), gr.inputs.Dropdown(["The Scream", "Rain Princess", "Starry Night", "The Mosaic"], type="value", default="Rain Princess", label="style") ], outputs = [gr.outputs.Image(label="Stylized"), gr.outputs.HTML(label="Comment")], # examples = [ # [] # ], # live = True, # the interface will recalculate as soon as the user input changes. allow_flagging = "manual", flagging_options = ["Excellect", "Moderate", "Bad"], flagging_dir = "flagged", allow_screenshot = False, ).launch() # iface.launch(enable_queue=True, cache_examples=True, debug=True)