Spaces:
Runtime error
Runtime error
gb-github-web
commited on
Commit
•
b456239
1
Parent(s):
5402e32
first try
Browse files- Photo_Style_Transfer.ipynb +71 -0
- colab_tools_2.py +87 -0
- predictor.py +51 -0
- stmodel.py +117 -0
- styles/.DS_Store +0 -0
Photo_Style_Transfer.ipynb
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "R7pPsDHPE_PF"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dabidou025/Live-Style-Transfer/blob/main/Photo_Style_Transfer.ipynb)"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"colab": {
|
17 |
+
"base_uri": "https://localhost:8080/"
|
18 |
+
},
|
19 |
+
"id": "pRMt1Ae6Asc7",
|
20 |
+
"outputId": "3b3868b0-e6c4-4a6d-93ae-bad6e0e31a02"
|
21 |
+
},
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"!pip install -q gradio==2.8.7\n",
|
25 |
+
"!git clone https://github.com/dabidou025/Live-Style-Transfer.git\n",
|
26 |
+
"%cd Live-Style-Transfer/"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {
|
33 |
+
"colab": {
|
34 |
+
"base_uri": "https://localhost:8080/"
|
35 |
+
},
|
36 |
+
"id": "LBy4brOoF9xC",
|
37 |
+
"outputId": "989467b2-546b-4b80-9475-4ad71a54a937"
|
38 |
+
},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"from colab_tools_2 import *"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"gradio_pls()"
|
51 |
+
]
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"metadata": {
|
55 |
+
"accelerator": "GPU",
|
56 |
+
"colab": {
|
57 |
+
"collapsed_sections": [],
|
58 |
+
"name": "Photo_Style_Transfer.ipynb",
|
59 |
+
"provenance": []
|
60 |
+
},
|
61 |
+
"kernelspec": {
|
62 |
+
"display_name": "Python 3",
|
63 |
+
"name": "python3"
|
64 |
+
},
|
65 |
+
"language_info": {
|
66 |
+
"name": "python"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"nbformat": 4,
|
70 |
+
"nbformat_minor": 0
|
71 |
+
}
|
colab_tools_2.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import dependencies
|
2 |
+
from IPython.display import display, Javascript, Image
|
3 |
+
from google.colab.output import eval_js
|
4 |
+
from google.colab.patches import cv2_imshow
|
5 |
+
from base64 import b64decode, b64encode
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import io
|
10 |
+
import html
|
11 |
+
import time
|
12 |
+
import torch
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
from models.stmodel import STModel
|
17 |
+
from predictor import Predictor
|
18 |
+
import argparse
|
19 |
+
from glob import glob
|
20 |
+
import os
|
21 |
+
from ipywidgets import Box, Image
|
22 |
+
import gradio as gr
|
23 |
+
|
24 |
+
def predict_gradio(image):
|
25 |
+
img_size = 512
|
26 |
+
load_model_path = "./models/st_model_512_80k_12.pth"
|
27 |
+
styles_path = "./styles/"
|
28 |
+
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
|
31 |
+
st_model = STModel(n_styles)
|
32 |
+
if True:
|
33 |
+
st_model.load_state_dict(torch.load(load_model_path, map_location=device))
|
34 |
+
st_model = st_model.to(device)
|
35 |
+
|
36 |
+
predictor = Predictor(st_model, device, img_size)
|
37 |
+
|
38 |
+
list_gen=[]
|
39 |
+
for s in range(n_styles):
|
40 |
+
gen = predictor.eval_image(image, s)
|
41 |
+
list_gen.append(gen)
|
42 |
+
return list_gen
|
43 |
+
|
44 |
+
def gradio_pls():
|
45 |
+
description="""
|
46 |
+
Upload a photo and click on submit to see the 12 styles applied to your photo. \n
|
47 |
+
Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
|
48 |
+
<center>
|
49 |
+
<table><tr>
|
50 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" width=100px></td>
|
51 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" width=100px></td>
|
52 |
+
|
53 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" width=100px></td>
|
54 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" width=100px></td>
|
55 |
+
|
56 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" width=100px></td>
|
57 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" width=100px></td>
|
58 |
+
|
59 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" width=100px></td>
|
60 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" width=100px></td>
|
61 |
+
|
62 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" width=100px></td>
|
63 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" width=100px></td>
|
64 |
+
|
65 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" width=100px></td>
|
66 |
+
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=100px></td>
|
67 |
+
|
68 |
+
</tr>
|
69 |
+
</table>
|
70 |
+
</center>
|
71 |
+
"""
|
72 |
+
iface = gr.Interface(
|
73 |
+
predict_gradio,
|
74 |
+
[
|
75 |
+
gr.inputs.Image(type="pil", label="Image"),
|
76 |
+
],
|
77 |
+
[
|
78 |
+
gr.outputs.Carousel("image", label="Style"),
|
79 |
+
],
|
80 |
+
layout="unaligned",
|
81 |
+
title="Photo Style Transfer",
|
82 |
+
description=description,
|
83 |
+
theme="grass",
|
84 |
+
allow_flagging='never'
|
85 |
+
)
|
86 |
+
|
87 |
+
return iface.launch(inline=True, height=800, width=800)
|
predictor.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class Predictor:
|
9 |
+
def __init__(self, st_model, device, img_size):
|
10 |
+
self.device = device
|
11 |
+
|
12 |
+
self.st_model = st_model.to(device)
|
13 |
+
self.st_model.eval()
|
14 |
+
|
15 |
+
self.mean = [0.485, 0.456, 0.406]
|
16 |
+
self.std = [0.229, 0.224, 0.225]
|
17 |
+
|
18 |
+
self.transformer = transforms.Compose([
|
19 |
+
transforms.Resize(img_size),
|
20 |
+
transforms.CenterCrop(img_size),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize(mean=self.mean, std=self.std)
|
23 |
+
])
|
24 |
+
|
25 |
+
def eval_image(self, img, style_1, style_2=None, alpha=0.5):
|
26 |
+
img = self.transformer(img).to(self.device)
|
27 |
+
gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
|
28 |
+
|
29 |
+
return Image.fromarray(np.uint8(np.moveaxis(gen[0].cpu().detach().numpy()*255.0, 0, 2)))
|
30 |
+
|
31 |
+
class WebcamPredictor:
|
32 |
+
def __init__(self, st_model, device):
|
33 |
+
self.device = device
|
34 |
+
|
35 |
+
self.st_model = st_model.to(device)
|
36 |
+
self.st_model.eval()
|
37 |
+
|
38 |
+
self.mean = np.array([0.485, 0.456, 0.406])
|
39 |
+
self.std = np.array([0.229, 0.224, 0.225])
|
40 |
+
|
41 |
+
self.mean = np.expand_dims(self.mean, (1,2))
|
42 |
+
self.std = np.expand_dims(self.std, (1,2))
|
43 |
+
|
44 |
+
def eval_image(self, img, style_1, style_2=None, alpha=0.5):
|
45 |
+
img = (img - self.mean) / self.std
|
46 |
+
img = torch.from_numpy(img).to(self.device)
|
47 |
+
img = img.float()
|
48 |
+
|
49 |
+
gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
|
50 |
+
|
51 |
+
return np.uint8(gen[0].cpu().detach().numpy()*255.0)
|
stmodel.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import time
|
8 |
+
|
9 |
+
class ConvCIN(nn.Module):
|
10 |
+
def __init__(self, n_styles, C_in, C_out, kernel_size, padding, stride, activation=None):
|
11 |
+
super(ConvCIN, self).__init__()
|
12 |
+
|
13 |
+
self.reflection = nn.ReflectionPad2d(padding)
|
14 |
+
self.conv = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=kernel_size, stride=stride)
|
15 |
+
nn.init.normal_(self.conv.weight, mean=0, std=1e-2)
|
16 |
+
|
17 |
+
self.instnorm = nn.InstanceNorm2d(C_out)#, affine=True)
|
18 |
+
#nn.init.normal_(self.instnorm.weight, mean=1, std=1e-2)
|
19 |
+
#nn.init.normal_(self.instnorm.bias, mean=0, std=1e-2)
|
20 |
+
|
21 |
+
|
22 |
+
self.gamma = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2 + 1, requires_grad=True)
|
23 |
+
#self.gamma.data.uniform_(1.0, 1.0)
|
24 |
+
|
25 |
+
self.beta = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2, requires_grad=True)
|
26 |
+
#self.beta.data.uniform_(0, 0)
|
27 |
+
|
28 |
+
self.activation = activation
|
29 |
+
|
30 |
+
def forward(self, x, style_1, style_2, alpha):
|
31 |
+
|
32 |
+
x = self.reflection(x)
|
33 |
+
x = self.conv(x)
|
34 |
+
|
35 |
+
x = self.instnorm(x)
|
36 |
+
|
37 |
+
|
38 |
+
if style_2 != None:
|
39 |
+
gamma = alpha*self.gamma[style_1] + (1-alpha)*self.gamma[style_2]
|
40 |
+
beta = alpha*self.beta[style_1] + (1-alpha)*self.beta[style_2]
|
41 |
+
else:
|
42 |
+
gamma = self.gamma[style_1]
|
43 |
+
beta = self.beta[style_1]
|
44 |
+
|
45 |
+
|
46 |
+
b,d,w,h = x.size()
|
47 |
+
x = x.view(b,d,w*h)
|
48 |
+
|
49 |
+
x = (x*gamma.unsqueeze(-1) + beta.unsqueeze(-1)).view(b,d,w,h)
|
50 |
+
|
51 |
+
if self.activation == 'relu':
|
52 |
+
x = F.relu(x)
|
53 |
+
elif self.activation == 'sigmoid':
|
54 |
+
x = torch.sigmoid(x)
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
class ResidualBlock(nn.Module):
|
59 |
+
def __init__(self, n_styles, C_in, C_out):
|
60 |
+
super(ResidualBlock,self).__init__()
|
61 |
+
|
62 |
+
self.convcin1 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
|
63 |
+
self.convcin2 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1)
|
64 |
+
|
65 |
+
def forward(self, x, style_1, style_2, alpha):
|
66 |
+
out = self.convcin1(x, style_1, style_2, alpha)
|
67 |
+
out = self.convcin2(out, style_1, style_2, alpha)
|
68 |
+
return x + out
|
69 |
+
|
70 |
+
class UpSampling(nn.Module):
|
71 |
+
def __init__(self, n_styles, C_in, C_out):
|
72 |
+
super(UpSampling,self).__init__()
|
73 |
+
|
74 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
75 |
+
self.convcin = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
|
76 |
+
|
77 |
+
def forward(self, x, style_1, style_2, alpha):
|
78 |
+
x = self.upsample(x)
|
79 |
+
x = self.convcin(x, style_1, style_2, alpha)
|
80 |
+
return x
|
81 |
+
|
82 |
+
class STModel(nn.Module):
|
83 |
+
def __init__(self, n_styles):
|
84 |
+
super(STModel,self).__init__()
|
85 |
+
|
86 |
+
self.convcin1 = ConvCIN(n_styles, C_in=3, C_out=32, kernel_size=9, padding=4, stride=1, activation='relu')
|
87 |
+
self.convcin2 = ConvCIN(n_styles, C_in=32, C_out=64, kernel_size=3, padding=1, stride=2, activation='relu')
|
88 |
+
self.convcin3 = ConvCIN(n_styles, C_in=64, C_out=128, kernel_size=3, padding=1, stride=2, activation='relu')
|
89 |
+
|
90 |
+
self.rb1 = ResidualBlock(n_styles, 128, 128)
|
91 |
+
self.rb2 = ResidualBlock(n_styles, 128, 128)
|
92 |
+
self.rb3 = ResidualBlock(n_styles, 128, 128)
|
93 |
+
self.rb4 = ResidualBlock(n_styles, 128, 128)
|
94 |
+
self.rb5 = ResidualBlock(n_styles, 128, 128)
|
95 |
+
|
96 |
+
self.upsample1 = UpSampling(n_styles, 128, 64)
|
97 |
+
self.upsample2 = UpSampling(n_styles, 64, 32)
|
98 |
+
|
99 |
+
self.convcin4 = ConvCIN(n_styles, C_in=32, C_out=3, kernel_size=9, padding=4, stride=1, activation='sigmoid')
|
100 |
+
|
101 |
+
def forward(self, x, style_1, style_2=None, alpha=0.5):
|
102 |
+
x = self.convcin1(x, style_1, style_2, alpha)
|
103 |
+
x = self.convcin2(x, style_1, style_2, alpha)
|
104 |
+
x = self.convcin3(x, style_1, style_2, alpha)
|
105 |
+
|
106 |
+
x = self.rb1(x, style_1, style_2, alpha)
|
107 |
+
x = self.rb2(x, style_1, style_2, alpha)
|
108 |
+
x = self.rb3(x, style_1, style_2, alpha)
|
109 |
+
x = self.rb4(x, style_1, style_2, alpha)
|
110 |
+
x = self.rb5(x, style_1, style_2, alpha)
|
111 |
+
|
112 |
+
x = self.upsample1(x, style_1, style_2, alpha)
|
113 |
+
x = self.upsample2(x, style_1, style_2, alpha)
|
114 |
+
|
115 |
+
x = self.convcin4(x, style_1, style_2, alpha)
|
116 |
+
|
117 |
+
return x
|
styles/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|