import gradio as gr import numpy as np import torch import torch.optim as optim from torchvision import transforms from torchvision.models import vgg19 from nst.train import train from nst.models.vgg19 import VGG19 from nst.losses import ContentLoss, StyleLoss from urllib.request import urlretrieve def transfer(content, style, device="cpu"): transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), ]) content = transform(content).unsqueeze(0) style = transform(style).unsqueeze(0) x = content.clone() # mean and std for vgg19 mean = torch.tensor([0.485, 0.456, 0.406]).to(device) std = torch.tensor([0.229, 0.224, 0.225]).to(device) # vgg19 model model = VGG19(mean=mean, std=std).to(device=device) model = load_vgg19_weights(model, device) # LBFGS optimizer like in paper optimizer = optim.LBFGS([x.contiguous().requires_grad_()]) # computing content and style representations content_outputs = model(content) style_outputs = model(style) # defining content and style losses content_loss = ContentLoss(content_outputs["conv4"][1], device) style_losses = [] for i in range(1, 6): style_losses.append(StyleLoss(style_outputs[f"conv{i}"][0], device)) # run style transfer output = train(model, optimizer, content_loss, style_losses, x, iterations=10, alpha=1, beta=1000000, style_weight=1.0) output = output.detach().to("cpu") output = output[0].permute(1, 2, 0).numpy() return output def load_vgg19_weights(model, device): """ Loads VGG19 pretrained weights from ImageNet for style transfer. Args: model (nn.Module): VGG19 feature module with randomized weights. device (torch.device): The device to load the model in. Returns: model (nn.Module): VGG19 module with pretrained ImageNet weights loaded. """ pretrained_model = vgg19(pretrained=True).features.to(device).eval() matching_keys = { "conv1.conv1.weight": "0.weight", "conv1.conv1.bias": "0.bias", "conv1.conv2.weight": "2.weight", "conv1.conv2.bias": "2.bias", "conv2.conv1.weight": "5.weight", "conv2.conv1.bias": "5.bias", "conv2.conv2.weight": "7.weight", "conv2.conv2.bias": "7.bias", "conv3.conv1.weight": "10.weight", "conv3.conv1.bias": "10.bias", "conv3.conv2.weight": "12.weight", "conv3.conv2.bias": "12.bias", "conv3.conv3.weight": "14.weight", "conv3.conv3.bias": "14.bias", "conv3.conv4.weight": "16.weight", "conv3.conv4.bias": "16.bias", "conv4.conv1.weight": "19.weight", "conv4.conv1.bias": "19.bias", "conv4.conv2.weight": "21.weight", "conv4.conv2.bias": "21.bias", "conv4.conv3.weight": "23.weight", "conv4.conv3.bias": "23.bias", "conv4.conv4.weight": "25.weight", "conv4.conv4.bias": "25.bias", "conv5.conv1.weight": "28.weight", "conv5.conv1.bias": "28.bias", "conv5.conv2.weight": "30.weight", "conv5.conv2.bias": "30.bias", "conv5.conv3.weight": "32.weight", "conv5.conv3.bias": "32.bias", "conv5.conv4.weight": "34.weight", "conv5.conv4.bias": "34.bias", } pretrained_dict = pretrained_model.state_dict() model_dict = model.state_dict() for key, value in matching_keys.items(): model_dict[key] = pretrained_dict[value] model.load_state_dict(model_dict) return model def main(): # define app features and run title = "Neural Style Transfer Demo" description = "
Gradio demo for an transfering style from a 'style' image onto a 'content' image. To use it, simply add your content and style images, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing (~10 min).
" article = "" css = "#0 {object-fit: contain;} #1 {object-fit: contain;}" urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/content/dancing.jpg?raw=True", "dancing_content.jpg") # make sure to use "copy image address when copying image from Github" urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/style/picasso.jpg?raw=True", "picasso_style.jpg") examples = [ # need to manually delete cache everytime new examples are added ['dancing_content.jpg', "picasso_style.jpg"] ] demo = gr.Interface( fn=transfer, title=title, description=description, article=article, inputs=[ gr.Image(type="pil", elem_id=0, show_label=False), gr.Image(type="pil", elem_id=1, show_label=False) ], outputs=gr.Image(elem_id=2, show_label=False), css=css, examples=examples, cache_examples=True, allow_flagging='never' ) demo.launch() if __name__ == "__main__": main()