the-neural-networker commited on
Commit
b314b18
1 Parent(s): a926168

add application

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.optim as optim
5
+ from torchvision import transforms
6
+ from torchvision.models import vgg19
7
+ from nst.train import train
8
+ from nst.models.vgg19 import VGG19
9
+ from nst.losses import ContentLoss, StyleLoss
10
+
11
+ from urllib.request import urlretrieve
12
+
13
+
14
+ def transfer(content, style, device="cpu"):
15
+ transform = transforms.Compose([
16
+ transforms.Resize((512, 512)),
17
+ transforms.ToTensor(),
18
+ ])
19
+
20
+ content = transform(content).unsqueeze(0)
21
+ style = transform(style).unsqueeze(0)
22
+
23
+ x = content.clone()
24
+
25
+ # mean and std for vgg19
26
+ mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
27
+ std = torch.tensor([0.229, 0.224, 0.225]).to(device)
28
+
29
+ # vgg19 model
30
+ model = VGG19(mean=mean, std=std).to(device=device)
31
+ model = load_vgg19_weights(model, device)
32
+ # LBFGS optimizer like in paper
33
+ optimizer = optim.LBFGS([x.contiguous().requires_grad_()])
34
+
35
+ # computing content and style representations
36
+ content_outputs = model(content)
37
+ style_outputs = model(style)
38
+
39
+ # defining content and style losses
40
+ content_loss = ContentLoss(content_outputs["conv4"][1], device)
41
+ style_losses = []
42
+ for i in range(1, 6):
43
+ style_losses.append(StyleLoss(style_outputs[f"conv{i}"][0], device))
44
+
45
+ # run style transfer
46
+ output = train(model, optimizer, content_loss, style_losses, x,
47
+ iterations=10, alpha=1, beta=1000000,
48
+ style_weight=1.0)
49
+
50
+ output = output.detach().to("cpu")
51
+ output = output[0].permute(1, 2, 0).numpy()
52
+
53
+ return output
54
+
55
+ def load_vgg19_weights(model, device):
56
+ """
57
+ Loads VGG19 pretrained weights from ImageNet for style transfer.
58
+
59
+ Args:
60
+ model (nn.Module): VGG19 feature module with randomized weights.
61
+ device (torch.device): The device to load the model in.
62
+
63
+ Returns:
64
+ model (nn.Module): VGG19 module with pretrained ImageNet weights loaded.
65
+ """
66
+ pretrained_model = vgg19(pretrained=True).features.to(device).eval()
67
+
68
+ matching_keys = {
69
+ "conv1.conv1.weight": "0.weight",
70
+ "conv1.conv1.bias": "0.bias",
71
+ "conv1.conv2.weight": "2.weight",
72
+ "conv1.conv2.bias": "2.bias",
73
+
74
+ "conv2.conv1.weight": "5.weight",
75
+ "conv2.conv1.bias": "5.bias",
76
+ "conv2.conv2.weight": "7.weight",
77
+ "conv2.conv2.bias": "7.bias",
78
+
79
+ "conv3.conv1.weight": "10.weight",
80
+ "conv3.conv1.bias": "10.bias",
81
+ "conv3.conv2.weight": "12.weight",
82
+ "conv3.conv2.bias": "12.bias",
83
+ "conv3.conv3.weight": "14.weight",
84
+ "conv3.conv3.bias": "14.bias",
85
+ "conv3.conv4.weight": "16.weight",
86
+ "conv3.conv4.bias": "16.bias",
87
+
88
+ "conv4.conv1.weight": "19.weight",
89
+ "conv4.conv1.bias": "19.bias",
90
+ "conv4.conv2.weight": "21.weight",
91
+ "conv4.conv2.bias": "21.bias",
92
+ "conv4.conv3.weight": "23.weight",
93
+ "conv4.conv3.bias": "23.bias",
94
+ "conv4.conv4.weight": "25.weight",
95
+ "conv4.conv4.bias": "25.bias",
96
+
97
+ "conv5.conv1.weight": "28.weight",
98
+ "conv5.conv1.bias": "28.bias",
99
+ "conv5.conv2.weight": "30.weight",
100
+ "conv5.conv2.bias": "30.bias",
101
+ "conv5.conv3.weight": "32.weight",
102
+ "conv5.conv3.bias": "32.bias",
103
+ "conv5.conv4.weight": "34.weight",
104
+ "conv5.conv4.bias": "34.bias",
105
+ }
106
+
107
+ pretrained_dict = pretrained_model.state_dict()
108
+ model_dict = model.state_dict()
109
+
110
+ for key, value in matching_keys.items():
111
+ model_dict[key] = pretrained_dict[value]
112
+
113
+ model.load_state_dict(model_dict)
114
+
115
+ return model
116
+
117
+
118
+ def main():
119
+ # define app features and run
120
+ title = "Neural Style Transfer Demo"
121
+ description = "<p style='text-align: center'>Gradio demo for an transfering style from a 'style' image onto a 'content' image. To use it, simply add your content and style images, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing (~10 min). </p>"
122
+ article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
123
+ css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
124
+
125
+ urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/content/dancing.jpg?raw=True", "dancing_content.jpg") # make sure to use "copy image address when copying image from Github"
126
+ urlretrieve("https://github.com/the-neural-networker/neural-style-transfer/blob/main/images/style/picasso.jpg?raw=True", "picasso_style.jpg")
127
+ examples = [ # need to manually delete cache everytime new examples are added
128
+ ['dancing_content.jpg', "picasso_style.jpg"]
129
+ ]
130
+
131
+ demo = gr.Interface(
132
+ fn=transfer,
133
+ title=title,
134
+ description=description,
135
+ article=article,
136
+ inputs=[
137
+ gr.Image(type="pil", elem_id=0, show_label=False),
138
+ gr.Image(type="pil", elem_id=1, show_label=False)
139
+ ],
140
+ outputs=gr.Image(elem_id=2, show_label=False),
141
+ css=css,
142
+ examples=examples,
143
+ cache_examples=True,
144
+ allow_flagging='never'
145
+ )
146
+ demo.launch()
147
+
148
+ if __name__ == "__main__":
149
+ main()