NikeZoldyck commited on
Commit
4158574
1 Parent(s): 4ac8bc1

adding the gradio app code

Browse files
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import gradio as gr
4
+ import utils.shared_utils as st
5
+
6
+
7
+ import torch
8
+ from torch import autocast
9
+ import torchvision.transforms as T
10
+ from contextlib import nullcontext
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ context = autocast if device == "cuda" else nullcontext
13
+ # Apply the transformations needed
14
+
15
+
16
+
17
+ def select_input(input_img,webcm_img):
18
+ if input_img is None:
19
+ img= webcm_img
20
+ else:
21
+ img=input_img
22
+ return img
23
+
24
+
25
+ def infer(prompt,samples):
26
+ images= []
27
+ selections = ["Img_{}".format(str(i+1).zfill(2)) for i in range(samples)]
28
+ with context(device):
29
+ for _ in range(samples):
30
+ back_img = st.stableDiffusionAPICall(prompt)
31
+ images.append(back_img)
32
+ return images
33
+
34
+
35
+
36
+
37
+
38
+ def change_bg_option(choice):
39
+ if choice == "I have an Image":
40
+ return gr.Image(shape=(800, 800))
41
+
42
+ elif choice == "Generate one for me":
43
+ return gr.update(lines=8, visible=True, value="Please enter a text prompt")
44
+ else:
45
+ return gr.update(visible=False)
46
+
47
+
48
+ # TEXT
49
+ title = "FSDL- One-Shot, Green-Screen, Composition-Transfer"
50
+ DEFAULT_TEXT = "Photorealistic scenery of bookshelf in a room"
51
+ description = """
52
+ <center><a href="https://docs.google.com/document/d/1fde8XKIMT1nNU72859ytd2c58LFBxepS3od9KFBrJbM/edit?usp=sharing">[PAPER]</a> <a href="https://github.com/snknitin/FSDL-Project/blob/main/src/utils/shared_utils.py">[CODE]</a></center>
53
+ <details>
54
+ <summary><b>Instructions</b></summary>
55
+ <p style="margin-top: -3px;">With this app, you can generate a suitable background image to overlay your portrait!<br />You have several ways to set how your final auto-edited image will look like:<br /></p>
56
+ <ul style="margin-top: -20px;margin-bottom: -15px;">
57
+ <li style="margin-bottom: -10px;margin-left: 20px;">Use the "<i>Inputs</i>" tab to either upload an image from your device or allow the use of your webcam to capture</li>
58
+ <li style="margin-left: 20px;">Use the "<i>Background Image Inputs</i>" to upload your own background</li>
59
+ <li style="margin-left: 20px;">Use the "<i>Text prompt</i>" tab to generate a satisfactory bacground image.</li>
60
+ </ul>
61
+ <p>After customization, just hit "<i>Edit</i>" and wait a few seconds.<br />The final image will be available for download <br /> <b>Enjoy!<b><p>
62
+ </details>
63
+ """
64
+
65
+ running = """
66
+
67
+ ### Instructions for running the 3 S's in sequence
68
+
69
+ * **Superimpose** - This button allows you to isolate the foreground from your image and overlay it on the background. Remove background using alpha matting
70
+ * **Style-Transfer** - This button transfer the style from your original image to re-map your new background realistically. Uses Nvidia FastPhotoStyle
71
+ * **Smoothing** - Given than image resolutions and clarity can be an issue, this smoothing button makes your final image crisp after the stylization transfer. Fair warning - this last process can take 5-10 mins
72
+ """
73
+
74
+
75
+ demo = gr.Blocks()
76
+
77
+ with demo:
78
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
79
+ with gr.Box():
80
+ gr.Markdown(description)
81
+ # First row - Inputs
82
+ with gr.Row(scale=1):
83
+ with gr.Column():
84
+ with gr.Tabs():
85
+ with gr.TabItem("Upload "):
86
+ input_img = gr.Image(shape=(800, 800), interactive=True, label="You")
87
+ with gr.TabItem("Webcam Capture"):
88
+ webcm_img = gr.Image(source="webcam", streaming=True, shape=(800, 800), interactive=True)
89
+ inp_select_btn = gr.Button("Select")
90
+
91
+ with gr.Column():
92
+ with gr.Tabs():
93
+ with gr.TabItem("Upload"):
94
+ bgm_img = gr.Image(shape=(800, 800), type="pil", interactive=True, label="The Background")
95
+ bgm_select_btn = gr.Button("Select")
96
+
97
+ with gr.TabItem("Generate via Text Prompt"):
98
+ with gr.Box():
99
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
100
+ text = gr.Textbox(lines=7,
101
+ placeholder="Enter your prompt to generate a background image... something like - Photorealistic scenery of bookshelf in a room")
102
+
103
+ samples = gr.Slider(label="Number of Images", minimum=1, maximum=5, value=2, step=1)
104
+ btn = gr.Button("Generate images",variant="primary").style(
105
+ margin=False,
106
+ rounded=(False, True, True, False),
107
+ )
108
+
109
+ gallery = gr.Gallery(label="Generated images", show_label=True).style(grid=(1, 3), height="auto")
110
+ # image_options = gr.Radio(label="Pick", interactive=True, choices=None, type="value")
111
+ text.submit(infer, inputs=[text, samples], outputs=gallery)
112
+ btn.click(infer, inputs=[text, samples], outputs=gallery, show_progress=True, status_tracker=None)
113
+
114
+
115
+ # Second Row - Backgrounds
116
+ with gr.Row(scale=1):
117
+ with gr.Column():
118
+ final_input_img = gr.Image(shape=(800, 800), type="pil", label="Foreground")
119
+
120
+ with gr.Column():
121
+ final_back_img = gr.Image(shape=(800, 800), type="pil", label="Background", interactive=True)
122
+
123
+ bgm_select_btn.click(fn=lambda x: x, inputs=bgm_img, outputs=final_back_img)
124
+
125
+ inp_select_btn.click(select_input, [input_img, webcm_img], final_input_img)
126
+
127
+ with gr.Row(scale=1):
128
+ with gr.Box():
129
+ gr.Markdown(running)
130
+
131
+ with gr.Row(scale=1):
132
+
133
+ with gr.Column(scale=1):
134
+ supimp_btn = gr.Button("SuperImpose")
135
+ overlay_img = gr.Image(shape=(800, 800), label="Overlay", type="pil")
136
+
137
+
138
+ with gr.Column(scale=1):
139
+ style_btn = gr.Button("Composition-Transfer",variant="primary")
140
+ style_img = gr.Image(shape=(800, 800),label="Style-Transfer Image",type="pil")
141
+
142
+ with gr.Column(scale=1):
143
+ submit_btn = gr.Button("Smoothen",variant="primary")
144
+ output_img = gr.Image(shape=(800, 800),label="FinalSmoothened Image",type="pil")
145
+
146
+ supimp_btn.click(fn=st.superimpose, inputs=[final_input_img, final_back_img], outputs=[overlay_img])
147
+ style_btn.click(fn=st.style_transfer, inputs=[overlay_img,final_input_img], outputs=[style_img])
148
+ submit_btn.click(fn=st.smoother, inputs=[style_img,overlay_img], outputs=[output_img])
149
+
150
+ demo.queue()
151
+ demo.launch()
152
+
models/__init__.py ADDED
File without changes
models/components/__init__.py ADDED
File without changes
models/components/photo_wct.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bedc114a83833de79e92b7166b37bc522db71a30bbfa13d0c4f36387789c8af5
3
+ size 33410469
models/models.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+ import torch.nn as nn
6
+
7
+
8
+ class VGGEncoder(nn.Module):
9
+ def __init__(self, level):
10
+ super(VGGEncoder, self).__init__()
11
+ self.level = level
12
+
13
+ # 224 x 224
14
+ self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
15
+
16
+ self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
17
+ # 226 x 226
18
+ self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
19
+ self.relu1_1 = nn.ReLU(inplace=True)
20
+ # 224 x 224
21
+
22
+ if level < 2: return
23
+
24
+ self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
25
+ self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
26
+ self.relu1_2 = nn.ReLU(inplace=True)
27
+ # 224 x 224
28
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
29
+ # 112 x 112
30
+
31
+ self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
32
+ self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
33
+ self.relu2_1 = nn.ReLU(inplace=True)
34
+ # 112 x 112
35
+
36
+ if level < 3: return
37
+
38
+ self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
39
+ self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
40
+ self.relu2_2 = nn.ReLU(inplace=True)
41
+ # 112 x 112
42
+
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
44
+ # 56 x 56
45
+
46
+ self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
47
+ self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
48
+ self.relu3_1 = nn.ReLU(inplace=True)
49
+ # 56 x 56
50
+
51
+ if level < 4: return
52
+
53
+ self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
54
+ self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
55
+ self.relu3_2 = nn.ReLU(inplace=True)
56
+ # 56 x 56
57
+
58
+ self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
59
+ self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
60
+ self.relu3_3 = nn.ReLU(inplace=True)
61
+ # 56 x 56
62
+
63
+ self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
64
+ self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
65
+ self.relu3_4 = nn.ReLU(inplace=True)
66
+ # 56 x 56
67
+
68
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
69
+ # 28 x 28
70
+
71
+ self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
72
+ self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
73
+ self.relu4_1 = nn.ReLU(inplace=True)
74
+ # 28 x 28
75
+
76
+ def forward(self, x):
77
+ out = self.conv0(x)
78
+
79
+ out = self.pad1_1(out)
80
+ out = self.conv1_1(out)
81
+ out = self.relu1_1(out)
82
+
83
+ if self.level < 2:
84
+ return out
85
+
86
+ out = self.pad1_2(out)
87
+ out = self.conv1_2(out)
88
+ pool1 = self.relu1_2(out)
89
+
90
+ out, pool1_idx = self.maxpool1(pool1)
91
+
92
+ out = self.pad2_1(out)
93
+ out = self.conv2_1(out)
94
+ out = self.relu2_1(out)
95
+
96
+ if self.level < 3:
97
+ return out, pool1_idx, pool1.size()
98
+
99
+ out = self.pad2_2(out)
100
+ out = self.conv2_2(out)
101
+ pool2 = self.relu2_2(out)
102
+
103
+ out, pool2_idx = self.maxpool2(pool2)
104
+
105
+ out = self.pad3_1(out)
106
+ out = self.conv3_1(out)
107
+ out = self.relu3_1(out)
108
+
109
+ if self.level < 4:
110
+ return out, pool1_idx, pool1.size(), pool2_idx, pool2.size()
111
+
112
+ out = self.pad3_2(out)
113
+ out = self.conv3_2(out)
114
+ out = self.relu3_2(out)
115
+
116
+ out = self.pad3_3(out)
117
+ out = self.conv3_3(out)
118
+ out = self.relu3_3(out)
119
+
120
+ out = self.pad3_4(out)
121
+ out = self.conv3_4(out)
122
+ pool3 = self.relu3_4(out)
123
+ out, pool3_idx = self.maxpool3(pool3)
124
+
125
+ out = self.pad4_1(out)
126
+ out = self.conv4_1(out)
127
+ out = self.relu4_1(out)
128
+
129
+ return out, pool1_idx, pool1.size(), pool2_idx, pool2.size(), pool3_idx, pool3.size()
130
+
131
+ def forward_multiple(self, x):
132
+ out = self.conv0(x)
133
+
134
+ out = self.pad1_1(out)
135
+ out = self.conv1_1(out)
136
+ out = self.relu1_1(out)
137
+
138
+ if self.level < 2: return out
139
+
140
+ out1 = out
141
+
142
+ out = self.pad1_2(out)
143
+ out = self.conv1_2(out)
144
+ pool1 = self.relu1_2(out)
145
+
146
+ out, pool1_idx = self.maxpool1(pool1)
147
+
148
+ out = self.pad2_1(out)
149
+ out = self.conv2_1(out)
150
+ out = self.relu2_1(out)
151
+
152
+ if self.level < 3: return out, out1
153
+
154
+ out2 = out
155
+
156
+ out = self.pad2_2(out)
157
+ out = self.conv2_2(out)
158
+ pool2 = self.relu2_2(out)
159
+
160
+ out, pool2_idx = self.maxpool2(pool2)
161
+
162
+ out = self.pad3_1(out)
163
+ out = self.conv3_1(out)
164
+ out = self.relu3_1(out)
165
+
166
+ if self.level < 4: return out, out2, out1
167
+
168
+ out3 = out
169
+
170
+ out = self.pad3_2(out)
171
+ out = self.conv3_2(out)
172
+ out = self.relu3_2(out)
173
+
174
+ out = self.pad3_3(out)
175
+ out = self.conv3_3(out)
176
+ out = self.relu3_3(out)
177
+
178
+ out = self.pad3_4(out)
179
+ out = self.conv3_4(out)
180
+ pool3 = self.relu3_4(out)
181
+ out, pool3_idx = self.maxpool3(pool3)
182
+
183
+ out = self.pad4_1(out)
184
+ out = self.conv4_1(out)
185
+ out = self.relu4_1(out)
186
+
187
+ return out, out3, out2, out1
188
+
189
+
190
+ class VGGDecoder(nn.Module):
191
+ def __init__(self, level):
192
+ super(VGGDecoder, self).__init__()
193
+ self.level = level
194
+
195
+ if level > 3:
196
+ self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
197
+ self.conv4_1 = nn.Conv2d(512, 256, 3, 1, 0)
198
+ self.relu4_1 = nn.ReLU(inplace=True)
199
+ # 28 x 28
200
+
201
+ self.unpool3 = nn.MaxUnpool2d(kernel_size=2, stride=2)
202
+ # 56 x 56
203
+
204
+ self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
205
+ self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
206
+ self.relu3_4 = nn.ReLU(inplace=True)
207
+ # 56 x 56
208
+
209
+ self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
210
+ self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
211
+ self.relu3_3 = nn.ReLU(inplace=True)
212
+ # 56 x 56
213
+
214
+ self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
215
+ self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
216
+ self.relu3_2 = nn.ReLU(inplace=True)
217
+ # 56 x 56
218
+
219
+ if level > 2:
220
+ self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
221
+ self.conv3_1 = nn.Conv2d(256, 128, 3, 1, 0)
222
+ self.relu3_1 = nn.ReLU(inplace=True)
223
+ # 56 x 56
224
+
225
+ self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2)
226
+ # 112 x 112
227
+
228
+ self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
229
+ self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
230
+ self.relu2_2 = nn.ReLU(inplace=True)
231
+ # 112 x 112
232
+
233
+ if level > 1:
234
+ self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
235
+ self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 0)
236
+ self.relu2_1 = nn.ReLU(inplace=True)
237
+ # 112 x 112
238
+
239
+ self.unpool1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
240
+ # 224 x 224
241
+
242
+ self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
243
+ self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
244
+ self.relu1_2 = nn.ReLU(inplace=True)
245
+ # 224 x 224
246
+
247
+ if level > 0:
248
+ self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
249
+ self.conv1_1 = nn.Conv2d(64, 3, 3, 1, 0)
250
+
251
+ def forward(self, x, pool1_idx=None, pool1_size=None, pool2_idx=None, pool2_size=None, pool3_idx=None,
252
+ pool3_size=None):
253
+ out = x
254
+
255
+ if self.level > 3:
256
+ out = self.pad4_1(out)
257
+ out = self.conv4_1(out)
258
+ out = self.relu4_1(out)
259
+ out = self.unpool3(out, pool3_idx, output_size=pool3_size)
260
+
261
+ out = self.pad3_4(out)
262
+ out = self.conv3_4(out)
263
+ out = self.relu3_4(out)
264
+
265
+ out = self.pad3_3(out)
266
+ out = self.conv3_3(out)
267
+ out = self.relu3_3(out)
268
+
269
+ out = self.pad3_2(out)
270
+ out = self.conv3_2(out)
271
+ out = self.relu3_2(out)
272
+
273
+ if self.level > 2:
274
+ out = self.pad3_1(out)
275
+ out = self.conv3_1(out)
276
+ out = self.relu3_1(out)
277
+ out = self.unpool2(out, pool2_idx, output_size=pool2_size)
278
+
279
+ out = self.pad2_2(out)
280
+ out = self.conv2_2(out)
281
+ out = self.relu2_2(out)
282
+
283
+ if self.level > 1:
284
+ out = self.pad2_1(out)
285
+ out = self.conv2_1(out)
286
+ out = self.relu2_1(out)
287
+ out = self.unpool1(out, pool1_idx, output_size=pool1_size)
288
+
289
+ out = self.pad1_2(out)
290
+ out = self.conv1_2(out)
291
+ out = self.relu1_2(out)
292
+
293
+ if self.level > 0:
294
+ out = self.pad1_1(out)
295
+ out = self.conv1_1(out)
296
+
297
+ return out
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch
3
+ diffusers
4
+ transformers
5
+ scipy
6
+ ftfy
7
+ gradio
8
+ torchvision
9
+ scikit-image
10
+ rembg
11
+ replicate
12
+ requests
13
+ Pillow
14
+ numpy
15
+ scipy
16
+ pyrootutils
17
+ pynvrtc
18
+ cupy
utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.pylogger import get_pylogger
2
+ from src.utils.rich_utils import enforce_tags, print_config_tree
3
+ from src.utils.utils import (
4
+ close_loggers,
5
+ extras,
6
+ get_metric_value,
7
+ instantiate_callbacks,
8
+ instantiate_loggers,
9
+ log_hyperparameters,
10
+ save_file,
11
+ task_wrapper,
12
+ )
utils/photo_smooth.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+ from __future__ import division
6
+ import torch.nn as nn
7
+ import scipy.misc
8
+ import scipy._lib
9
+ import numpy as np
10
+ import scipy.sparse
11
+ import scipy.sparse.linalg as linalg
12
+ from numpy.lib.stride_tricks import as_strided
13
+ from PIL import Image
14
+
15
+
16
+ class Propagator(nn.Module):
17
+ def __init__(self, beta=0.9999):
18
+ super(Propagator, self).__init__()
19
+ self.beta = beta
20
+
21
+ def process(self, initImg, contentImg):
22
+
23
+ if type(contentImg) == str:
24
+ content = scipy.misc.imread(contentImg, mode='RGB')
25
+ else:
26
+ content = contentImg.copy()
27
+ # content = scipy.misc.imread(contentImg, mode='RGB')
28
+
29
+ if type(initImg) == str:
30
+ B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64) / 255
31
+ else:
32
+ B = scipy.asarray(initImg).astype(np.float64) / 255
33
+ # B = self.
34
+ # B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64)/255
35
+ h1,w1,k = B.shape
36
+ h = h1 - 4
37
+ w = w1 - 4
38
+ B = B[int((h1-h)/2):int((h1-h)/2+h),int((w1-w)/2):int((w1-w)/2+w),:]
39
+ #content = scipy.misc.imresize(content,(h,w))
40
+ content = np.asarray(Image.fromarray(np.array(content)).resize((h,w),Image.BICUBIC))
41
+ B = self.__replication_padding(B,2)
42
+ content = self.__replication_padding(content,2)
43
+ content = content.astype(np.float64)/255
44
+ B = np.reshape(B,(h1*w1,k))
45
+ W = self.__compute_laplacian(content)
46
+ W = W.tocsc()
47
+ dd = W.sum(0)
48
+ dd = np.sqrt(np.power(dd,-1))
49
+ dd = dd.A.squeeze()
50
+ D = scipy.sparse.csc_matrix((dd, (np.arange(0,w1*h1), np.arange(0,w1*h1)))) # 0.026
51
+ S = D.dot(W).dot(D)
52
+ A = scipy.sparse.identity(w1*h1) - self.beta*S
53
+ A = A.tocsc()
54
+ solver = linalg.factorized(A)
55
+ V = np.zeros((h1*w1,k))
56
+ V[:,0] = solver(B[:,0])
57
+ V[:,1] = solver(B[:,1])
58
+ V[:,2] = solver(B[:,2])
59
+ V = V*(1-self.beta)
60
+ V = V.reshape(h1,w1,k)
61
+ V = V[2:2+h,2:2+w,:]
62
+
63
+ img = Image.fromarray(np.uint8(np.clip(V * 255., 0, 255.)))
64
+ return img
65
+
66
+ # Returns sparse matting laplacian
67
+ # The implementation of the function is heavily borrowed from
68
+ # https://github.com/MarcoForte/closed-form-matting/blob/master/closed_form_matting.py
69
+ # We thank Marco Forte for sharing his code.
70
+ def __compute_laplacian(self, img, eps=10**(-7), win_rad=1):
71
+ win_size = (win_rad*2+1)**2
72
+ h, w, d = img.shape
73
+ c_h, c_w = h - 2*win_rad, w - 2*win_rad
74
+ win_diam = win_rad*2+1
75
+ indsM = np.arange(h*w).reshape((h, w))
76
+ ravelImg = img.reshape(h*w, d)
77
+ win_inds = self.__rolling_block(indsM, block=(win_diam, win_diam))
78
+ win_inds = win_inds.reshape(c_h, c_w, win_size)
79
+ winI = ravelImg[win_inds]
80
+ win_mu = np.mean(winI, axis=2, keepdims=True)
81
+ win_var = np.einsum('...ji,...jk ->...ik', winI, winI)/win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu)
82
+ inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3))
83
+ X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv)
84
+ vals = (1/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu))
85
+ nz_indsCol = np.tile(win_inds, win_size).ravel()
86
+ nz_indsRow = np.repeat(win_inds, win_size).ravel()
87
+ nz_indsVal = vals.ravel()
88
+ L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w))
89
+ return L
90
+
91
+ def __replication_padding(self, arr,pad):
92
+ h,w,c = arr.shape
93
+ ans = np.zeros((h+pad*2,w+pad*2,c))
94
+ for i in range(c):
95
+ ans[:,:,i] = np.pad(arr[:,:,i],pad_width=(pad,pad),mode='edge')
96
+ return ans
97
+
98
+ def __rolling_block(self, A, block=(3, 3)):
99
+ shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block
100
+ strides = (A.strides[0], A.strides[1]) + A.strides
101
+ return as_strided(A, shape=shape, strides=strides)
utils/photo_wct.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn as nn
10
+ from models.models import VGGEncoder, VGGDecoder
11
+
12
+
13
+ class PhotoWCT(nn.Module):
14
+ def __init__(self):
15
+ super(PhotoWCT, self).__init__()
16
+ self.e1 = VGGEncoder(1)
17
+ self.d1 = VGGDecoder(1)
18
+ self.e2 = VGGEncoder(2)
19
+ self.d2 = VGGDecoder(2)
20
+ self.e3 = VGGEncoder(3)
21
+ self.d3 = VGGDecoder(3)
22
+ self.e4 = VGGEncoder(4)
23
+ self.d4 = VGGDecoder(4)
24
+
25
+ def transform(self, cont_img, styl_img, cont_seg, styl_seg):
26
+ self.__compute_label_info(cont_seg, styl_seg)
27
+
28
+ sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img)
29
+
30
+ cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4(cont_img)
31
+ sF4 = sF4.data.squeeze(0)
32
+ cF4 = cF4.data.squeeze(0)
33
+ # print(cont_seg)
34
+ csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg)
35
+ Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3)
36
+
37
+ cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4)
38
+ sF3 = sF3.data.squeeze(0)
39
+ cF3 = cF3.data.squeeze(0)
40
+ csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg)
41
+ Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2)
42
+
43
+ cF2, cpool_idx, cpool = self.e2(Im3)
44
+ sF2 = sF2.data.squeeze(0)
45
+ cF2 = cF2.data.squeeze(0)
46
+ csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg)
47
+ Im2 = self.d2(csF2, cpool_idx, cpool)
48
+
49
+ cF1 = self.e1(Im2)
50
+ sF1 = sF1.data.squeeze(0)
51
+ cF1 = cF1.data.squeeze(0)
52
+ csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg)
53
+ Im1 = self.d1(csF1)
54
+ return Im1
55
+
56
+ def __compute_label_info(self, cont_seg, styl_seg):
57
+ if cont_seg.size == False or styl_seg.size == False:
58
+ return
59
+ max_label = np.max(cont_seg) + 1
60
+ self.label_set = np.unique(cont_seg)
61
+ self.label_indicator = np.zeros(max_label)
62
+ for l in self.label_set:
63
+ # if l==0:
64
+ # continue
65
+ is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100
66
+ o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l)
67
+ o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l)
68
+ self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size)
69
+
70
+ def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
71
+ cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
72
+ styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
73
+ cont_feat_view = cont_feat.view(cont_c, -1).clone()
74
+ styl_feat_view = styl_feat.view(styl_c, -1).clone()
75
+
76
+ if cont_seg.size == False or styl_seg.size == False:
77
+ target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
78
+ else:
79
+ target_feature = cont_feat.view(cont_c, -1).clone()
80
+ if len(cont_seg.shape) == 2:
81
+ t_cont_seg = np.asarray(Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST))
82
+ else:
83
+ t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
84
+ if len(styl_seg.shape) == 2:
85
+ t_styl_seg = np.asarray(Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST))
86
+ else:
87
+ t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))
88
+
89
+ for l in self.label_set:
90
+ if self.label_indicator[l] == 0:
91
+ continue
92
+ cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
93
+ styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
94
+ if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
95
+ continue
96
+
97
+ cont_indi = torch.LongTensor(cont_mask[0])
98
+ styl_indi = torch.LongTensor(styl_mask[0])
99
+ if self.is_cuda:
100
+ cont_indi = cont_indi.cuda(0)
101
+ styl_indi = styl_indi.cuda(0)
102
+
103
+ cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
104
+ sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
105
+ # print(len(cont_indi))
106
+ # print(len(styl_indi))
107
+ tmp_target_feature = self.__wct_core(cFFG, sFFG)
108
+ # print(tmp_target_feature.size())
109
+ if torch.__version__ >= "0.4.0":
110
+ # This seems to be a bug in PyTorch 0.4.0 to me.
111
+ new_target_feature = torch.transpose(target_feature, 1, 0)
112
+ new_target_feature.index_copy_(0, cont_indi, \
113
+ torch.transpose(tmp_target_feature,1,0))
114
+ target_feature = torch.transpose(new_target_feature, 1, 0)
115
+ else:
116
+ target_feature.index_copy_(1, cont_indi, tmp_target_feature)
117
+
118
+ target_feature = target_feature.view_as(cont_feat)
119
+ ccsF = target_feature.float().unsqueeze(0)
120
+ return ccsF
121
+
122
+ def __wct_core(self, cont_feat, styl_feat):
123
+ cFSize = cont_feat.size()
124
+ c_mean = torch.mean(cont_feat, 1) # c x (h x w)
125
+ c_mean = c_mean.unsqueeze(1).expand_as(cont_feat)
126
+ cont_feat = cont_feat - c_mean
127
+
128
+ iden = torch.eye(cFSize[0]) # .double()
129
+ if self.is_cuda:
130
+ iden = iden.cuda()
131
+
132
+ contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
133
+ # del iden
134
+ c_u, c_e, c_v = torch.svd(contentConv, some=False)
135
+ # c_e2, c_v = torch.eig(contentConv, True)
136
+ # c_e = c_e2[:,0]
137
+
138
+ k_c = cFSize[0]
139
+ for i in range(cFSize[0] - 1, -1, -1):
140
+ if c_e[i] >= 0.00001:
141
+ k_c = i + 1
142
+ break
143
+
144
+ sFSize = styl_feat.size()
145
+ s_mean = torch.mean(styl_feat, 1)
146
+ styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat)
147
+ styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1)
148
+ s_u, s_e, s_v = torch.svd(styleConv, some=False)
149
+
150
+ k_s = sFSize[0]
151
+ for i in range(sFSize[0] - 1, -1, -1):
152
+ if s_e[i] >= 0.00001:
153
+ k_s = i + 1
154
+ break
155
+
156
+ c_d = (c_e[0:k_c]).pow(-0.5)
157
+ step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d))
158
+ step2 = torch.mm(step1, (c_v[:, 0:k_c].t()))
159
+ whiten_cF = torch.mm(step2, cont_feat)
160
+
161
+ s_d = (s_e[0:k_s]).pow(0.5)
162
+ targetFeature = torch.mm(torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF)
163
+ targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
164
+ return targetFeature
165
+
166
+ @property
167
+ def is_cuda(self):
168
+ return next(self.parameters()).is_cuda
169
+
170
+ def forward(self, *input):
171
+ pass
utils/shared_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from rembg import remove
3
+ import io
4
+
5
+ # Apply the transformations needed
6
+ from torch import autocast, nn
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ import torchvision.utils as utils
12
+ import torch.nn as nn
13
+ import pyrootutils
14
+ from PIL import Image
15
+ import numpy as np
16
+ from utils.photo_wct import PhotoWCT
17
+ from utils.photo_smooth import Propagator
18
+
19
+ # Load models
20
+ root = pyrootutils.setup_root(Path.cwd(), pythonpath=True)
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ # Load model
23
+ p_wct = PhotoWCT()
24
+ p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
25
+ p_pro = Propagator()
26
+ stylization_module=p_wct
27
+ smoothing_module=p_pro
28
+
29
+
30
+ #Dependecies - To be installed -
31
+ #!pip install replicate
32
+ #Token - To be authenticated -
33
+ #API TOKEN - 664474670af075461f85420f7b1d23d18484f826
34
+ #To be declared as an environment variable -
35
+ #export REPLICATE_API_TOKEN =
36
+ import replicate
37
+ import os
38
+ import requests
39
+
40
+
41
+
42
+ def stableDiffusionAPICall(text_prompt):
43
+ os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8'
44
+ model = replicate.models.get("stability-ai/stable-diffusion")
45
+ #text_prompt = 'photorealistic, elf fighting Sauron'
46
+ gen_bg_img = model.predict(prompt=text_prompt)[0]
47
+ img_data = requests.get(gen_bg_img).content
48
+ # r_data = binascii.unhexlify(img_data)
49
+ stream = io.BytesIO(img_data)
50
+ img = Image.open(stream)
51
+ del img_data
52
+
53
+ return img
54
+
55
+
56
+
57
+ def memory_limit_image_resize(cont_img):
58
+ # prevent too small or too big images
59
+ MINSIZE=400
60
+ MAXSIZE=800
61
+ orig_width = cont_img.width
62
+ orig_height = cont_img.height
63
+ if max(cont_img.width,cont_img.height) < MINSIZE:
64
+ if cont_img.width > cont_img.height:
65
+ cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC)
66
+ else:
67
+ cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC)
68
+ if min(cont_img.width,cont_img.height) > MAXSIZE:
69
+ if cont_img.width > cont_img.height:
70
+ cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC)
71
+ else:
72
+ cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC)
73
+ print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height))
74
+ return cont_img.width, cont_img.height
75
+
76
+
77
+
78
+
79
+
80
+ def superimpose(input_img,back_img):
81
+ matte_img = remove(input_img)
82
+ back_img.paste(matte_img, (0, 0), matte_img)
83
+ return back_img
84
+
85
+
86
+
87
+ def style_transfer(cont_img,styl_img):
88
+ with torch.no_grad():
89
+ new_cw, new_ch = memory_limit_image_resize(cont_img)
90
+ new_sw, new_sh = memory_limit_image_resize(styl_img)
91
+ cont_pilimg = cont_img.copy()
92
+ cw = cont_pilimg.width
93
+ ch = cont_pilimg.height
94
+ cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
95
+ styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
96
+
97
+ cont_seg = []
98
+ styl_seg = []
99
+
100
+ if device == 'cuda':
101
+ cont_img = cont_img.to(device)
102
+ styl_img = styl_img.to(device)
103
+ stylization_module.to(device)
104
+ cont_seg = np.asarray(cont_seg)
105
+ styl_seg = np.asarray(styl_seg)
106
+
107
+ stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg)
108
+ if ch != new_ch or cw != new_cw:
109
+ stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear')
110
+ grid = utils.make_grid(stylized_img.data, nrow=1, padding=0)
111
+ ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
112
+ stylized_img = Image.fromarray(ndarr)
113
+ #final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1)
114
+ return stylized_img
115
+
116
+ def smoother(stylized_img, over_img):
117
+ final_img = smoothing_module.process(stylized_img, over_img)
118
+ return final_img
119
+
120
+
121
+ if __name__ == "__main__":
122
+ root = pyrootutils.setup_root(__file__, pythonpath=True)
123
+ fg_path = root/"notebooks/profile_new.png"
124
+ bg_path = root/"notebooks/back_img.png"
125
+ ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt"
126
+
127
+ #stableDiffusionAPICall("Photorealistic scenery of a concert")
128
+ fg_img = Image.open(fg_path).resize((800,800))
129
+ bg_img = Image.open(bg_path).resize((800,800))
130
+ #img = combined_display(fg_img, bg_img,ckpt_path)
131
+ img = superimpose(fg_img,bg_img)
132
+ img.save(root/"notebooks/overlay.png")
133
+ # bg_img.paste(img, (0, 0), img)
134
+ # bg_img.save(root/"notebooks/check.png")
135
+
136
+
utils/smooth_filter.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+ src = '''
6
+ #include "/usr/local/cuda/include/math_functions.h"
7
+ #define TB 256
8
+ #define EPS 1e-7
9
+
10
+ __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
11
+ double m[16], inv[16];
12
+ for (int i = 0; i < 4; i++) {
13
+ for (int j = 0; j < 4; j++) {
14
+ m[i * 4 + j] = m_in[i][j];
15
+ }
16
+ }
17
+
18
+ inv[0] = m[5] * m[10] * m[15] -
19
+ m[5] * m[11] * m[14] -
20
+ m[9] * m[6] * m[15] +
21
+ m[9] * m[7] * m[14] +
22
+ m[13] * m[6] * m[11] -
23
+ m[13] * m[7] * m[10];
24
+
25
+ inv[4] = -m[4] * m[10] * m[15] +
26
+ m[4] * m[11] * m[14] +
27
+ m[8] * m[6] * m[15] -
28
+ m[8] * m[7] * m[14] -
29
+ m[12] * m[6] * m[11] +
30
+ m[12] * m[7] * m[10];
31
+
32
+ inv[8] = m[4] * m[9] * m[15] -
33
+ m[4] * m[11] * m[13] -
34
+ m[8] * m[5] * m[15] +
35
+ m[8] * m[7] * m[13] +
36
+ m[12] * m[5] * m[11] -
37
+ m[12] * m[7] * m[9];
38
+
39
+ inv[12] = -m[4] * m[9] * m[14] +
40
+ m[4] * m[10] * m[13] +
41
+ m[8] * m[5] * m[14] -
42
+ m[8] * m[6] * m[13] -
43
+ m[12] * m[5] * m[10] +
44
+ m[12] * m[6] * m[9];
45
+
46
+ inv[1] = -m[1] * m[10] * m[15] +
47
+ m[1] * m[11] * m[14] +
48
+ m[9] * m[2] * m[15] -
49
+ m[9] * m[3] * m[14] -
50
+ m[13] * m[2] * m[11] +
51
+ m[13] * m[3] * m[10];
52
+
53
+ inv[5] = m[0] * m[10] * m[15] -
54
+ m[0] * m[11] * m[14] -
55
+ m[8] * m[2] * m[15] +
56
+ m[8] * m[3] * m[14] +
57
+ m[12] * m[2] * m[11] -
58
+ m[12] * m[3] * m[10];
59
+
60
+ inv[9] = -m[0] * m[9] * m[15] +
61
+ m[0] * m[11] * m[13] +
62
+ m[8] * m[1] * m[15] -
63
+ m[8] * m[3] * m[13] -
64
+ m[12] * m[1] * m[11] +
65
+ m[12] * m[3] * m[9];
66
+
67
+ inv[13] = m[0] * m[9] * m[14] -
68
+ m[0] * m[10] * m[13] -
69
+ m[8] * m[1] * m[14] +
70
+ m[8] * m[2] * m[13] +
71
+ m[12] * m[1] * m[10] -
72
+ m[12] * m[2] * m[9];
73
+
74
+ inv[2] = m[1] * m[6] * m[15] -
75
+ m[1] * m[7] * m[14] -
76
+ m[5] * m[2] * m[15] +
77
+ m[5] * m[3] * m[14] +
78
+ m[13] * m[2] * m[7] -
79
+ m[13] * m[3] * m[6];
80
+
81
+ inv[6] = -m[0] * m[6] * m[15] +
82
+ m[0] * m[7] * m[14] +
83
+ m[4] * m[2] * m[15] -
84
+ m[4] * m[3] * m[14] -
85
+ m[12] * m[2] * m[7] +
86
+ m[12] * m[3] * m[6];
87
+
88
+ inv[10] = m[0] * m[5] * m[15] -
89
+ m[0] * m[7] * m[13] -
90
+ m[4] * m[1] * m[15] +
91
+ m[4] * m[3] * m[13] +
92
+ m[12] * m[1] * m[7] -
93
+ m[12] * m[3] * m[5];
94
+
95
+ inv[14] = -m[0] * m[5] * m[14] +
96
+ m[0] * m[6] * m[13] +
97
+ m[4] * m[1] * m[14] -
98
+ m[4] * m[2] * m[13] -
99
+ m[12] * m[1] * m[6] +
100
+ m[12] * m[2] * m[5];
101
+
102
+ inv[3] = -m[1] * m[6] * m[11] +
103
+ m[1] * m[7] * m[10] +
104
+ m[5] * m[2] * m[11] -
105
+ m[5] * m[3] * m[10] -
106
+ m[9] * m[2] * m[7] +
107
+ m[9] * m[3] * m[6];
108
+
109
+ inv[7] = m[0] * m[6] * m[11] -
110
+ m[0] * m[7] * m[10] -
111
+ m[4] * m[2] * m[11] +
112
+ m[4] * m[3] * m[10] +
113
+ m[8] * m[2] * m[7] -
114
+ m[8] * m[3] * m[6];
115
+
116
+ inv[11] = -m[0] * m[5] * m[11] +
117
+ m[0] * m[7] * m[9] +
118
+ m[4] * m[1] * m[11] -
119
+ m[4] * m[3] * m[9] -
120
+ m[8] * m[1] * m[7] +
121
+ m[8] * m[3] * m[5];
122
+
123
+ inv[15] = m[0] * m[5] * m[10] -
124
+ m[0] * m[6] * m[9] -
125
+ m[4] * m[1] * m[10] +
126
+ m[4] * m[2] * m[9] +
127
+ m[8] * m[1] * m[6] -
128
+ m[8] * m[2] * m[5];
129
+
130
+ double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
131
+
132
+ if (abs(det) < 1e-9) {
133
+ return false;
134
+ }
135
+
136
+
137
+ det = 1.0 / det;
138
+
139
+ for (int i = 0; i < 4; i++) {
140
+ for (int j = 0; j < 4; j++) {
141
+ inv_out[i][j] = inv[i * 4 + j] * det;
142
+ }
143
+ }
144
+
145
+ return true;
146
+ }
147
+
148
+ extern "C"
149
+ __global__ void best_local_affine_kernel(
150
+ float *output, float *input, float *affine_model,
151
+ int h, int w, float epsilon, int kernel_radius
152
+ )
153
+ {
154
+ int size = h * w;
155
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
156
+
157
+ if (id < size) {
158
+ int x = id % w, y = id / w;
159
+
160
+ double Mt_M[4][4] = {}; // 4x4
161
+ double invMt_M[4][4] = {};
162
+ double Mt_S[3][4] = {}; // RGB -> 1x4
163
+ double A[3][4] = {};
164
+ for (int i = 0; i < 4; i++)
165
+ for (int j = 0; j < 4; j++) {
166
+ Mt_M[i][j] = 0, invMt_M[i][j] = 0;
167
+ if (i != 3) {
168
+ Mt_S[i][j] = 0, A[i][j] = 0;
169
+ if (i == j)
170
+ Mt_M[i][j] = 1e-3;
171
+ }
172
+ }
173
+
174
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
175
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
176
+
177
+ int xx = x + dx, yy = y + dy;
178
+ int id2 = yy * w + xx;
179
+
180
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
181
+
182
+ Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
183
+ Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
184
+ Mt_M[0][2] += input[id2 + 2*size] * input[id2];
185
+ Mt_M[0][3] += input[id2 + 2*size];
186
+
187
+ Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
188
+ Mt_M[1][1] += input[id2 + size] * input[id2 + size];
189
+ Mt_M[1][2] += input[id2 + size] * input[id2];
190
+ Mt_M[1][3] += input[id2 + size];
191
+
192
+ Mt_M[2][0] += input[id2] * input[id2 + 2*size];
193
+ Mt_M[2][1] += input[id2] * input[id2 + size];
194
+ Mt_M[2][2] += input[id2] * input[id2];
195
+ Mt_M[2][3] += input[id2];
196
+
197
+ Mt_M[3][0] += input[id2 + 2*size];
198
+ Mt_M[3][1] += input[id2 + size];
199
+ Mt_M[3][2] += input[id2];
200
+ Mt_M[3][3] += 1;
201
+
202
+ Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
203
+ Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
204
+ Mt_S[0][2] += input[id2] * output[id2 + 2*size];
205
+ Mt_S[0][3] += output[id2 + 2*size];
206
+
207
+ Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
208
+ Mt_S[1][1] += input[id2 + size] * output[id2 + size];
209
+ Mt_S[1][2] += input[id2] * output[id2 + size];
210
+ Mt_S[1][3] += output[id2 + size];
211
+
212
+ Mt_S[2][0] += input[id2 + 2*size] * output[id2];
213
+ Mt_S[2][1] += input[id2 + size] * output[id2];
214
+ Mt_S[2][2] += input[id2] * output[id2];
215
+ Mt_S[2][3] += output[id2];
216
+ }
217
+ }
218
+ }
219
+
220
+ bool success = InverseMat4x4(Mt_M, invMt_M);
221
+
222
+ for (int i = 0; i < 3; i++) {
223
+ for (int j = 0; j < 4; j++) {
224
+ for (int k = 0; k < 4; k++) {
225
+ A[i][j] += invMt_M[j][k] * Mt_S[i][k];
226
+ }
227
+ }
228
+ }
229
+
230
+ for (int i = 0; i < 3; i++) {
231
+ for (int j = 0; j < 4; j++) {
232
+ int affine_id = i * 4 + j;
233
+ affine_model[12 * id + affine_id] = A[i][j];
234
+ }
235
+ }
236
+ }
237
+ return ;
238
+ }
239
+
240
+ extern "C"
241
+ __global__ void bilateral_smooth_kernel(
242
+ float *affine_model, float *filtered_affine_model, float *guide,
243
+ int h, int w, int kernel_radius, float sigma1, float sigma2
244
+ )
245
+ {
246
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
247
+ int size = h * w;
248
+ if (id < size) {
249
+ int x = id % w;
250
+ int y = id / w;
251
+
252
+ double sum_affine[12] = {};
253
+ double sum_weight = 0;
254
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
255
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
256
+ int yy = y + dy, xx = x + dx;
257
+ int id2 = yy * w + xx;
258
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
259
+ float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
260
+ float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
261
+ float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
262
+ float color_diff_sqr =
263
+ (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
264
+
265
+ float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
266
+ float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
267
+ float weight = v1 * v2;
268
+
269
+ for (int i = 0; i < 3; i++) {
270
+ for (int j = 0; j < 4; j++) {
271
+ int affine_id = i * 4 + j;
272
+ sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
273
+ }
274
+ }
275
+ sum_weight += weight;
276
+ }
277
+ }
278
+ }
279
+
280
+ for (int i = 0; i < 3; i++) {
281
+ for (int j = 0; j < 4; j++) {
282
+ int affine_id = i * 4 + j;
283
+ filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
284
+ }
285
+ }
286
+ }
287
+ return ;
288
+ }
289
+
290
+
291
+ extern "C"
292
+ __global__ void reconstruction_best_kernel(
293
+ float *input, float *filtered_affine_model, float *filtered_best_output,
294
+ int h, int w
295
+ )
296
+ {
297
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
298
+ int size = h * w;
299
+ if (id < size) {
300
+ double out1 =
301
+ input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
302
+ input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
303
+ input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
304
+ filtered_affine_model[id*12 + 3]; //A[0][3];
305
+ double out2 =
306
+ input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
307
+ input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
308
+ input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
309
+ filtered_affine_model[id*12 + 7]; //A[1][3];
310
+ double out3 =
311
+ input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
312
+ input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
313
+ input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
314
+ filtered_affine_model[id*12 + 11]; // A[2][3];
315
+
316
+ filtered_best_output[id] = out1;
317
+ filtered_best_output[id + size] = out2;
318
+ filtered_best_output[id + 2*size] = out3;
319
+ }
320
+ return ;
321
+ }
322
+ '''
323
+
324
+ import torch
325
+ import numpy as np
326
+ from PIL import Image
327
+ from cupy.cuda import function
328
+ from pynvrtc.compiler import Program
329
+ from collections import namedtuple
330
+
331
+
332
+ def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
333
+ # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
334
+ # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
335
+ program = Program(src, 'best_local_affine_kernel.cu')
336
+ ptx = program.compile(['-I/usr/local/cuda/include'])
337
+ m = function.Module()
338
+ m.load(bytes(ptx.encode()))
339
+
340
+ _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
341
+ _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
342
+ _best_local_affine_kernel = m.get_function('best_local_affine_kernel')
343
+ Stream = namedtuple('Stream', ['ptr'])
344
+ s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
345
+
346
+ filter_radius = f_r
347
+ sigma1 = filter_radius / 3
348
+ sigma2 = f_e
349
+ radius = (patch - 1) / 2
350
+
351
+ filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
352
+ affine_model = torch.zeros((h * w, 12)).cuda()
353
+ filtered_affine_model =torch.zeros((h * w, 12)).cuda()
354
+
355
+ input_ = torch.from_numpy(input_cpu).cuda()
356
+ output_ = torch.from_numpy(output_cpu).cuda()
357
+ _best_local_affine_kernel(
358
+ grid=(int((h * w) / 256 + 1), 1),
359
+ block=(256, 1, 1),
360
+ args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
361
+ np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
362
+ )
363
+
364
+ _bilateral_smooth_kernel(
365
+ grid=(int((h * w) / 256 + 1), 1),
366
+ block=(256, 1, 1),
367
+ args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
368
+ )
369
+
370
+ _reconstruction_best_kernel(
371
+ grid=(int((h * w) / 256 + 1), 1),
372
+ block=(256, 1, 1),
373
+ args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
374
+ np.int32(h), np.int32(w)], stream=s
375
+ )
376
+ numpy_filtered_best_output = filtered_best_output.cpu().numpy()
377
+ return numpy_filtered_best_output
378
+
379
+
380
+ def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
381
+ '''
382
+ :param initImg: intermediate output. Either image path or PIL Image
383
+ :param contentImg: content image output. Either path or PIL Image
384
+ :return: stylized output image. PIL Image
385
+ '''
386
+ if type(initImg) == str:
387
+ initImg = Image.open(initImg).convert("RGB")
388
+ best_image_bgr = np.array(initImg, dtype=np.float32)
389
+ bW, bH, bC = best_image_bgr.shape
390
+ best_image_bgr = best_image_bgr[:, :, ::-1]
391
+ best_image_bgr = best_image_bgr.transpose((2, 0, 1))
392
+
393
+ if type(contentImg) == str:
394
+ contentImg = Image.open(contentImg).convert("RGB")
395
+ content_input = contentImg.resize((bH,bW))
396
+ content_input = np.array(content_input, dtype=np.float32)
397
+ content_input = content_input[:, :, ::-1]
398
+ content_input = content_input.transpose((2, 0, 1))
399
+ input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
400
+ _, H, W = np.shape(input_)
401
+ output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
402
+ best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
403
+ best_ = best_.transpose(1, 2, 0)
404
+ result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
405
+ return result