gb-github-web commited on
Commit
b456239
1 Parent(s): 5402e32
Files changed (5) hide show
  1. Photo_Style_Transfer.ipynb +71 -0
  2. colab_tools_2.py +87 -0
  3. predictor.py +51 -0
  4. stmodel.py +117 -0
  5. styles/.DS_Store +0 -0
Photo_Style_Transfer.ipynb ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "R7pPsDHPE_PF"
7
+ },
8
+ "source": [
9
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dabidou025/Live-Style-Transfer/blob/main/Photo_Style_Transfer.ipynb)"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "colab": {
17
+ "base_uri": "https://localhost:8080/"
18
+ },
19
+ "id": "pRMt1Ae6Asc7",
20
+ "outputId": "3b3868b0-e6c4-4a6d-93ae-bad6e0e31a02"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "!pip install -q gradio==2.8.7\n",
25
+ "!git clone https://github.com/dabidou025/Live-Style-Transfer.git\n",
26
+ "%cd Live-Style-Transfer/"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "colab": {
34
+ "base_uri": "https://localhost:8080/"
35
+ },
36
+ "id": "LBy4brOoF9xC",
37
+ "outputId": "989467b2-546b-4b80-9475-4ad71a54a937"
38
+ },
39
+ "outputs": [],
40
+ "source": [
41
+ "from colab_tools_2 import *"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "gradio_pls()"
51
+ ]
52
+ }
53
+ ],
54
+ "metadata": {
55
+ "accelerator": "GPU",
56
+ "colab": {
57
+ "collapsed_sections": [],
58
+ "name": "Photo_Style_Transfer.ipynb",
59
+ "provenance": []
60
+ },
61
+ "kernelspec": {
62
+ "display_name": "Python 3",
63
+ "name": "python3"
64
+ },
65
+ "language_info": {
66
+ "name": "python"
67
+ }
68
+ },
69
+ "nbformat": 4,
70
+ "nbformat_minor": 0
71
+ }
colab_tools_2.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import dependencies
2
+ from IPython.display import display, Javascript, Image
3
+ from google.colab.output import eval_js
4
+ from google.colab.patches import cv2_imshow
5
+ from base64 import b64decode, b64encode
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import io
10
+ import html
11
+ import time
12
+ import torch
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PIL import Image
16
+ from models.stmodel import STModel
17
+ from predictor import Predictor
18
+ import argparse
19
+ from glob import glob
20
+ import os
21
+ from ipywidgets import Box, Image
22
+ import gradio as gr
23
+
24
+ def predict_gradio(image):
25
+ img_size = 512
26
+ load_model_path = "./models/st_model_512_80k_12.pth"
27
+ styles_path = "./styles/"
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
31
+ st_model = STModel(n_styles)
32
+ if True:
33
+ st_model.load_state_dict(torch.load(load_model_path, map_location=device))
34
+ st_model = st_model.to(device)
35
+
36
+ predictor = Predictor(st_model, device, img_size)
37
+
38
+ list_gen=[]
39
+ for s in range(n_styles):
40
+ gen = predictor.eval_image(image, s)
41
+ list_gen.append(gen)
42
+ return list_gen
43
+
44
+ def gradio_pls():
45
+ description="""
46
+ Upload a photo and click on submit to see the 12 styles applied to your photo. \n
47
+ Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
48
+ <center>
49
+ <table><tr>
50
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" width=100px></td>
51
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" width=100px></td>
52
+
53
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" width=100px></td>
54
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" width=100px></td>
55
+
56
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" width=100px></td>
57
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" width=100px></td>
58
+
59
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" width=100px></td>
60
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" width=100px></td>
61
+
62
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" width=100px></td>
63
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" width=100px></td>
64
+
65
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" width=100px></td>
66
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=100px></td>
67
+
68
+ </tr>
69
+ </table>
70
+ </center>
71
+ """
72
+ iface = gr.Interface(
73
+ predict_gradio,
74
+ [
75
+ gr.inputs.Image(type="pil", label="Image"),
76
+ ],
77
+ [
78
+ gr.outputs.Carousel("image", label="Style"),
79
+ ],
80
+ layout="unaligned",
81
+ title="Photo Style Transfer",
82
+ description=description,
83
+ theme="grass",
84
+ allow_flagging='never'
85
+ )
86
+
87
+ return iface.launch(inline=True, height=800, width=800)
predictor.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+
4
+ from PIL import Image
5
+
6
+ import numpy as np
7
+
8
+ class Predictor:
9
+ def __init__(self, st_model, device, img_size):
10
+ self.device = device
11
+
12
+ self.st_model = st_model.to(device)
13
+ self.st_model.eval()
14
+
15
+ self.mean = [0.485, 0.456, 0.406]
16
+ self.std = [0.229, 0.224, 0.225]
17
+
18
+ self.transformer = transforms.Compose([
19
+ transforms.Resize(img_size),
20
+ transforms.CenterCrop(img_size),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=self.mean, std=self.std)
23
+ ])
24
+
25
+ def eval_image(self, img, style_1, style_2=None, alpha=0.5):
26
+ img = self.transformer(img).to(self.device)
27
+ gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
28
+
29
+ return Image.fromarray(np.uint8(np.moveaxis(gen[0].cpu().detach().numpy()*255.0, 0, 2)))
30
+
31
+ class WebcamPredictor:
32
+ def __init__(self, st_model, device):
33
+ self.device = device
34
+
35
+ self.st_model = st_model.to(device)
36
+ self.st_model.eval()
37
+
38
+ self.mean = np.array([0.485, 0.456, 0.406])
39
+ self.std = np.array([0.229, 0.224, 0.225])
40
+
41
+ self.mean = np.expand_dims(self.mean, (1,2))
42
+ self.std = np.expand_dims(self.std, (1,2))
43
+
44
+ def eval_image(self, img, style_1, style_2=None, alpha=0.5):
45
+ img = (img - self.mean) / self.std
46
+ img = torch.from_numpy(img).to(self.device)
47
+ img = img.float()
48
+
49
+ gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
50
+
51
+ return np.uint8(gen[0].cpu().detach().numpy()*255.0)
stmodel.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import numpy as np
6
+ import math
7
+ import time
8
+
9
+ class ConvCIN(nn.Module):
10
+ def __init__(self, n_styles, C_in, C_out, kernel_size, padding, stride, activation=None):
11
+ super(ConvCIN, self).__init__()
12
+
13
+ self.reflection = nn.ReflectionPad2d(padding)
14
+ self.conv = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=kernel_size, stride=stride)
15
+ nn.init.normal_(self.conv.weight, mean=0, std=1e-2)
16
+
17
+ self.instnorm = nn.InstanceNorm2d(C_out)#, affine=True)
18
+ #nn.init.normal_(self.instnorm.weight, mean=1, std=1e-2)
19
+ #nn.init.normal_(self.instnorm.bias, mean=0, std=1e-2)
20
+
21
+
22
+ self.gamma = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2 + 1, requires_grad=True)
23
+ #self.gamma.data.uniform_(1.0, 1.0)
24
+
25
+ self.beta = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2, requires_grad=True)
26
+ #self.beta.data.uniform_(0, 0)
27
+
28
+ self.activation = activation
29
+
30
+ def forward(self, x, style_1, style_2, alpha):
31
+
32
+ x = self.reflection(x)
33
+ x = self.conv(x)
34
+
35
+ x = self.instnorm(x)
36
+
37
+
38
+ if style_2 != None:
39
+ gamma = alpha*self.gamma[style_1] + (1-alpha)*self.gamma[style_2]
40
+ beta = alpha*self.beta[style_1] + (1-alpha)*self.beta[style_2]
41
+ else:
42
+ gamma = self.gamma[style_1]
43
+ beta = self.beta[style_1]
44
+
45
+
46
+ b,d,w,h = x.size()
47
+ x = x.view(b,d,w*h)
48
+
49
+ x = (x*gamma.unsqueeze(-1) + beta.unsqueeze(-1)).view(b,d,w,h)
50
+
51
+ if self.activation == 'relu':
52
+ x = F.relu(x)
53
+ elif self.activation == 'sigmoid':
54
+ x = torch.sigmoid(x)
55
+
56
+ return x
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, n_styles, C_in, C_out):
60
+ super(ResidualBlock,self).__init__()
61
+
62
+ self.convcin1 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
63
+ self.convcin2 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1)
64
+
65
+ def forward(self, x, style_1, style_2, alpha):
66
+ out = self.convcin1(x, style_1, style_2, alpha)
67
+ out = self.convcin2(out, style_1, style_2, alpha)
68
+ return x + out
69
+
70
+ class UpSampling(nn.Module):
71
+ def __init__(self, n_styles, C_in, C_out):
72
+ super(UpSampling,self).__init__()
73
+
74
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
75
+ self.convcin = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
76
+
77
+ def forward(self, x, style_1, style_2, alpha):
78
+ x = self.upsample(x)
79
+ x = self.convcin(x, style_1, style_2, alpha)
80
+ return x
81
+
82
+ class STModel(nn.Module):
83
+ def __init__(self, n_styles):
84
+ super(STModel,self).__init__()
85
+
86
+ self.convcin1 = ConvCIN(n_styles, C_in=3, C_out=32, kernel_size=9, padding=4, stride=1, activation='relu')
87
+ self.convcin2 = ConvCIN(n_styles, C_in=32, C_out=64, kernel_size=3, padding=1, stride=2, activation='relu')
88
+ self.convcin3 = ConvCIN(n_styles, C_in=64, C_out=128, kernel_size=3, padding=1, stride=2, activation='relu')
89
+
90
+ self.rb1 = ResidualBlock(n_styles, 128, 128)
91
+ self.rb2 = ResidualBlock(n_styles, 128, 128)
92
+ self.rb3 = ResidualBlock(n_styles, 128, 128)
93
+ self.rb4 = ResidualBlock(n_styles, 128, 128)
94
+ self.rb5 = ResidualBlock(n_styles, 128, 128)
95
+
96
+ self.upsample1 = UpSampling(n_styles, 128, 64)
97
+ self.upsample2 = UpSampling(n_styles, 64, 32)
98
+
99
+ self.convcin4 = ConvCIN(n_styles, C_in=32, C_out=3, kernel_size=9, padding=4, stride=1, activation='sigmoid')
100
+
101
+ def forward(self, x, style_1, style_2=None, alpha=0.5):
102
+ x = self.convcin1(x, style_1, style_2, alpha)
103
+ x = self.convcin2(x, style_1, style_2, alpha)
104
+ x = self.convcin3(x, style_1, style_2, alpha)
105
+
106
+ x = self.rb1(x, style_1, style_2, alpha)
107
+ x = self.rb2(x, style_1, style_2, alpha)
108
+ x = self.rb3(x, style_1, style_2, alpha)
109
+ x = self.rb4(x, style_1, style_2, alpha)
110
+ x = self.rb5(x, style_1, style_2, alpha)
111
+
112
+ x = self.upsample1(x, style_1, style_2, alpha)
113
+ x = self.upsample2(x, style_1, style_2, alpha)
114
+
115
+ x = self.convcin4(x, style_1, style_2, alpha)
116
+
117
+ return x
styles/.DS_Store ADDED
Binary file (6.15 kB). View file