the-neural-networker's picture
add application
b314b18
raw
history blame
No virus
5.22 kB
import gradio as gr
import numpy as np
import torch
import torch.optim as optim
from torchvision import transforms
from torchvision.models import vgg19
from nst.train import train
from nst.models.vgg19 import VGG19
from nst.losses import ContentLoss, StyleLoss
from urllib.request import urlretrieve
def transfer(content, style, device="cpu"):
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
content = transform(content).unsqueeze(0)
style = transform(style).unsqueeze(0)
x = content.clone()
# mean and std for vgg19
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# vgg19 model
model = VGG19(mean=mean, std=std).to(device=device)
model = load_vgg19_weights(model, device)
# LBFGS optimizer like in paper
optimizer = optim.LBFGS([x.contiguous().requires_grad_()])
# computing content and style representations
content_outputs = model(content)
style_outputs = model(style)
# defining content and style losses
content_loss = ContentLoss(content_outputs["conv4"][1], device)
style_losses = []
for i in range(1, 6):
style_losses.append(StyleLoss(style_outputs[f"conv{i}"][0], device))
# run style transfer
output = train(model, optimizer, content_loss, style_losses, x,
iterations=10, alpha=1, beta=1000000,
style_weight=1.0)
output = output.detach().to("cpu")
output = output[0].permute(1, 2, 0).numpy()
return output
def load_vgg19_weights(model, device):
"""
Loads VGG19 pretrained weights from ImageNet for style transfer.
Args:
model (nn.Module): VGG19 feature module with randomized weights.
device (torch.device): The device to load the model in.
Returns:
model (nn.Module): VGG19 module with pretrained ImageNet weights loaded.
"""
pretrained_model = vgg19(pretrained=True).features.to(device).eval()
matching_keys = {
"conv1.conv1.weight": "0.weight",
"conv1.conv1.bias": "0.bias",
"conv1.conv2.weight": "2.weight",
"conv1.conv2.bias": "2.bias",
"conv2.conv1.weight": "5.weight",
"conv2.conv1.bias": "5.bias",
"conv2.conv2.weight": "7.weight",
"conv2.conv2.bias": "7.bias",
"conv3.conv1.weight": "10.weight",
"conv3.conv1.bias": "10.bias",
"conv3.conv2.weight": "12.weight",
"conv3.conv2.bias": "12.bias",
"conv3.conv3.weight": "14.weight",
"conv3.conv3.bias": "14.bias",
"conv3.conv4.weight": "16.weight",
"conv3.conv4.bias": "16.bias",
"conv4.conv1.weight": "19.weight",
"conv4.conv1.bias": "19.bias",
"conv4.conv2.weight": "21.weight",
"conv4.conv2.bias": "21.bias",
"conv4.conv3.weight": "23.weight",
"conv4.conv3.bias": "23.bias",
"conv4.conv4.weight": "25.weight",
"conv4.conv4.bias": "25.bias",
"conv5.conv1.weight": "28.weight",
"conv5.conv1.bias": "28.bias",
"conv5.conv2.weight": "30.weight",
"conv5.conv2.bias": "30.bias",
"conv5.conv3.weight": "32.weight",
"conv5.conv3.bias": "32.bias",
"conv5.conv4.weight": "34.weight",
"conv5.conv4.bias": "34.bias",
}
pretrained_dict = pretrained_model.state_dict()
model_dict = model.state_dict()
for key, value in matching_keys.items():
model_dict[key] = pretrained_dict[value]
model.load_state_dict(model_dict)
return model
def main():
# define app features and run
title = "Neural Style Transfer Demo"
description = "<p style='text-align: center'>Gradio demo for an transfering style from a 'style' image onto a 'content' image. To use it, simply add your content and style images, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing (~10 min). </p>"
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/content/dancing.jpg?raw=True", "dancing_content.jpg") # make sure to use "copy image address when copying image from Github"
urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/style/picasso.jpg?raw=True", "picasso_style.jpg")
examples = [ # need to manually delete cache everytime new examples are added
['dancing_content.jpg', "picasso_style.jpg"]
]
demo = gr.Interface(
fn=transfer,
title=title,
description=description,
article=article,
inputs=[
gr.Image(type="pil", elem_id=0, show_label=False),
gr.Image(type="pil", elem_id=1, show_label=False)
],
outputs=gr.Image(elem_id=2, show_label=False),
css=css,
examples=examples,
cache_examples=True,
allow_flagging='never'
)
demo.launch()
if __name__ == "__main__":
main()