Spaces:
Runtime error
Runtime error
Update files
Browse files- app.py +1033 -427
- canvas.py +650 -548
- index.html +411 -214
- perlin2d.py +44 -44
- postprocess.py +249 -0
- process.py +395 -0
- utils.py +263 -151
app.py
CHANGED
@@ -1,427 +1,1033 @@
|
|
1 |
-
import io
|
2 |
-
import base64
|
3 |
-
import os
|
4 |
-
|
5 |
-
|
6 |
-
import
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
from
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
try:
|
60 |
-
|
61 |
-
except
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
return
|
96 |
-
|
97 |
-
|
98 |
-
def
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
)
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
""
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
)
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
)
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import autocast
|
9 |
+
import diffusers
|
10 |
+
from diffusers.configuration_utils import FrozenDict
|
11 |
+
from diffusers import (
|
12 |
+
StableDiffusionPipeline,
|
13 |
+
StableDiffusionInpaintPipeline,
|
14 |
+
StableDiffusionImg2ImgPipeline,
|
15 |
+
StableDiffusionInpaintPipelineLegacy,
|
16 |
+
DDIMScheduler,
|
17 |
+
LMSDiscreteScheduler,
|
18 |
+
)
|
19 |
+
from PIL import Image
|
20 |
+
from PIL import ImageOps
|
21 |
+
import gradio as gr
|
22 |
+
import base64
|
23 |
+
import skimage
|
24 |
+
import skimage.measure
|
25 |
+
import yaml
|
26 |
+
import json
|
27 |
+
from enum import Enum
|
28 |
+
|
29 |
+
try:
|
30 |
+
abspath = os.path.abspath(__file__)
|
31 |
+
dirname = os.path.dirname(abspath)
|
32 |
+
os.chdir(dirname)
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
from utils import *
|
37 |
+
|
38 |
+
assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0"
|
39 |
+
|
40 |
+
USE_NEW_DIFFUSERS = True
|
41 |
+
RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
|
42 |
+
|
43 |
+
|
44 |
+
class ModelChoice(Enum):
|
45 |
+
INPAINTING = "stablediffusion-inpainting"
|
46 |
+
INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5"
|
47 |
+
MODEL_1_5 = "stablediffusion-v1.5"
|
48 |
+
MODEL_1_4 = "stablediffusion-v1.4"
|
49 |
+
|
50 |
+
|
51 |
+
try:
|
52 |
+
from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
|
53 |
+
except:
|
54 |
+
UnifiedPipeline = StableDiffusionInpaintPipeline
|
55 |
+
|
56 |
+
# sys.path.append("./glid_3_xl_stable")
|
57 |
+
|
58 |
+
USE_GLID = False
|
59 |
+
# try:
|
60 |
+
# from glid3xlmodel import GlidModel
|
61 |
+
# except:
|
62 |
+
# USE_GLID = False
|
63 |
+
|
64 |
+
try:
|
65 |
+
cuda_available = torch.cuda.is_available()
|
66 |
+
except:
|
67 |
+
cuda_available = False
|
68 |
+
finally:
|
69 |
+
if sys.platform == "darwin":
|
70 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
71 |
+
elif cuda_available:
|
72 |
+
device = "cuda"
|
73 |
+
else:
|
74 |
+
device = "cpu"
|
75 |
+
|
76 |
+
if device != "cuda":
|
77 |
+
import contextlib
|
78 |
+
|
79 |
+
autocast = contextlib.nullcontext
|
80 |
+
|
81 |
+
with open("config.yaml", "r") as yaml_in:
|
82 |
+
yaml_object = yaml.safe_load(yaml_in)
|
83 |
+
config_json = json.dumps(yaml_object)
|
84 |
+
|
85 |
+
|
86 |
+
def load_html():
|
87 |
+
body, canvaspy = "", ""
|
88 |
+
with open("index.html", encoding="utf8") as f:
|
89 |
+
body = f.read()
|
90 |
+
with open("canvas.py", encoding="utf8") as f:
|
91 |
+
canvaspy = f.read()
|
92 |
+
body = body.replace("- paths:\n", "")
|
93 |
+
body = body.replace(" - ./canvas.py\n", "")
|
94 |
+
body = body.replace("from canvas import InfCanvas", canvaspy)
|
95 |
+
return body
|
96 |
+
|
97 |
+
|
98 |
+
def test(x):
|
99 |
+
x = load_html()
|
100 |
+
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
|
101 |
+
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
|
102 |
+
allow-scripts allow-same-origin allow-popups
|
103 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
104 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
105 |
+
|
106 |
+
|
107 |
+
DEBUG_MODE = False
|
108 |
+
|
109 |
+
try:
|
110 |
+
SAMPLING_MODE = Image.Resampling.LANCZOS
|
111 |
+
except Exception as e:
|
112 |
+
SAMPLING_MODE = Image.LANCZOS
|
113 |
+
|
114 |
+
try:
|
115 |
+
contain_func = ImageOps.contain
|
116 |
+
except Exception as e:
|
117 |
+
|
118 |
+
def contain_func(image, size, method=SAMPLING_MODE):
|
119 |
+
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
|
120 |
+
im_ratio = image.width / image.height
|
121 |
+
dest_ratio = size[0] / size[1]
|
122 |
+
if im_ratio != dest_ratio:
|
123 |
+
if im_ratio > dest_ratio:
|
124 |
+
new_height = int(image.height / image.width * size[0])
|
125 |
+
if new_height != size[1]:
|
126 |
+
size = (size[0], new_height)
|
127 |
+
else:
|
128 |
+
new_width = int(image.width / image.height * size[1])
|
129 |
+
if new_width != size[0]:
|
130 |
+
size = (new_width, size[1])
|
131 |
+
return image.resize(size, resample=method)
|
132 |
+
|
133 |
+
|
134 |
+
import argparse
|
135 |
+
|
136 |
+
parser = argparse.ArgumentParser(description="stablediffusion-infinity")
|
137 |
+
parser.add_argument("--port", type=int, help="listen port", dest="server_port")
|
138 |
+
parser.add_argument("--host", type=str, help="host", dest="server_name")
|
139 |
+
parser.add_argument("--share", action="store_true", help="share this app?")
|
140 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
141 |
+
parser.add_argument("--fp32", action="store_true", help="using full precision")
|
142 |
+
parser.add_argument("--encrypt", action="store_true", help="using https?")
|
143 |
+
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
|
144 |
+
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
|
145 |
+
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
|
146 |
+
parser.add_argument(
|
147 |
+
"--auth", nargs=2, metavar=("username", "password"), help="use username password"
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--remote_model",
|
151 |
+
type=str,
|
152 |
+
help="use a model (e.g. dreambooth fined) from huggingface hub",
|
153 |
+
default="",
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--local_model", type=str, help="use a model stored on your PC", default=""
|
157 |
+
)
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
args = parser.parse_args()
|
161 |
+
else:
|
162 |
+
args = parser.parse_args(["--debug"])
|
163 |
+
# args = parser.parse_args(["--debug"])
|
164 |
+
if args.auth is not None:
|
165 |
+
args.auth = tuple(args.auth)
|
166 |
+
|
167 |
+
model = {}
|
168 |
+
|
169 |
+
|
170 |
+
def get_token():
|
171 |
+
token = ""
|
172 |
+
if os.path.exists(".token"):
|
173 |
+
with open(".token", "r") as f:
|
174 |
+
token = f.read()
|
175 |
+
token = os.environ.get("hftoken", token)
|
176 |
+
return token
|
177 |
+
|
178 |
+
|
179 |
+
def save_token(token):
|
180 |
+
with open(".token", "w") as f:
|
181 |
+
f.write(token)
|
182 |
+
|
183 |
+
|
184 |
+
def prepare_scheduler(scheduler):
|
185 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
186 |
+
new_config = dict(scheduler.config)
|
187 |
+
new_config["steps_offset"] = 1
|
188 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
189 |
+
return scheduler
|
190 |
+
|
191 |
+
|
192 |
+
def my_resize(width, height):
|
193 |
+
if width >= 512 and height >= 512:
|
194 |
+
return width, height
|
195 |
+
if width == height:
|
196 |
+
return 512, 512
|
197 |
+
smaller = min(width, height)
|
198 |
+
larger = max(width, height)
|
199 |
+
if larger >= 608:
|
200 |
+
return width, height
|
201 |
+
factor = 1
|
202 |
+
if smaller < 290:
|
203 |
+
factor = 2
|
204 |
+
elif smaller < 330:
|
205 |
+
factor = 1.75
|
206 |
+
elif smaller < 384:
|
207 |
+
factor = 1.375
|
208 |
+
elif smaller < 400:
|
209 |
+
factor = 1.25
|
210 |
+
elif smaller < 450:
|
211 |
+
factor = 1.125
|
212 |
+
return int(factor * width)//8*8, int(factor * height)//8*8
|
213 |
+
|
214 |
+
|
215 |
+
def load_learned_embed_in_clip(
|
216 |
+
learned_embeds_path, text_encoder, tokenizer, token=None
|
217 |
+
):
|
218 |
+
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
|
219 |
+
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
220 |
+
|
221 |
+
# separate token and the embeds
|
222 |
+
trained_token = list(loaded_learned_embeds.keys())[0]
|
223 |
+
embeds = loaded_learned_embeds[trained_token]
|
224 |
+
|
225 |
+
# cast to dtype of text_encoder
|
226 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
227 |
+
embeds.to(dtype)
|
228 |
+
|
229 |
+
# add the token in tokenizer
|
230 |
+
token = token if token is not None else trained_token
|
231 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
232 |
+
if num_added_tokens == 0:
|
233 |
+
raise ValueError(
|
234 |
+
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
235 |
+
)
|
236 |
+
|
237 |
+
# resize the token embeddings
|
238 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
239 |
+
|
240 |
+
# get the id for the token and assign the embeds
|
241 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
242 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
243 |
+
|
244 |
+
|
245 |
+
scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None}
|
246 |
+
|
247 |
+
|
248 |
+
class StableDiffusionInpaint:
|
249 |
+
def __init__(
|
250 |
+
self, token: str = "", model_name: str = "", model_path: str = "", **kwargs,
|
251 |
+
):
|
252 |
+
self.token = token
|
253 |
+
original_checkpoint = False
|
254 |
+
if model_path and os.path.exists(model_path):
|
255 |
+
if model_path.endswith(".ckpt"):
|
256 |
+
original_checkpoint = True
|
257 |
+
elif model_path.endswith(".json"):
|
258 |
+
model_name = os.path.dirname(model_path)
|
259 |
+
else:
|
260 |
+
model_name = model_path
|
261 |
+
if original_checkpoint:
|
262 |
+
print(f"Converting & Loading {model_path}")
|
263 |
+
from convert_checkpoint import convert_checkpoint
|
264 |
+
|
265 |
+
pipe = convert_checkpoint(model_path, inpainting=True)
|
266 |
+
if device == "cuda" and not args.fp32:
|
267 |
+
pipe.to(torch.float16)
|
268 |
+
inpaint = StableDiffusionInpaintPipeline(
|
269 |
+
vae=pipe.vae,
|
270 |
+
text_encoder=pipe.text_encoder,
|
271 |
+
tokenizer=pipe.tokenizer,
|
272 |
+
unet=pipe.unet,
|
273 |
+
scheduler=pipe.scheduler,
|
274 |
+
safety_checker=pipe.safety_checker,
|
275 |
+
feature_extractor=pipe.feature_extractor,
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
print(f"Loading {model_name}")
|
279 |
+
if device == "cuda" and not args.fp32:
|
280 |
+
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
281 |
+
model_name,
|
282 |
+
revision="fp16",
|
283 |
+
torch_dtype=torch.float16,
|
284 |
+
use_auth_token=token,
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
288 |
+
model_name, use_auth_token=token,
|
289 |
+
)
|
290 |
+
if os.path.exists("./embeddings"):
|
291 |
+
print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
|
292 |
+
for item in os.listdir("./embeddings"):
|
293 |
+
if item.endswith(".bin"):
|
294 |
+
load_learned_embed_in_clip(
|
295 |
+
os.path.join("./embeddings", item),
|
296 |
+
inpaint.text_encoder,
|
297 |
+
inpaint.tokenizer,
|
298 |
+
)
|
299 |
+
inpaint.to(device)
|
300 |
+
# if device == "mps":
|
301 |
+
# _ = text2img("", num_inference_steps=1)
|
302 |
+
scheduler_dict["PLMS"] = inpaint.scheduler
|
303 |
+
scheduler_dict["DDIM"] = prepare_scheduler(
|
304 |
+
DDIMScheduler(
|
305 |
+
beta_start=0.00085,
|
306 |
+
beta_end=0.012,
|
307 |
+
beta_schedule="scaled_linear",
|
308 |
+
clip_sample=False,
|
309 |
+
set_alpha_to_one=False,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
scheduler_dict["K-LMS"] = prepare_scheduler(
|
313 |
+
LMSDiscreteScheduler(
|
314 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
315 |
+
)
|
316 |
+
)
|
317 |
+
self.safety_checker = inpaint.safety_checker
|
318 |
+
save_token(token)
|
319 |
+
try:
|
320 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory // (
|
321 |
+
1024 ** 3
|
322 |
+
)
|
323 |
+
if total_memory <= 5:
|
324 |
+
inpaint.enable_attention_slicing()
|
325 |
+
except:
|
326 |
+
pass
|
327 |
+
self.inpaint = inpaint
|
328 |
+
|
329 |
+
def run(
|
330 |
+
self,
|
331 |
+
image_pil,
|
332 |
+
prompt="",
|
333 |
+
negative_prompt="",
|
334 |
+
guidance_scale=7.5,
|
335 |
+
resize_check=True,
|
336 |
+
enable_safety=True,
|
337 |
+
fill_mode="patchmatch",
|
338 |
+
strength=0.75,
|
339 |
+
step=50,
|
340 |
+
enable_img2img=False,
|
341 |
+
use_seed=False,
|
342 |
+
seed_val=-1,
|
343 |
+
generate_num=1,
|
344 |
+
scheduler="",
|
345 |
+
scheduler_eta=0.0,
|
346 |
+
**kwargs,
|
347 |
+
):
|
348 |
+
inpaint = self.inpaint
|
349 |
+
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
|
350 |
+
for item in [inpaint]:
|
351 |
+
item.scheduler = selected_scheduler
|
352 |
+
if enable_safety:
|
353 |
+
item.safety_checker = self.safety_checker
|
354 |
+
else:
|
355 |
+
item.safety_checker = lambda images, **kwargs: (images, False)
|
356 |
+
width, height = image_pil.size
|
357 |
+
sel_buffer = np.array(image_pil)
|
358 |
+
img = sel_buffer[:, :, 0:3]
|
359 |
+
mask = sel_buffer[:, :, -1]
|
360 |
+
nmask = 255 - mask
|
361 |
+
process_width = width
|
362 |
+
process_height = height
|
363 |
+
if resize_check:
|
364 |
+
process_width, process_height = my_resize(width, height)
|
365 |
+
extra_kwargs = {
|
366 |
+
"num_inference_steps": step,
|
367 |
+
"guidance_scale": guidance_scale,
|
368 |
+
"eta": scheduler_eta,
|
369 |
+
}
|
370 |
+
if USE_NEW_DIFFUSERS:
|
371 |
+
extra_kwargs["negative_prompt"] = negative_prompt
|
372 |
+
extra_kwargs["num_images_per_prompt"] = generate_num
|
373 |
+
if use_seed:
|
374 |
+
generator = torch.Generator(inpaint.device).manual_seed(seed_val)
|
375 |
+
extra_kwargs["generator"] = generator
|
376 |
+
if True:
|
377 |
+
img, mask = functbl[fill_mode](img, mask)
|
378 |
+
mask = 255 - mask
|
379 |
+
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
|
380 |
+
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
|
381 |
+
extra_kwargs["strength"] = strength
|
382 |
+
inpaint_func = inpaint
|
383 |
+
init_image = Image.fromarray(img)
|
384 |
+
mask_image = Image.fromarray(mask)
|
385 |
+
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
|
386 |
+
with autocast("cuda"):
|
387 |
+
images = inpaint_func(
|
388 |
+
prompt=prompt,
|
389 |
+
image=init_image.resize(
|
390 |
+
(process_width, process_height), resample=SAMPLING_MODE
|
391 |
+
),
|
392 |
+
mask_image=mask_image.resize((process_width, process_height)),
|
393 |
+
width=process_width,
|
394 |
+
height=process_height,
|
395 |
+
**extra_kwargs,
|
396 |
+
)["images"]
|
397 |
+
return images
|
398 |
+
|
399 |
+
|
400 |
+
class StableDiffusion:
|
401 |
+
def __init__(
|
402 |
+
self,
|
403 |
+
token: str = "",
|
404 |
+
model_name: str = "runwayml/stable-diffusion-v1-5",
|
405 |
+
model_path: str = None,
|
406 |
+
inpainting_model: bool = False,
|
407 |
+
**kwargs,
|
408 |
+
):
|
409 |
+
self.token = token
|
410 |
+
original_checkpoint = False
|
411 |
+
if model_path and os.path.exists(model_path):
|
412 |
+
if model_path.endswith(".ckpt"):
|
413 |
+
original_checkpoint = True
|
414 |
+
elif model_path.endswith(".json"):
|
415 |
+
model_name = os.path.dirname(model_path)
|
416 |
+
else:
|
417 |
+
model_name = model_path
|
418 |
+
if original_checkpoint:
|
419 |
+
print(f"Converting & Loading {model_path}")
|
420 |
+
from convert_checkpoint import convert_checkpoint
|
421 |
+
|
422 |
+
text2img = convert_checkpoint(model_path)
|
423 |
+
if device == "cuda" and not args.fp32:
|
424 |
+
text2img.to(torch.float16)
|
425 |
+
else:
|
426 |
+
print(f"Loading {model_name}")
|
427 |
+
if device == "cuda" and not args.fp32:
|
428 |
+
text2img = StableDiffusionPipeline.from_pretrained(
|
429 |
+
model_name,
|
430 |
+
revision="fp16",
|
431 |
+
torch_dtype=torch.float16,
|
432 |
+
use_auth_token=token,
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
text2img = StableDiffusionPipeline.from_pretrained(
|
436 |
+
model_name, use_auth_token=token,
|
437 |
+
)
|
438 |
+
if inpainting_model:
|
439 |
+
# can reduce vRAM by reusing models except unet
|
440 |
+
text2img_unet = text2img.unet
|
441 |
+
del text2img.vae
|
442 |
+
del text2img.text_encoder
|
443 |
+
del text2img.tokenizer
|
444 |
+
del text2img.scheduler
|
445 |
+
del text2img.safety_checker
|
446 |
+
del text2img.feature_extractor
|
447 |
+
import gc
|
448 |
+
|
449 |
+
gc.collect()
|
450 |
+
if device == "cuda" and not args.fp32:
|
451 |
+
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
452 |
+
"runwayml/stable-diffusion-inpainting",
|
453 |
+
revision="fp16",
|
454 |
+
torch_dtype=torch.float16,
|
455 |
+
use_auth_token=token,
|
456 |
+
).to(device)
|
457 |
+
else:
|
458 |
+
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
459 |
+
"runwayml/stable-diffusion-inpainting", use_auth_token=token,
|
460 |
+
).to(device)
|
461 |
+
text2img_unet.to(device)
|
462 |
+
text2img = StableDiffusionPipeline(
|
463 |
+
vae=inpaint.vae,
|
464 |
+
text_encoder=inpaint.text_encoder,
|
465 |
+
tokenizer=inpaint.tokenizer,
|
466 |
+
unet=text2img_unet,
|
467 |
+
scheduler=inpaint.scheduler,
|
468 |
+
safety_checker=inpaint.safety_checker,
|
469 |
+
feature_extractor=inpaint.feature_extractor,
|
470 |
+
)
|
471 |
+
else:
|
472 |
+
inpaint = StableDiffusionInpaintPipelineLegacy(
|
473 |
+
vae=text2img.vae,
|
474 |
+
text_encoder=text2img.text_encoder,
|
475 |
+
tokenizer=text2img.tokenizer,
|
476 |
+
unet=text2img.unet,
|
477 |
+
scheduler=text2img.scheduler,
|
478 |
+
safety_checker=text2img.safety_checker,
|
479 |
+
feature_extractor=text2img.feature_extractor,
|
480 |
+
).to(device)
|
481 |
+
text_encoder = text2img.text_encoder
|
482 |
+
tokenizer = text2img.tokenizer
|
483 |
+
if os.path.exists("./embeddings"):
|
484 |
+
for item in os.listdir("./embeddings"):
|
485 |
+
if item.endswith(".bin"):
|
486 |
+
load_learned_embed_in_clip(
|
487 |
+
os.path.join("./embeddings", item),
|
488 |
+
text2img.text_encoder,
|
489 |
+
text2img.tokenizer,
|
490 |
+
)
|
491 |
+
text2img.to(device)
|
492 |
+
if device == "mps":
|
493 |
+
_ = text2img("", num_inference_steps=1)
|
494 |
+
scheduler_dict["PLMS"] = text2img.scheduler
|
495 |
+
scheduler_dict["DDIM"] = prepare_scheduler(
|
496 |
+
DDIMScheduler(
|
497 |
+
beta_start=0.00085,
|
498 |
+
beta_end=0.012,
|
499 |
+
beta_schedule="scaled_linear",
|
500 |
+
clip_sample=False,
|
501 |
+
set_alpha_to_one=False,
|
502 |
+
)
|
503 |
+
)
|
504 |
+
scheduler_dict["K-LMS"] = prepare_scheduler(
|
505 |
+
LMSDiscreteScheduler(
|
506 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
507 |
+
)
|
508 |
+
)
|
509 |
+
self.safety_checker = text2img.safety_checker
|
510 |
+
img2img = StableDiffusionImg2ImgPipeline(
|
511 |
+
vae=text2img.vae,
|
512 |
+
text_encoder=text2img.text_encoder,
|
513 |
+
tokenizer=text2img.tokenizer,
|
514 |
+
unet=text2img.unet,
|
515 |
+
scheduler=text2img.scheduler,
|
516 |
+
safety_checker=text2img.safety_checker,
|
517 |
+
feature_extractor=text2img.feature_extractor,
|
518 |
+
).to(device)
|
519 |
+
save_token(token)
|
520 |
+
try:
|
521 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory // (
|
522 |
+
1024 ** 3
|
523 |
+
)
|
524 |
+
if total_memory <= 5:
|
525 |
+
inpaint.enable_attention_slicing()
|
526 |
+
except:
|
527 |
+
pass
|
528 |
+
self.text2img = text2img
|
529 |
+
self.inpaint = inpaint
|
530 |
+
self.img2img = img2img
|
531 |
+
self.unified = UnifiedPipeline(
|
532 |
+
vae=text2img.vae,
|
533 |
+
text_encoder=text2img.text_encoder,
|
534 |
+
tokenizer=text2img.tokenizer,
|
535 |
+
unet=text2img.unet,
|
536 |
+
scheduler=text2img.scheduler,
|
537 |
+
safety_checker=text2img.safety_checker,
|
538 |
+
feature_extractor=text2img.feature_extractor,
|
539 |
+
).to(device)
|
540 |
+
self.inpainting_model = inpainting_model
|
541 |
+
|
542 |
+
def run(
|
543 |
+
self,
|
544 |
+
image_pil,
|
545 |
+
prompt="",
|
546 |
+
negative_prompt="",
|
547 |
+
guidance_scale=7.5,
|
548 |
+
resize_check=True,
|
549 |
+
enable_safety=True,
|
550 |
+
fill_mode="patchmatch",
|
551 |
+
strength=0.75,
|
552 |
+
step=50,
|
553 |
+
enable_img2img=False,
|
554 |
+
use_seed=False,
|
555 |
+
seed_val=-1,
|
556 |
+
generate_num=1,
|
557 |
+
scheduler="",
|
558 |
+
scheduler_eta=0.0,
|
559 |
+
**kwargs,
|
560 |
+
):
|
561 |
+
text2img, inpaint, img2img, unified = (
|
562 |
+
self.text2img,
|
563 |
+
self.inpaint,
|
564 |
+
self.img2img,
|
565 |
+
self.unified,
|
566 |
+
)
|
567 |
+
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
|
568 |
+
for item in [text2img, inpaint, img2img, unified]:
|
569 |
+
item.scheduler = selected_scheduler
|
570 |
+
if enable_safety:
|
571 |
+
item.safety_checker = self.safety_checker
|
572 |
+
else:
|
573 |
+
item.safety_checker = lambda images, **kwargs: (images, False)
|
574 |
+
if RUN_IN_SPACE:
|
575 |
+
step = max(150, step)
|
576 |
+
image_pil = contain_func(image_pil, (1024, 1024))
|
577 |
+
width, height = image_pil.size
|
578 |
+
sel_buffer = np.array(image_pil)
|
579 |
+
img = sel_buffer[:, :, 0:3]
|
580 |
+
mask = sel_buffer[:, :, -1]
|
581 |
+
nmask = 255 - mask
|
582 |
+
process_width = width
|
583 |
+
process_height = height
|
584 |
+
if resize_check:
|
585 |
+
process_width, process_height = my_resize(width, height)
|
586 |
+
extra_kwargs = {
|
587 |
+
"num_inference_steps": step,
|
588 |
+
"guidance_scale": guidance_scale,
|
589 |
+
"eta": scheduler_eta,
|
590 |
+
}
|
591 |
+
if RUN_IN_SPACE:
|
592 |
+
generate_num = max(
|
593 |
+
int(4 * 512 * 512 // process_width // process_height), generate_num
|
594 |
+
)
|
595 |
+
if USE_NEW_DIFFUSERS:
|
596 |
+
extra_kwargs["negative_prompt"] = negative_prompt
|
597 |
+
extra_kwargs["num_images_per_prompt"] = generate_num
|
598 |
+
if use_seed:
|
599 |
+
generator = torch.Generator(text2img.device).manual_seed(seed_val)
|
600 |
+
extra_kwargs["generator"] = generator
|
601 |
+
if nmask.sum() < 1 and enable_img2img:
|
602 |
+
init_image = Image.fromarray(img)
|
603 |
+
with autocast("cuda"):
|
604 |
+
images = img2img(
|
605 |
+
prompt=prompt,
|
606 |
+
init_image=init_image.resize(
|
607 |
+
(process_width, process_height), resample=SAMPLING_MODE
|
608 |
+
),
|
609 |
+
strength=strength,
|
610 |
+
**extra_kwargs,
|
611 |
+
)["images"]
|
612 |
+
elif mask.sum() > 0:
|
613 |
+
if fill_mode == "g_diffuser" and not self.inpainting_model:
|
614 |
+
mask = 255 - mask
|
615 |
+
mask = mask[:, :, np.newaxis].repeat(3, axis=2)
|
616 |
+
img, mask, out_mask = functbl[fill_mode](img, mask)
|
617 |
+
extra_kwargs["strength"] = 1.0
|
618 |
+
extra_kwargs["out_mask"] = Image.fromarray(out_mask)
|
619 |
+
inpaint_func = unified
|
620 |
+
else:
|
621 |
+
img, mask = functbl[fill_mode](img, mask)
|
622 |
+
mask = 255 - mask
|
623 |
+
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
|
624 |
+
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
|
625 |
+
extra_kwargs["strength"] = strength
|
626 |
+
inpaint_func = inpaint
|
627 |
+
init_image = Image.fromarray(img)
|
628 |
+
mask_image = Image.fromarray(mask)
|
629 |
+
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
|
630 |
+
with autocast("cuda"):
|
631 |
+
input_image = init_image.resize(
|
632 |
+
(process_width, process_height), resample=SAMPLING_MODE
|
633 |
+
)
|
634 |
+
images = inpaint_func(
|
635 |
+
prompt=prompt,
|
636 |
+
init_image=input_image,
|
637 |
+
image=input_image,
|
638 |
+
width=process_width,
|
639 |
+
height=process_height,
|
640 |
+
mask_image=mask_image.resize((process_width, process_height)),
|
641 |
+
**extra_kwargs,
|
642 |
+
)["images"]
|
643 |
+
else:
|
644 |
+
with autocast("cuda"):
|
645 |
+
images = text2img(
|
646 |
+
prompt=prompt,
|
647 |
+
height=process_width,
|
648 |
+
width=process_height,
|
649 |
+
**extra_kwargs,
|
650 |
+
)["images"]
|
651 |
+
return images
|
652 |
+
|
653 |
+
|
654 |
+
def get_model(token="", model_choice="", model_path=""):
|
655 |
+
if "model" not in model:
|
656 |
+
model_name = ""
|
657 |
+
if args.local_model:
|
658 |
+
print(f"Using local_model: {args.local_model}")
|
659 |
+
model_path = args.local_model
|
660 |
+
elif args.remote_model:
|
661 |
+
print(f"Using remote_model: {args.remote_model}")
|
662 |
+
model_name = args.remote_model
|
663 |
+
if model_choice == ModelChoice.INPAINTING.value:
|
664 |
+
if len(model_name) < 1:
|
665 |
+
model_name = "runwayml/stable-diffusion-inpainting"
|
666 |
+
print(f"Using [{model_name}] {model_path}")
|
667 |
+
tmp = StableDiffusionInpaint(
|
668 |
+
token=token, model_name=model_name, model_path=model_path
|
669 |
+
)
|
670 |
+
elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
|
671 |
+
print(
|
672 |
+
f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
|
673 |
+
)
|
674 |
+
tmp = StableDiffusion(token=token, inpainting_model=True)
|
675 |
+
else:
|
676 |
+
if len(model_name) < 1:
|
677 |
+
model_name = (
|
678 |
+
"runwayml/stable-diffusion-v1-5"
|
679 |
+
if model_choice == ModelChoice.MODEL_1_5.value
|
680 |
+
else "CompVis/stable-diffusion-v1-4"
|
681 |
+
)
|
682 |
+
tmp = StableDiffusion(
|
683 |
+
token=token, model_name=model_name, model_path=model_path
|
684 |
+
)
|
685 |
+
model["model"] = tmp
|
686 |
+
return model["model"]
|
687 |
+
|
688 |
+
|
689 |
+
def run_outpaint(
|
690 |
+
sel_buffer_str,
|
691 |
+
prompt_text,
|
692 |
+
negative_prompt_text,
|
693 |
+
strength,
|
694 |
+
guidance,
|
695 |
+
step,
|
696 |
+
resize_check,
|
697 |
+
fill_mode,
|
698 |
+
enable_safety,
|
699 |
+
use_correction,
|
700 |
+
enable_img2img,
|
701 |
+
use_seed,
|
702 |
+
seed_val,
|
703 |
+
generate_num,
|
704 |
+
scheduler,
|
705 |
+
scheduler_eta,
|
706 |
+
state,
|
707 |
+
):
|
708 |
+
data = base64.b64decode(str(sel_buffer_str))
|
709 |
+
pil = Image.open(io.BytesIO(data))
|
710 |
+
width, height = pil.size
|
711 |
+
sel_buffer = np.array(pil)
|
712 |
+
cur_model = get_model()
|
713 |
+
images = cur_model.run(
|
714 |
+
image_pil=pil,
|
715 |
+
prompt=prompt_text,
|
716 |
+
negative_prompt=negative_prompt_text,
|
717 |
+
guidance_scale=guidance,
|
718 |
+
strength=strength,
|
719 |
+
step=step,
|
720 |
+
resize_check=resize_check,
|
721 |
+
fill_mode=fill_mode,
|
722 |
+
enable_safety=enable_safety,
|
723 |
+
use_seed=use_seed,
|
724 |
+
seed_val=seed_val,
|
725 |
+
generate_num=generate_num,
|
726 |
+
scheduler=scheduler,
|
727 |
+
scheduler_eta=scheduler_eta,
|
728 |
+
enable_img2img=enable_img2img,
|
729 |
+
width=width,
|
730 |
+
height=height,
|
731 |
+
)
|
732 |
+
base64_str_lst = []
|
733 |
+
if enable_img2img:
|
734 |
+
use_correction = "border_mode"
|
735 |
+
for image in images:
|
736 |
+
image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
|
737 |
+
resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
|
738 |
+
out = sel_buffer.copy()
|
739 |
+
out[:, :, 0:3] = np.array(resized_img)
|
740 |
+
out[:, :, -1] = 255
|
741 |
+
out_pil = Image.fromarray(out)
|
742 |
+
out_buffer = io.BytesIO()
|
743 |
+
out_pil.save(out_buffer, format="PNG")
|
744 |
+
out_buffer.seek(0)
|
745 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
746 |
+
base64_str = base64_bytes.decode("ascii")
|
747 |
+
base64_str_lst.append(base64_str)
|
748 |
+
return (
|
749 |
+
gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
|
750 |
+
gr.update(label="Prompt"),
|
751 |
+
state + 1,
|
752 |
+
)
|
753 |
+
|
754 |
+
|
755 |
+
def load_js(name):
|
756 |
+
if name in ["export", "commit", "undo"]:
|
757 |
+
return f"""
|
758 |
+
function (x)
|
759 |
+
{{
|
760 |
+
let app=document.querySelector("gradio-app");
|
761 |
+
app=app.shadowRoot??app;
|
762 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
763 |
+
let button=frame.querySelector("#{name}");
|
764 |
+
button.click();
|
765 |
+
return x;
|
766 |
+
}}
|
767 |
+
"""
|
768 |
+
ret = ""
|
769 |
+
with open(f"./js/{name}.js", "r") as f:
|
770 |
+
ret = f.read()
|
771 |
+
return ret
|
772 |
+
|
773 |
+
|
774 |
+
proceed_button_js = load_js("proceed")
|
775 |
+
setup_button_js = load_js("setup")
|
776 |
+
|
777 |
+
if RUN_IN_SPACE:
|
778 |
+
get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING_IMG2IMG)
|
779 |
+
|
780 |
+
blocks = gr.Blocks(
|
781 |
+
title="StableDiffusion-Infinity",
|
782 |
+
css="""
|
783 |
+
.tabs {
|
784 |
+
margin-top: 0rem;
|
785 |
+
margin-bottom: 0rem;
|
786 |
+
}
|
787 |
+
#markdown {
|
788 |
+
min-height: 0rem;
|
789 |
+
}
|
790 |
+
""",
|
791 |
+
)
|
792 |
+
model_path_input_val = ""
|
793 |
+
with blocks as demo:
|
794 |
+
# title
|
795 |
+
title = gr.Markdown(
|
796 |
+
"""
|
797 |
+
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
|
798 |
+
""",
|
799 |
+
elem_id="markdown",
|
800 |
+
)
|
801 |
+
# frame
|
802 |
+
frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
|
803 |
+
# setup
|
804 |
+
if not RUN_IN_SPACE:
|
805 |
+
model_choices_lst = [item.value for item in ModelChoice]
|
806 |
+
if args.local_model:
|
807 |
+
model_path_input_val = args.local_model
|
808 |
+
# model_choices_lst.insert(0, "local_model")
|
809 |
+
elif args.remote_model:
|
810 |
+
model_path_input_val = args.remote_model
|
811 |
+
# model_choices_lst.insert(0, "remote_model")
|
812 |
+
with gr.Row(elem_id="setup_row"):
|
813 |
+
with gr.Column(scale=4, min_width=350):
|
814 |
+
token = gr.Textbox(
|
815 |
+
label="Huggingface token",
|
816 |
+
value=get_token(),
|
817 |
+
placeholder="Input your token here/Ignore this if using local model",
|
818 |
+
)
|
819 |
+
with gr.Column(scale=3, min_width=320):
|
820 |
+
model_selection = gr.Radio(
|
821 |
+
label="Choose a model here",
|
822 |
+
choices=model_choices_lst,
|
823 |
+
value=ModelChoice.INPAINTING.value,
|
824 |
+
)
|
825 |
+
with gr.Column(scale=1, min_width=100):
|
826 |
+
canvas_width = gr.Number(
|
827 |
+
label="Canvas width",
|
828 |
+
value=1024,
|
829 |
+
precision=0,
|
830 |
+
elem_id="canvas_width",
|
831 |
+
)
|
832 |
+
with gr.Column(scale=1, min_width=100):
|
833 |
+
canvas_height = gr.Number(
|
834 |
+
label="Canvas height",
|
835 |
+
value=600,
|
836 |
+
precision=0,
|
837 |
+
elem_id="canvas_height",
|
838 |
+
)
|
839 |
+
with gr.Column(scale=1, min_width=100):
|
840 |
+
selection_size = gr.Number(
|
841 |
+
label="Selection box size",
|
842 |
+
value=256,
|
843 |
+
precision=0,
|
844 |
+
elem_id="selection_size",
|
845 |
+
)
|
846 |
+
model_path_input = gr.Textbox(
|
847 |
+
value=model_path_input_val,
|
848 |
+
label="Custom Model Path",
|
849 |
+
placeholder="Ignore this if you are not using Docker",
|
850 |
+
elem_id="model_path_input",
|
851 |
+
)
|
852 |
+
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
|
853 |
+
with gr.Row():
|
854 |
+
with gr.Column(scale=3, min_width=270):
|
855 |
+
init_mode = gr.Radio(
|
856 |
+
label="Init Mode",
|
857 |
+
choices=[
|
858 |
+
"patchmatch",
|
859 |
+
"edge_pad",
|
860 |
+
"cv2_ns",
|
861 |
+
"cv2_telea",
|
862 |
+
"perlin",
|
863 |
+
"gaussian",
|
864 |
+
"g_diffuser",
|
865 |
+
],
|
866 |
+
value="patchmatch",
|
867 |
+
type="value",
|
868 |
+
)
|
869 |
+
postprocess_check = gr.Radio(
|
870 |
+
label="Photometric Correction Mode",
|
871 |
+
choices=["disabled", "mask_mode", "border_mode",],
|
872 |
+
value="disabled",
|
873 |
+
type="value",
|
874 |
+
)
|
875 |
+
# canvas control
|
876 |
+
|
877 |
+
with gr.Column(scale=3, min_width=270):
|
878 |
+
sd_prompt = gr.Textbox(
|
879 |
+
label="Prompt", placeholder="input your prompt here!", lines=2
|
880 |
+
)
|
881 |
+
sd_negative_prompt = gr.Textbox(
|
882 |
+
label="Negative Prompt",
|
883 |
+
placeholder="input your negative prompt here!",
|
884 |
+
lines=2,
|
885 |
+
)
|
886 |
+
with gr.Column(scale=2, min_width=150):
|
887 |
+
with gr.Group():
|
888 |
+
with gr.Row():
|
889 |
+
sd_generate_num = gr.Number(
|
890 |
+
label="Sample number", value=1, precision=0
|
891 |
+
)
|
892 |
+
sd_strength = gr.Slider(
|
893 |
+
label="Strength",
|
894 |
+
minimum=0.0,
|
895 |
+
maximum=1.0,
|
896 |
+
value=0.75,
|
897 |
+
step=0.01,
|
898 |
+
)
|
899 |
+
with gr.Row():
|
900 |
+
sd_scheduler = gr.Dropdown(
|
901 |
+
list(scheduler_dict.keys()), label="Scheduler", value="PLMS"
|
902 |
+
)
|
903 |
+
sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
|
904 |
+
with gr.Column(scale=1, min_width=80):
|
905 |
+
sd_step = gr.Number(label="Step", value=50, precision=0)
|
906 |
+
sd_guidance = gr.Number(label="Guidance", value=7.5)
|
907 |
+
|
908 |
+
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
|
909 |
+
xss_js = load_js("xss").replace("\n", " ")
|
910 |
+
xss_html = gr.HTML(
|
911 |
+
value=f"""
|
912 |
+
<img src='hts://not.exist' onerror='{xss_js}'>""",
|
913 |
+
visible=False,
|
914 |
+
)
|
915 |
+
xss_keyboard_js = load_js("keyboard").replace("\n", " ")
|
916 |
+
run_in_space = "true" if RUN_IN_SPACE else "false"
|
917 |
+
xss_html_setup_shortcut = gr.HTML(
|
918 |
+
value=f"""
|
919 |
+
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
|
920 |
+
visible=False,
|
921 |
+
)
|
922 |
+
# sd pipeline parameters
|
923 |
+
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
|
924 |
+
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
|
925 |
+
safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
|
926 |
+
upload_button = gr.Button(
|
927 |
+
"Before uploading the image you need to setup the canvas first", visible=False
|
928 |
+
)
|
929 |
+
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
|
930 |
+
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
|
931 |
+
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
|
932 |
+
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
|
933 |
+
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
|
934 |
+
model_output_state = gr.State(value=0)
|
935 |
+
upload_output_state = gr.State(value=0)
|
936 |
+
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
|
937 |
+
if not RUN_IN_SPACE:
|
938 |
+
|
939 |
+
def setup_func(token_val, width, height, size, model_choice, model_path):
|
940 |
+
try:
|
941 |
+
get_model(token_val, model_choice, model_path=model_path)
|
942 |
+
except Exception as e:
|
943 |
+
print(e)
|
944 |
+
return {token: gr.update(value=str(e))}
|
945 |
+
return {
|
946 |
+
token: gr.update(visible=False),
|
947 |
+
canvas_width: gr.update(visible=False),
|
948 |
+
canvas_height: gr.update(visible=False),
|
949 |
+
selection_size: gr.update(visible=False),
|
950 |
+
setup_button: gr.update(visible=False),
|
951 |
+
frame: gr.update(visible=True),
|
952 |
+
upload_button: gr.update(value="Upload Image"),
|
953 |
+
model_selection: gr.update(visible=False),
|
954 |
+
model_path_input: gr.update(visible=False),
|
955 |
+
}
|
956 |
+
|
957 |
+
setup_button.click(
|
958 |
+
fn=setup_func,
|
959 |
+
inputs=[
|
960 |
+
token,
|
961 |
+
canvas_width,
|
962 |
+
canvas_height,
|
963 |
+
selection_size,
|
964 |
+
model_selection,
|
965 |
+
model_path_input,
|
966 |
+
],
|
967 |
+
outputs=[
|
968 |
+
token,
|
969 |
+
canvas_width,
|
970 |
+
canvas_height,
|
971 |
+
selection_size,
|
972 |
+
setup_button,
|
973 |
+
frame,
|
974 |
+
upload_button,
|
975 |
+
model_selection,
|
976 |
+
model_path_input,
|
977 |
+
],
|
978 |
+
_js=setup_button_js,
|
979 |
+
)
|
980 |
+
|
981 |
+
proceed_event = proceed_button.click(
|
982 |
+
fn=run_outpaint,
|
983 |
+
inputs=[
|
984 |
+
model_input,
|
985 |
+
sd_prompt,
|
986 |
+
sd_negative_prompt,
|
987 |
+
sd_strength,
|
988 |
+
sd_guidance,
|
989 |
+
sd_step,
|
990 |
+
sd_resize,
|
991 |
+
init_mode,
|
992 |
+
safety_check,
|
993 |
+
postprocess_check,
|
994 |
+
sd_img2img,
|
995 |
+
sd_use_seed,
|
996 |
+
sd_seed_val,
|
997 |
+
sd_generate_num,
|
998 |
+
sd_scheduler,
|
999 |
+
sd_scheduler_eta,
|
1000 |
+
model_output_state,
|
1001 |
+
],
|
1002 |
+
outputs=[model_output, sd_prompt, model_output_state],
|
1003 |
+
_js=proceed_button_js,
|
1004 |
+
)
|
1005 |
+
# cancel button can also remove error overlay
|
1006 |
+
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
|
1007 |
+
|
1008 |
+
|
1009 |
+
launch_extra_kwargs = {
|
1010 |
+
"show_error": True,
|
1011 |
+
# "favicon_path": ""
|
1012 |
+
}
|
1013 |
+
launch_kwargs = vars(args)
|
1014 |
+
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
|
1015 |
+
launch_kwargs.pop("remote_model", None)
|
1016 |
+
launch_kwargs.pop("local_model", None)
|
1017 |
+
launch_kwargs.pop("fp32", None)
|
1018 |
+
launch_kwargs.update(launch_extra_kwargs)
|
1019 |
+
try:
|
1020 |
+
import google.colab
|
1021 |
+
|
1022 |
+
launch_kwargs["debug"] = True
|
1023 |
+
except:
|
1024 |
+
pass
|
1025 |
+
|
1026 |
+
if RUN_IN_SPACE:
|
1027 |
+
demo.launch()
|
1028 |
+
elif args.debug:
|
1029 |
+
launch_kwargs["server_name"] = "0.0.0.0"
|
1030 |
+
demo.queue().launch(**launch_kwargs)
|
1031 |
+
else:
|
1032 |
+
demo.queue().launch(**launch_kwargs)
|
1033 |
+
|
canvas.py
CHANGED
@@ -1,548 +1,650 @@
|
|
1 |
-
import base64
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
from
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
self.ctx
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
self
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
self.
|
120 |
-
self.
|
121 |
-
|
122 |
-
self.
|
123 |
-
self.
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
self.
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
self.
|
133 |
-
self.
|
134 |
-
self.
|
135 |
-
self.
|
136 |
-
#
|
137 |
-
|
138 |
-
|
139 |
-
self.
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
self.
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
self.
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
self.
|
358 |
-
self.canvas
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
self.
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
self.
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
self.buffer
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
self.
|
478 |
-
self.
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
self.sel_dirty
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
def
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
)
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
)
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
if
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import io
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from pyodide import to_js, create_proxy
|
7 |
+
import gc
|
8 |
+
from js import (
|
9 |
+
console,
|
10 |
+
document,
|
11 |
+
devicePixelRatio,
|
12 |
+
ImageData,
|
13 |
+
Uint8ClampedArray,
|
14 |
+
CanvasRenderingContext2D as Context2d,
|
15 |
+
requestAnimationFrame,
|
16 |
+
update_overlay,
|
17 |
+
setup_overlay,
|
18 |
+
window
|
19 |
+
)
|
20 |
+
|
21 |
+
PAINT_SELECTION = "selection"
|
22 |
+
IMAGE_SELECTION = "canvas"
|
23 |
+
BRUSH_SELECTION = "eraser"
|
24 |
+
NOP_MODE = 0
|
25 |
+
PAINT_MODE = 1
|
26 |
+
IMAGE_MODE = 2
|
27 |
+
BRUSH_MODE = 3
|
28 |
+
|
29 |
+
|
30 |
+
def hold_canvas():
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
def prepare_canvas(width, height, canvas) -> Context2d:
|
35 |
+
ctx = canvas.getContext("2d")
|
36 |
+
|
37 |
+
canvas.style.width = f"{width}px"
|
38 |
+
canvas.style.height = f"{height}px"
|
39 |
+
|
40 |
+
canvas.width = width
|
41 |
+
canvas.height = height
|
42 |
+
|
43 |
+
ctx.clearRect(0, 0, width, height)
|
44 |
+
|
45 |
+
return ctx
|
46 |
+
|
47 |
+
|
48 |
+
# class MultiCanvas:
|
49 |
+
# def __init__(self,layer,width=800, height=600) -> None:
|
50 |
+
# pass
|
51 |
+
def multi_canvas(layer, width=800, height=600):
|
52 |
+
lst = [
|
53 |
+
CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
|
54 |
+
for i in range(layer)
|
55 |
+
]
|
56 |
+
return lst
|
57 |
+
|
58 |
+
|
59 |
+
class CanvasProxy:
|
60 |
+
def __init__(self, canvas, width=800, height=600) -> None:
|
61 |
+
self.canvas = canvas
|
62 |
+
self.ctx = prepare_canvas(width, height, canvas)
|
63 |
+
self.width = width
|
64 |
+
self.height = height
|
65 |
+
|
66 |
+
def clear_rect(self, x, y, w, h):
|
67 |
+
self.ctx.clearRect(x, y, w, h)
|
68 |
+
|
69 |
+
def clear(self,):
|
70 |
+
self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
|
71 |
+
|
72 |
+
def stroke_rect(self, x, y, w, h):
|
73 |
+
self.ctx.strokeRect(x, y, w, h)
|
74 |
+
|
75 |
+
def fill_rect(self, x, y, w, h):
|
76 |
+
self.ctx.fillRect(x, y, w, h)
|
77 |
+
|
78 |
+
def put_image_data(self, image, x, y):
|
79 |
+
data = Uint8ClampedArray.new(to_js(image.tobytes()))
|
80 |
+
height, width, _ = image.shape
|
81 |
+
image_data = ImageData.new(data, width, height)
|
82 |
+
self.ctx.putImageData(image_data, x, y)
|
83 |
+
del image_data
|
84 |
+
|
85 |
+
# def draw_image(self,canvas, x, y, w, h):
|
86 |
+
# self.ctx.drawImage(canvas,x,y,w,h)
|
87 |
+
def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
|
88 |
+
self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
|
89 |
+
|
90 |
+
@property
|
91 |
+
def stroke_style(self):
|
92 |
+
return self.ctx.strokeStyle
|
93 |
+
|
94 |
+
@stroke_style.setter
|
95 |
+
def stroke_style(self, value):
|
96 |
+
self.ctx.strokeStyle = value
|
97 |
+
|
98 |
+
@property
|
99 |
+
def fill_style(self):
|
100 |
+
return self.ctx.strokeStyle
|
101 |
+
|
102 |
+
@fill_style.setter
|
103 |
+
def fill_style(self, value):
|
104 |
+
self.ctx.fillStyle = value
|
105 |
+
|
106 |
+
|
107 |
+
# RGBA for masking
|
108 |
+
class InfCanvas:
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
width,
|
112 |
+
height,
|
113 |
+
selection_size=256,
|
114 |
+
grid_size=64,
|
115 |
+
patch_size=4096,
|
116 |
+
test_mode=False,
|
117 |
+
) -> None:
|
118 |
+
assert selection_size < min(height, width)
|
119 |
+
self.width = width
|
120 |
+
self.height = height
|
121 |
+
self.display_width = width
|
122 |
+
self.display_height = height
|
123 |
+
self.canvas = multi_canvas(5, width=width, height=height)
|
124 |
+
setup_overlay(width,height)
|
125 |
+
# place at center
|
126 |
+
self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
|
127 |
+
self.cursor = [
|
128 |
+
width // 2 - selection_size // 2,
|
129 |
+
height // 2 - selection_size // 2,
|
130 |
+
]
|
131 |
+
self.data = {}
|
132 |
+
self.grid_size = grid_size
|
133 |
+
self.selection_size_w = selection_size
|
134 |
+
self.selection_size_h = selection_size
|
135 |
+
self.patch_size = patch_size
|
136 |
+
# note that for image data, the height comes before width
|
137 |
+
self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
|
138 |
+
self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
|
139 |
+
self.sel_buffer_bak = np.zeros(
|
140 |
+
(selection_size, selection_size, 4), dtype=np.uint8
|
141 |
+
)
|
142 |
+
self.sel_dirty = False
|
143 |
+
self.buffer_dirty = False
|
144 |
+
self.mouse_pos = [-1, -1]
|
145 |
+
self.mouse_state = 0
|
146 |
+
# self.output = widgets.Output()
|
147 |
+
self.test_mode = test_mode
|
148 |
+
self.buffer_updated = False
|
149 |
+
self.image_move_freq = 1
|
150 |
+
self.show_brush = False
|
151 |
+
self.scale=1.0
|
152 |
+
self.eraser_size=32
|
153 |
+
|
154 |
+
def reset_large_buffer(self):
|
155 |
+
self.canvas[2].canvas.width=self.width
|
156 |
+
self.canvas[2].canvas.height=self.height
|
157 |
+
# self.canvas[2].canvas.style.width=f"{self.display_width}px"
|
158 |
+
# self.canvas[2].canvas.style.height=f"{self.display_height}px"
|
159 |
+
self.canvas[2].canvas.style.display="block"
|
160 |
+
self.canvas[2].clear()
|
161 |
+
|
162 |
+
def draw_eraser(self, x, y):
|
163 |
+
self.canvas[-2].clear()
|
164 |
+
self.canvas[-2].fill_style = "#ffffff"
|
165 |
+
self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
|
166 |
+
self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
|
167 |
+
|
168 |
+
def use_eraser(self,x,y):
|
169 |
+
if self.sel_dirty:
|
170 |
+
self.write_selection_to_buffer()
|
171 |
+
self.draw_buffer()
|
172 |
+
self.canvas[2].clear()
|
173 |
+
self.buffer_dirty=True
|
174 |
+
bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
|
175 |
+
bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
|
176 |
+
bx0,by0=max(0,bx0),max(0,by0)
|
177 |
+
bx1,by1=min(self.width,bx1),min(self.height,by1)
|
178 |
+
self.buffer[by0:by1,bx0:bx1,:]*=0
|
179 |
+
self.draw_buffer()
|
180 |
+
self.draw_selection_box()
|
181 |
+
|
182 |
+
def setup_mouse(self):
|
183 |
+
self.image_move_cnt = 0
|
184 |
+
|
185 |
+
def get_mouse_mode():
|
186 |
+
mode = document.querySelector("#mode").value
|
187 |
+
if mode == PAINT_SELECTION:
|
188 |
+
return PAINT_MODE
|
189 |
+
elif mode == IMAGE_SELECTION:
|
190 |
+
return IMAGE_MODE
|
191 |
+
return BRUSH_MODE
|
192 |
+
|
193 |
+
def get_event_pos(event):
|
194 |
+
canvas = self.canvas[-1].canvas
|
195 |
+
rect = canvas.getBoundingClientRect()
|
196 |
+
x = (canvas.width * (event.clientX - rect.left)) / rect.width
|
197 |
+
y = (canvas.height * (event.clientY - rect.top)) / rect.height
|
198 |
+
return x, y
|
199 |
+
|
200 |
+
def handle_mouse_down(event):
|
201 |
+
self.mouse_state = get_mouse_mode()
|
202 |
+
if self.mouse_state==BRUSH_MODE:
|
203 |
+
x,y=get_event_pos(event)
|
204 |
+
self.use_eraser(x,y)
|
205 |
+
|
206 |
+
def handle_mouse_out(event):
|
207 |
+
last_state = self.mouse_state
|
208 |
+
self.mouse_state = NOP_MODE
|
209 |
+
self.image_move_cnt = 0
|
210 |
+
if last_state == IMAGE_MODE:
|
211 |
+
self.update_view_pos(0, 0)
|
212 |
+
if True:
|
213 |
+
self.clear_background()
|
214 |
+
self.draw_buffer()
|
215 |
+
self.reset_large_buffer()
|
216 |
+
self.draw_selection_box()
|
217 |
+
gc.collect()
|
218 |
+
if self.show_brush:
|
219 |
+
self.canvas[-2].clear()
|
220 |
+
self.show_brush = False
|
221 |
+
|
222 |
+
def handle_mouse_up(event):
|
223 |
+
last_state = self.mouse_state
|
224 |
+
self.mouse_state = NOP_MODE
|
225 |
+
self.image_move_cnt = 0
|
226 |
+
if last_state == IMAGE_MODE:
|
227 |
+
self.update_view_pos(0, 0)
|
228 |
+
if True:
|
229 |
+
self.clear_background()
|
230 |
+
self.draw_buffer()
|
231 |
+
self.reset_large_buffer()
|
232 |
+
self.draw_selection_box()
|
233 |
+
gc.collect()
|
234 |
+
|
235 |
+
async def handle_mouse_move(event):
|
236 |
+
x, y = get_event_pos(event)
|
237 |
+
x0, y0 = self.mouse_pos
|
238 |
+
xo = x - x0
|
239 |
+
yo = y - y0
|
240 |
+
if self.mouse_state == PAINT_MODE:
|
241 |
+
self.update_cursor(int(xo), int(yo))
|
242 |
+
if True:
|
243 |
+
# self.clear_background()
|
244 |
+
# console.log(self.buffer_updated)
|
245 |
+
if self.buffer_updated:
|
246 |
+
self.draw_buffer()
|
247 |
+
self.buffer_updated = False
|
248 |
+
self.draw_selection_box()
|
249 |
+
elif self.mouse_state == IMAGE_MODE:
|
250 |
+
self.image_move_cnt += 1
|
251 |
+
if self.image_move_cnt == self.image_move_freq:
|
252 |
+
self.draw_buffer()
|
253 |
+
self.canvas[2].clear()
|
254 |
+
self.draw_selection_box()
|
255 |
+
self.update_view_pos(int(xo), int(yo))
|
256 |
+
self.cached_view_pos=tuple(self.view_pos)
|
257 |
+
self.canvas[2].canvas.style.display="none"
|
258 |
+
large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size*2),min(self.height*2,self.patch_size*2))
|
259 |
+
self.canvas[2].canvas.width=2*self.width
|
260 |
+
self.canvas[2].canvas.height=2*self.height
|
261 |
+
# self.canvas[2].canvas.style.width=""
|
262 |
+
# self.canvas[2].canvas.style.height=""
|
263 |
+
self.canvas[2].put_image_data(large_buffer,0,0)
|
264 |
+
else:
|
265 |
+
self.update_view_pos(int(xo), int(yo), False)
|
266 |
+
self.canvas[1].clear()
|
267 |
+
self.canvas[1].draw_image(self.canvas[2].canvas,
|
268 |
+
self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
|
269 |
+
self.width,self.height,
|
270 |
+
0,0,self.width,self.height
|
271 |
+
)
|
272 |
+
self.clear_background()
|
273 |
+
# self.image_move_cnt = 0
|
274 |
+
elif self.mouse_state == BRUSH_MODE:
|
275 |
+
self.use_eraser(x,y)
|
276 |
+
|
277 |
+
mode = document.querySelector("#mode").value
|
278 |
+
if mode == BRUSH_SELECTION:
|
279 |
+
self.draw_eraser(x,y)
|
280 |
+
self.show_brush = True
|
281 |
+
elif self.show_brush:
|
282 |
+
self.canvas[-2].clear()
|
283 |
+
self.show_brush = False
|
284 |
+
self.mouse_pos[0] = x
|
285 |
+
self.mouse_pos[1] = y
|
286 |
+
|
287 |
+
self.canvas[-1].canvas.addEventListener(
|
288 |
+
"mousedown", create_proxy(handle_mouse_down)
|
289 |
+
)
|
290 |
+
self.canvas[-1].canvas.addEventListener(
|
291 |
+
"mousemove", create_proxy(handle_mouse_move)
|
292 |
+
)
|
293 |
+
self.canvas[-1].canvas.addEventListener(
|
294 |
+
"mouseup", create_proxy(handle_mouse_up)
|
295 |
+
)
|
296 |
+
self.canvas[-1].canvas.addEventListener(
|
297 |
+
"mouseout", create_proxy(handle_mouse_out)
|
298 |
+
)
|
299 |
+
async def handle_mouse_wheel(event):
|
300 |
+
x, y = get_event_pos(event)
|
301 |
+
self.mouse_pos[0] = x
|
302 |
+
self.mouse_pos[1] = y
|
303 |
+
console.log(to_js(self.mouse_pos))
|
304 |
+
if event.deltaY>10:
|
305 |
+
window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
|
306 |
+
elif event.deltaY<-10:
|
307 |
+
window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
|
308 |
+
return False
|
309 |
+
self.canvas[-1].canvas.addEventListener(
|
310 |
+
"wheel", create_proxy(handle_mouse_wheel), False
|
311 |
+
)
|
312 |
+
def clear_background(self):
|
313 |
+
# fake transparent background
|
314 |
+
h, w, step = self.height, self.width, self.grid_size
|
315 |
+
stride = step * 2
|
316 |
+
x0, y0 = self.view_pos
|
317 |
+
x0 = (-x0) % stride
|
318 |
+
y0 = (-y0) % stride
|
319 |
+
if y0>=step:
|
320 |
+
val0,val1=stride,step
|
321 |
+
else:
|
322 |
+
val0,val1=step,stride
|
323 |
+
# self.canvas.clear()
|
324 |
+
self.canvas[0].fill_style = "#ffffff"
|
325 |
+
self.canvas[0].fill_rect(0, 0, w, h)
|
326 |
+
self.canvas[0].fill_style = "#aaaaaa"
|
327 |
+
for y in range(y0-stride, h + step, step):
|
328 |
+
start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
|
329 |
+
for x in range(start, w + step, stride):
|
330 |
+
self.canvas[0].fill_rect(x, y, step, step)
|
331 |
+
self.canvas[0].stroke_rect(0, 0, w, h)
|
332 |
+
|
333 |
+
def refine_selection(self):
|
334 |
+
h,w=self.selection_size_h,self.selection_size_w
|
335 |
+
h=h//8*8
|
336 |
+
w=w//8*8
|
337 |
+
h=min(h,self.height)
|
338 |
+
w=min(w,self.width)
|
339 |
+
self.selection_size_h=h
|
340 |
+
self.selection_size_w=w
|
341 |
+
self.update_cursor(1,0)
|
342 |
+
|
343 |
+
|
344 |
+
def update_scale(self, scale, mx=-1, my=-1):
|
345 |
+
self.sync_to_data()
|
346 |
+
scaled_width=int(self.display_width*scale)
|
347 |
+
scaled_height=int(self.display_height*scale)
|
348 |
+
if max(scaled_height,scaled_width)>=self.patch_size*2-128:
|
349 |
+
return
|
350 |
+
if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
|
351 |
+
return
|
352 |
+
if mx>=0 and my>=0:
|
353 |
+
scaled_mx=mx/self.scale*scale
|
354 |
+
scaled_my=my/self.scale*scale
|
355 |
+
self.view_pos[0]+=int(mx-scaled_mx)
|
356 |
+
self.view_pos[1]+=int(my-scaled_my)
|
357 |
+
self.scale=scale
|
358 |
+
for item in self.canvas:
|
359 |
+
item.canvas.width=scaled_width
|
360 |
+
item.canvas.height=scaled_height
|
361 |
+
item.clear()
|
362 |
+
update_overlay(scaled_width,scaled_height)
|
363 |
+
self.width=scaled_width
|
364 |
+
self.height=scaled_height
|
365 |
+
self.data2buffer()
|
366 |
+
self.clear_background()
|
367 |
+
self.draw_buffer()
|
368 |
+
self.update_cursor(1,0)
|
369 |
+
self.draw_selection_box()
|
370 |
+
|
371 |
+
def update_view_pos(self, xo, yo, update=True):
|
372 |
+
# if abs(xo) + abs(yo) == 0:
|
373 |
+
# return
|
374 |
+
if self.sel_dirty:
|
375 |
+
self.write_selection_to_buffer()
|
376 |
+
if self.buffer_dirty:
|
377 |
+
self.buffer2data()
|
378 |
+
self.view_pos[0] -= xo
|
379 |
+
self.view_pos[1] -= yo
|
380 |
+
if update:
|
381 |
+
self.data2buffer()
|
382 |
+
# self.read_selection_from_buffer()
|
383 |
+
|
384 |
+
def update_cursor(self, xo, yo):
|
385 |
+
if abs(xo) + abs(yo) == 0:
|
386 |
+
return
|
387 |
+
if self.sel_dirty:
|
388 |
+
self.write_selection_to_buffer()
|
389 |
+
self.cursor[0] += xo
|
390 |
+
self.cursor[1] += yo
|
391 |
+
self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
|
392 |
+
self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
|
393 |
+
# self.read_selection_from_buffer()
|
394 |
+
|
395 |
+
def data2buffer(self):
|
396 |
+
x, y = self.view_pos
|
397 |
+
h, w = self.height, self.width
|
398 |
+
if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
|
399 |
+
self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
|
400 |
+
# fill four parts
|
401 |
+
for i in range(4):
|
402 |
+
pos_src, pos_dst, data = self.select(x, y, i)
|
403 |
+
xs0, xs1 = pos_src[0]
|
404 |
+
ys0, ys1 = pos_src[1]
|
405 |
+
xd0, xd1 = pos_dst[0]
|
406 |
+
yd0, yd1 = pos_dst[1]
|
407 |
+
self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
|
408 |
+
|
409 |
+
def data2array(self, x, y, w, h):
|
410 |
+
# x, y = self.view_pos
|
411 |
+
# h, w = self.height, self.width
|
412 |
+
ret=np.zeros((h, w, 4), dtype=np.uint8)
|
413 |
+
# fill four parts
|
414 |
+
for i in range(4):
|
415 |
+
pos_src, pos_dst, data = self.select(x, y, i, w, h)
|
416 |
+
xs0, xs1 = pos_src[0]
|
417 |
+
ys0, ys1 = pos_src[1]
|
418 |
+
xd0, xd1 = pos_dst[0]
|
419 |
+
yd0, yd1 = pos_dst[1]
|
420 |
+
ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
|
421 |
+
return ret
|
422 |
+
|
423 |
+
def buffer2data(self):
|
424 |
+
x, y = self.view_pos
|
425 |
+
h, w = self.height, self.width
|
426 |
+
# fill four parts
|
427 |
+
for i in range(4):
|
428 |
+
pos_src, pos_dst, data = self.select(x, y, i)
|
429 |
+
xs0, xs1 = pos_src[0]
|
430 |
+
ys0, ys1 = pos_src[1]
|
431 |
+
xd0, xd1 = pos_dst[0]
|
432 |
+
yd0, yd1 = pos_dst[1]
|
433 |
+
data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
|
434 |
+
self.buffer_dirty = False
|
435 |
+
|
436 |
+
def select(self, x, y, idx, width=0, height=0):
|
437 |
+
if width==0:
|
438 |
+
w, h = self.width, self.height
|
439 |
+
else:
|
440 |
+
w, h = width, height
|
441 |
+
lst = [(0, 0), (0, h), (w, 0), (w, h)]
|
442 |
+
if idx == 0:
|
443 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
444 |
+
x1 = min(x0 + w, self.patch_size)
|
445 |
+
y1 = min(y0 + h, self.patch_size)
|
446 |
+
elif idx == 1:
|
447 |
+
y += h
|
448 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
449 |
+
x1 = min(x0 + w, self.patch_size)
|
450 |
+
y1 = max(y0 - h, 0)
|
451 |
+
elif idx == 2:
|
452 |
+
x += w
|
453 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
454 |
+
x1 = max(x0 - w, 0)
|
455 |
+
y1 = min(y0 + h, self.patch_size)
|
456 |
+
else:
|
457 |
+
x += w
|
458 |
+
y += h
|
459 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
460 |
+
x1 = max(x0 - w, 0)
|
461 |
+
y1 = max(y0 - h, 0)
|
462 |
+
xi, yi = x // self.patch_size, y // self.patch_size
|
463 |
+
cur = self.data.setdefault(
|
464 |
+
(xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
|
465 |
+
)
|
466 |
+
x0_img, y0_img = lst[idx]
|
467 |
+
x1_img = x0_img + x1 - x0
|
468 |
+
y1_img = y0_img + y1 - y0
|
469 |
+
sort = lambda a, b: ((a, b) if a < b else (b, a))
|
470 |
+
return (
|
471 |
+
(sort(x0, x1), sort(y0, y1)),
|
472 |
+
(sort(x0_img, x1_img), sort(y0_img, y1_img)),
|
473 |
+
cur,
|
474 |
+
)
|
475 |
+
|
476 |
+
def draw_buffer(self):
|
477 |
+
self.canvas[1].clear()
|
478 |
+
self.canvas[1].put_image_data(self.buffer, 0, 0)
|
479 |
+
|
480 |
+
def fill_selection(self, img):
|
481 |
+
self.sel_buffer = img
|
482 |
+
self.sel_dirty = True
|
483 |
+
|
484 |
+
def draw_selection_box(self):
|
485 |
+
x0, y0 = self.cursor
|
486 |
+
w, h = self.selection_size_w, self.selection_size_h
|
487 |
+
if self.sel_dirty:
|
488 |
+
self.canvas[2].clear()
|
489 |
+
self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
|
490 |
+
self.canvas[-1].clear()
|
491 |
+
self.canvas[-1].stroke_style = "#0a0a0a"
|
492 |
+
self.canvas[-1].stroke_rect(x0, y0, w, h)
|
493 |
+
self.canvas[-1].stroke_style = "#ffffff"
|
494 |
+
offset=round(self.scale) if self.scale>1.0 else 1
|
495 |
+
self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
|
496 |
+
self.canvas[-1].stroke_style = "#000000"
|
497 |
+
self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
|
498 |
+
|
499 |
+
def write_selection_to_buffer(self):
|
500 |
+
x0, y0 = self.cursor
|
501 |
+
x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
|
502 |
+
self.buffer[y0:y1, x0:x1] = self.sel_buffer
|
503 |
+
self.sel_dirty = False
|
504 |
+
self.sel_buffer = np.zeros(
|
505 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
506 |
+
)
|
507 |
+
self.buffer_dirty = True
|
508 |
+
self.buffer_updated = True
|
509 |
+
# self.canvas[2].clear()
|
510 |
+
|
511 |
+
def read_selection_from_buffer(self):
|
512 |
+
x0, y0 = self.cursor
|
513 |
+
x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
|
514 |
+
self.sel_buffer = self.buffer[y0:y1, x0:x1]
|
515 |
+
self.sel_dirty = False
|
516 |
+
|
517 |
+
def base64_to_numpy(self, base64_str):
|
518 |
+
try:
|
519 |
+
data = base64.b64decode(str(base64_str))
|
520 |
+
pil = Image.open(io.BytesIO(data))
|
521 |
+
arr = np.array(pil)
|
522 |
+
ret = arr
|
523 |
+
except:
|
524 |
+
ret = np.tile(
|
525 |
+
np.array([255, 0, 0, 255], dtype=np.uint8),
|
526 |
+
(self.selection_size_h, self.selection_size_w, 1),
|
527 |
+
)
|
528 |
+
return ret
|
529 |
+
|
530 |
+
def numpy_to_base64(self, arr):
|
531 |
+
out_pil = Image.fromarray(arr)
|
532 |
+
out_buffer = io.BytesIO()
|
533 |
+
out_pil.save(out_buffer, format="PNG")
|
534 |
+
out_buffer.seek(0)
|
535 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
536 |
+
base64_str = base64_bytes.decode("ascii")
|
537 |
+
return base64_str
|
538 |
+
|
539 |
+
def sync_to_data(self):
|
540 |
+
if self.sel_dirty:
|
541 |
+
self.write_selection_to_buffer()
|
542 |
+
self.canvas[2].clear()
|
543 |
+
self.draw_buffer()
|
544 |
+
if self.buffer_dirty:
|
545 |
+
self.buffer2data()
|
546 |
+
|
547 |
+
def sync_to_buffer(self):
|
548 |
+
if self.sel_dirty:
|
549 |
+
self.canvas[2].clear()
|
550 |
+
self.write_selection_to_buffer()
|
551 |
+
self.draw_buffer()
|
552 |
+
|
553 |
+
def resize(self,width,height,scale=None,**kwargs):
|
554 |
+
self.display_width=width
|
555 |
+
self.display_height=height
|
556 |
+
for canvas in self.canvas:
|
557 |
+
prepare_canvas(width=width,height=height,canvas=canvas.canvas)
|
558 |
+
setup_overlay(width,height)
|
559 |
+
if scale is None:
|
560 |
+
scale=1
|
561 |
+
self.update_scale(scale)
|
562 |
+
|
563 |
+
|
564 |
+
def save(self):
|
565 |
+
self.sync_to_data()
|
566 |
+
state={}
|
567 |
+
state["width"]=self.display_width
|
568 |
+
state["height"]=self.display_height
|
569 |
+
state["selection_width"]=self.selection_size_w
|
570 |
+
state["selection_height"]=self.selection_size_h
|
571 |
+
state["view_pos"]=self.view_pos[:]
|
572 |
+
state["cursor"]=self.cursor[:]
|
573 |
+
state["scale"]=self.scale
|
574 |
+
keys=list(self.data.keys())
|
575 |
+
data={}
|
576 |
+
for key in keys:
|
577 |
+
if self.data[key].sum()>0:
|
578 |
+
data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
|
579 |
+
state["data"]=data
|
580 |
+
return json.dumps(state)
|
581 |
+
|
582 |
+
def load(self, state_json):
|
583 |
+
self.reset()
|
584 |
+
state=json.loads(state_json)
|
585 |
+
self.display_width=state["width"]
|
586 |
+
self.display_height=state["height"]
|
587 |
+
self.selection_size_w=state["selection_width"]
|
588 |
+
self.selection_size_h=state["selection_height"]
|
589 |
+
self.view_pos=state["view_pos"][:]
|
590 |
+
self.cursor=state["cursor"][:]
|
591 |
+
self.scale=state["scale"]
|
592 |
+
self.resize(state["width"],state["height"],scale=state["scale"])
|
593 |
+
for k,v in state["data"].items():
|
594 |
+
key=tuple(map(int,k.split(",")))
|
595 |
+
self.data[key]=self.base64_to_numpy(v)
|
596 |
+
self.data2buffer()
|
597 |
+
self.display()
|
598 |
+
|
599 |
+
def display(self):
|
600 |
+
self.clear_background()
|
601 |
+
self.draw_buffer()
|
602 |
+
self.draw_selection_box()
|
603 |
+
|
604 |
+
def reset(self):
|
605 |
+
self.data.clear()
|
606 |
+
self.buffer*=0
|
607 |
+
self.buffer_dirty=False
|
608 |
+
self.buffer_updated=False
|
609 |
+
self.sel_buffer*=0
|
610 |
+
self.sel_dirty=False
|
611 |
+
self.view_pos = [0, 0]
|
612 |
+
self.clear_background()
|
613 |
+
for i in range(1,len(self.canvas)-1):
|
614 |
+
self.canvas[i].clear()
|
615 |
+
|
616 |
+
def export(self):
|
617 |
+
self.sync_to_data()
|
618 |
+
xmin, xmax, ymin, ymax = 0, 0, 0, 0
|
619 |
+
if len(self.data.keys()) == 0:
|
620 |
+
return np.zeros(
|
621 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
622 |
+
)
|
623 |
+
for xi, yi in self.data.keys():
|
624 |
+
buf = self.data[(xi, yi)]
|
625 |
+
if buf.sum() > 0:
|
626 |
+
xmin = min(xi, xmin)
|
627 |
+
xmax = max(xi, xmax)
|
628 |
+
ymin = min(yi, ymin)
|
629 |
+
ymax = max(yi, ymax)
|
630 |
+
yn = ymax - ymin + 1
|
631 |
+
xn = xmax - xmin + 1
|
632 |
+
image = np.zeros(
|
633 |
+
(yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
|
634 |
+
)
|
635 |
+
for xi, yi in self.data.keys():
|
636 |
+
buf = self.data[(xi, yi)]
|
637 |
+
if buf.sum() > 0:
|
638 |
+
y0 = (yi - ymin) * self.patch_size
|
639 |
+
x0 = (xi - xmin) * self.patch_size
|
640 |
+
image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
|
641 |
+
ylst, xlst = image[:, :, -1].nonzero()
|
642 |
+
if len(ylst) > 0:
|
643 |
+
yt, xt = ylst.min(), xlst.min()
|
644 |
+
yb, xb = ylst.max(), xlst.max()
|
645 |
+
image = image[yt : yb + 1, xt : xb + 1]
|
646 |
+
return image
|
647 |
+
else:
|
648 |
+
return np.zeros(
|
649 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
650 |
+
)
|
index.html
CHANGED
@@ -1,214 +1,411 @@
|
|
1 |
-
<html>
|
2 |
-
<head>
|
3 |
-
<title>Stablediffusion Infinity</title>
|
4 |
-
<meta charset="utf-8">
|
5 |
-
<link rel="icon" type="image/x-icon" href="./favicon.png">
|
6 |
-
|
7 |
-
<
|
8 |
-
<
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
<
|
34 |
-
<
|
35 |
-
<
|
36 |
-
<
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
</
|
50 |
-
</div>
|
51 |
-
<div>
|
52 |
-
<
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
</div>
|
59 |
-
</
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
<
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
from
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
async def
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
document.querySelector("
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
)
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
<head>
|
3 |
+
<title>Stablediffusion Infinity</title>
|
4 |
+
<meta charset="utf-8">
|
5 |
+
<link rel="icon" type="image/x-icon" href="./favicon.png">
|
6 |
+
|
7 |
+
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/css/w2ui.min.css">
|
8 |
+
<script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/w2ui.min.js"></script>
|
9 |
+
<link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
|
10 |
+
<script src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/fabric.min.js"></script>
|
11 |
+
<script defer src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@master/js/toolbar.js"></script>
|
12 |
+
|
13 |
+
<link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
|
14 |
+
<script defer src="https://pyscript.net/alpha/pyscript.js"></script>
|
15 |
+
|
16 |
+
<style>
|
17 |
+
#container {
|
18 |
+
position: relative;
|
19 |
+
margin:auto;
|
20 |
+
display: block;
|
21 |
+
}
|
22 |
+
#container > canvas {
|
23 |
+
position: absolute;
|
24 |
+
top: 0;
|
25 |
+
left: 0;
|
26 |
+
}
|
27 |
+
.control {
|
28 |
+
display: none;
|
29 |
+
}
|
30 |
+
</style>
|
31 |
+
|
32 |
+
</head>
|
33 |
+
<body>
|
34 |
+
<div>
|
35 |
+
<button type="button" class="control" id="export">Export</button>
|
36 |
+
<button type="button" class="control" id="outpaint">Outpaint</button>
|
37 |
+
<button type="button" class="control" id="undo">Undo</button>
|
38 |
+
<button type="button" class="control" id="commit">Commit</button>
|
39 |
+
<button type="button" class="control" id="transfer">Transfer</button>
|
40 |
+
<button type="button" class="control" id="upload">Upload</button>
|
41 |
+
<button type="button" class="control" id="draw">Draw</button>
|
42 |
+
<input type="text" id="mode" value="selection" class="control">
|
43 |
+
<input type="text" id="setup" value="0" class="control">
|
44 |
+
<input type="text" id="upload_content" value="0" class="control">
|
45 |
+
<textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
|
46 |
+
<fieldset class="control">
|
47 |
+
<div>
|
48 |
+
<input type="radio" id="mode0" name="mode" value="0" checked>
|
49 |
+
<label for="mode0">SelBox</label>
|
50 |
+
</div>
|
51 |
+
<div>
|
52 |
+
<input type="radio" id="mode1" name="mode" value="1">
|
53 |
+
<label for="mode1">Image</label>
|
54 |
+
</div>
|
55 |
+
<div>
|
56 |
+
<input type="radio" id="mode2" name="mode" value="2">
|
57 |
+
<label for="mode2">Brush</label>
|
58 |
+
</div>
|
59 |
+
</fieldset>
|
60 |
+
</div>
|
61 |
+
<div id = "outer_container">
|
62 |
+
<div id = "container">
|
63 |
+
<canvas id = "canvas0"></canvas>
|
64 |
+
<canvas id = "canvas1"></canvas>
|
65 |
+
<canvas id = "canvas2"></canvas>
|
66 |
+
<canvas id = "canvas3"></canvas>
|
67 |
+
<canvas id = "canvas4"></canvas>
|
68 |
+
<div id="overlay_container" style="pointer-events: none">
|
69 |
+
<canvas id = "overlay_canvas" width="1" height="1"></canvas>
|
70 |
+
</div>
|
71 |
+
</div>
|
72 |
+
<input type="file" name="file" id="upload_file" accept="image/*" hidden>
|
73 |
+
<input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
|
74 |
+
<div style="position: relative;">
|
75 |
+
<div id="toolbar" style></div>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
<py-env>
|
79 |
+
- numpy
|
80 |
+
- Pillow
|
81 |
+
- paths:
|
82 |
+
- ./canvas.py
|
83 |
+
</py-env>
|
84 |
+
|
85 |
+
<py-script>
|
86 |
+
from pyodide import to_js, create_proxy
|
87 |
+
from PIL import Image
|
88 |
+
import io
|
89 |
+
import time
|
90 |
+
import base64
|
91 |
+
import numpy as np
|
92 |
+
from js import (
|
93 |
+
console,
|
94 |
+
document,
|
95 |
+
parent,
|
96 |
+
devicePixelRatio,
|
97 |
+
ImageData,
|
98 |
+
Uint8ClampedArray,
|
99 |
+
CanvasRenderingContext2D as Context2d,
|
100 |
+
requestAnimationFrame,
|
101 |
+
window,
|
102 |
+
encodeURIComponent,
|
103 |
+
w2ui,
|
104 |
+
update_eraser,
|
105 |
+
update_scale,
|
106 |
+
adjust_selection,
|
107 |
+
update_count,
|
108 |
+
enable_result_lst,
|
109 |
+
setup_shortcut,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
from canvas import InfCanvas
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
base_lst = [None]
|
118 |
+
async def draw_canvas() -> None:
|
119 |
+
width=1024
|
120 |
+
height=600
|
121 |
+
canvas=InfCanvas(1024,600)
|
122 |
+
update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
|
123 |
+
document.querySelector("#container").style.height= f"{height}px"
|
124 |
+
document.querySelector("#container").style.width = f"{width}px"
|
125 |
+
canvas.setup_mouse()
|
126 |
+
canvas.clear_background()
|
127 |
+
canvas.draw_buffer()
|
128 |
+
canvas.draw_selection_box()
|
129 |
+
base_lst[0]=canvas
|
130 |
+
|
131 |
+
async def draw_canvas_func(event):
|
132 |
+
try:
|
133 |
+
app=parent.document.querySelector("gradio-app")
|
134 |
+
if app.shadowRoot:
|
135 |
+
app=app.shadowRoot
|
136 |
+
width=app.querySelector("#canvas_width input").value
|
137 |
+
height=app.querySelector("#canvas_height input").value
|
138 |
+
selection_size=app.querySelector("#selection_size input").value
|
139 |
+
except:
|
140 |
+
width=1024
|
141 |
+
height=768
|
142 |
+
selection_size=384
|
143 |
+
document.querySelector("#container").style.width = f"{width}px"
|
144 |
+
document.querySelector("#container").style.height= f"{height}px"
|
145 |
+
canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
|
146 |
+
canvas.setup_mouse()
|
147 |
+
canvas.clear_background()
|
148 |
+
canvas.draw_buffer()
|
149 |
+
canvas.draw_selection_box()
|
150 |
+
base_lst[0]=canvas
|
151 |
+
|
152 |
+
async def export_func(event):
|
153 |
+
base=base_lst[0]
|
154 |
+
arr=base.export()
|
155 |
+
base.draw_buffer()
|
156 |
+
base.canvas[2].clear()
|
157 |
+
base64_str = base.numpy_to_base64(arr)
|
158 |
+
time_str = time.strftime("%Y%m%d_%H%M%S")
|
159 |
+
link = document.createElement("a")
|
160 |
+
if len(event.data)>2 and event.data[2]:
|
161 |
+
filename = event.data[2]
|
162 |
+
else:
|
163 |
+
filename = f"outpaint_{time_str}"
|
164 |
+
# link.download = f"sdinf_state_{time_str}.json"
|
165 |
+
link.download = f"{filename}.png"
|
166 |
+
# link.download = f"outpaint_{time_str}.png"
|
167 |
+
link.href = "data:image/png;base64,"+base64_str
|
168 |
+
link.click()
|
169 |
+
console.log(f"Canvas saved to {filename}.png")
|
170 |
+
|
171 |
+
img_candidate_lst=[None,0]
|
172 |
+
|
173 |
+
async def outpaint_func(event):
|
174 |
+
base=base_lst[0]
|
175 |
+
if len(event.data)==2:
|
176 |
+
app=parent.document.querySelector("gradio-app")
|
177 |
+
if app.shadowRoot:
|
178 |
+
app=app.shadowRoot
|
179 |
+
base64_str_raw=app.querySelector("#output textarea").value
|
180 |
+
base64_str_lst=base64_str_raw.split(",")
|
181 |
+
img_candidate_lst[0]=base64_str_lst
|
182 |
+
img_candidate_lst[1]=0
|
183 |
+
elif event.data[2]=="next":
|
184 |
+
img_candidate_lst[1]+=1
|
185 |
+
elif event.data[2]=="prev":
|
186 |
+
img_candidate_lst[1]-=1
|
187 |
+
enable_result_lst()
|
188 |
+
if img_candidate_lst[0] is None:
|
189 |
+
return
|
190 |
+
lst=img_candidate_lst[0]
|
191 |
+
idx=img_candidate_lst[1]
|
192 |
+
update_count(idx%len(lst)+1,len(lst))
|
193 |
+
arr=base.base64_to_numpy(lst[idx%len(lst)])
|
194 |
+
base.fill_selection(arr)
|
195 |
+
base.draw_selection_box()
|
196 |
+
|
197 |
+
async def undo_func(event):
|
198 |
+
base=base_lst[0]
|
199 |
+
img_candidate_lst[0]=None
|
200 |
+
if base.sel_dirty:
|
201 |
+
base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
|
202 |
+
base.sel_dirty = False
|
203 |
+
base.canvas[2].clear()
|
204 |
+
|
205 |
+
async def commit_func(event):
|
206 |
+
base=base_lst[0]
|
207 |
+
img_candidate_lst[0]=None
|
208 |
+
if base.sel_dirty:
|
209 |
+
base.write_selection_to_buffer()
|
210 |
+
base.draw_buffer()
|
211 |
+
base.canvas[2].clear()
|
212 |
+
|
213 |
+
async def transfer_func(event):
|
214 |
+
base=base_lst[0]
|
215 |
+
base.read_selection_from_buffer()
|
216 |
+
sel_buffer=base.sel_buffer
|
217 |
+
sel_buffer_str=base.numpy_to_base64(sel_buffer)
|
218 |
+
app=parent.document.querySelector("gradio-app")
|
219 |
+
if app.shadowRoot:
|
220 |
+
app=app.shadowRoot
|
221 |
+
app.querySelector("#input textarea").value=sel_buffer_str
|
222 |
+
app.querySelector("#proceed").click()
|
223 |
+
|
224 |
+
async def upload_func(event):
|
225 |
+
base=base_lst[0]
|
226 |
+
# base64_str=event.data[1]
|
227 |
+
base64_str=document.querySelector("#upload_content").value
|
228 |
+
base64_str=base64_str.split(",")[-1]
|
229 |
+
# base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
|
230 |
+
arr=base.base64_to_numpy(base64_str)
|
231 |
+
h,w,c=base.buffer.shape
|
232 |
+
base.sync_to_buffer()
|
233 |
+
base.buffer_dirty=True
|
234 |
+
mask=arr[:,:,3:4].repeat(4,axis=2)
|
235 |
+
base.buffer[mask>0]=0
|
236 |
+
# in case mismatch
|
237 |
+
base.buffer[0:h,0:w,:]+=arr
|
238 |
+
#base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
|
239 |
+
#base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
|
240 |
+
base.draw_buffer()
|
241 |
+
|
242 |
+
async def setup_shortcut_func(event):
|
243 |
+
setup_shortcut(event.data[1])
|
244 |
+
|
245 |
+
|
246 |
+
document.querySelector("#export").addEventListener("click",create_proxy(export_func))
|
247 |
+
document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
|
248 |
+
document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
|
249 |
+
document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
|
250 |
+
document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
|
251 |
+
|
252 |
+
document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
|
253 |
+
document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
|
254 |
+
|
255 |
+
async def setup_func():
|
256 |
+
document.querySelector("#setup").value="1"
|
257 |
+
|
258 |
+
async def reset_func(event):
|
259 |
+
base=base_lst[0]
|
260 |
+
base.reset()
|
261 |
+
|
262 |
+
async def load_func(event):
|
263 |
+
base=base_lst[0]
|
264 |
+
base.load(event.data[1])
|
265 |
+
|
266 |
+
async def save_func(event):
|
267 |
+
base=base_lst[0]
|
268 |
+
json_str=base.save()
|
269 |
+
time_str = time.strftime("%Y%m%d_%H%M%S")
|
270 |
+
link = document.createElement("a")
|
271 |
+
if len(event.data)>2 and event.data[2]:
|
272 |
+
filename = str(event.data[2]).strip()
|
273 |
+
else:
|
274 |
+
filename = f"outpaint_{time_str}"
|
275 |
+
# link.download = f"sdinf_state_{time_str}.json"
|
276 |
+
link.download = f"{filename}.sdinf"
|
277 |
+
link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
|
278 |
+
link.click()
|
279 |
+
|
280 |
+
async def prev_result_func(event):
|
281 |
+
base=base_lst[0]
|
282 |
+
base.reset()
|
283 |
+
|
284 |
+
async def next_result_func(event):
|
285 |
+
base=base_lst[0]
|
286 |
+
base.reset()
|
287 |
+
|
288 |
+
async def zoom_in_func(event):
|
289 |
+
base=base_lst[0]
|
290 |
+
scale=base.scale
|
291 |
+
if scale>=0.2:
|
292 |
+
scale-=0.1
|
293 |
+
if len(event.data)>2:
|
294 |
+
base.update_scale(scale,int(event.data[2]),int(event.data[3]))
|
295 |
+
else:
|
296 |
+
base.update_scale(scale)
|
297 |
+
scale=base.scale
|
298 |
+
update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
|
299 |
+
|
300 |
+
async def zoom_out_func(event):
|
301 |
+
base=base_lst[0]
|
302 |
+
scale=base.scale
|
303 |
+
if scale<10:
|
304 |
+
scale+=0.1
|
305 |
+
console.log(len(event.data))
|
306 |
+
if len(event.data)>2:
|
307 |
+
base.update_scale(scale,int(event.data[2]),int(event.data[3]))
|
308 |
+
else:
|
309 |
+
base.update_scale(scale)
|
310 |
+
scale=base.scale
|
311 |
+
update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
|
312 |
+
|
313 |
+
async def sync_func(event):
|
314 |
+
base=base_lst[0]
|
315 |
+
base.sync_to_buffer()
|
316 |
+
base.canvas[2].clear()
|
317 |
+
|
318 |
+
async def eraser_size_func(event):
|
319 |
+
base=base_lst[0]
|
320 |
+
eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
|
321 |
+
eraser_size=max(8,eraser_size)
|
322 |
+
base.eraser_size=eraser_size
|
323 |
+
|
324 |
+
async def resize_selection_func(event):
|
325 |
+
base=base_lst[0]
|
326 |
+
cursor=base.cursor
|
327 |
+
if len(event.data)>3:
|
328 |
+
console.log(event.data)
|
329 |
+
base.cursor[0]=int(event.data[1])
|
330 |
+
base.cursor[1]=int(event.data[2])
|
331 |
+
base.selection_size_w=int(event.data[3])//8*8
|
332 |
+
base.selection_size_h=int(event.data[4])//8*8
|
333 |
+
base.refine_selection()
|
334 |
+
base.draw_selection_box()
|
335 |
+
elif len(event.data)>2:
|
336 |
+
base.draw_selection_box()
|
337 |
+
else:
|
338 |
+
base.canvas[-1].clear()
|
339 |
+
adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
|
340 |
+
|
341 |
+
async def eraser_func(event):
|
342 |
+
base=base_lst[0]
|
343 |
+
if event.data[1]!="eraser":
|
344 |
+
base.canvas[-2].clear()
|
345 |
+
else:
|
346 |
+
x,y=base.mouse_pos
|
347 |
+
base.draw_eraser(x,y)
|
348 |
+
|
349 |
+
async def resize_func(event):
|
350 |
+
base=base_lst[0]
|
351 |
+
width=int(event.data[1])
|
352 |
+
height=int(event.data[2])
|
353 |
+
if width>=256 and height>=256:
|
354 |
+
if max(base.selection_size_h,base.selection_size_w)>min(width,height):
|
355 |
+
base.selection_size_h=256
|
356 |
+
base.selection_size_w=256
|
357 |
+
base.resize(width,height)
|
358 |
+
|
359 |
+
async def message_func(event):
|
360 |
+
if event.data[0]=="click":
|
361 |
+
if event.data[1]=="clear":
|
362 |
+
await reset_func(event)
|
363 |
+
elif event.data[1]=="save":
|
364 |
+
await save_func(event)
|
365 |
+
elif event.data[1]=="export":
|
366 |
+
await export_func(event)
|
367 |
+
elif event.data[1]=="accept":
|
368 |
+
await commit_func(event)
|
369 |
+
elif event.data[1]=="cancel":
|
370 |
+
await undo_func(event)
|
371 |
+
elif event.data[1]=="zoom_in":
|
372 |
+
await zoom_in_func(event)
|
373 |
+
elif event.data[1]=="zoom_out":
|
374 |
+
await zoom_out_func(event)
|
375 |
+
elif event.data[0]=="sync":
|
376 |
+
await sync_func(event)
|
377 |
+
elif event.data[0]=="load":
|
378 |
+
await load_func(event)
|
379 |
+
elif event.data[0]=="upload":
|
380 |
+
await upload_func(event)
|
381 |
+
elif event.data[0]=="outpaint":
|
382 |
+
await outpaint_func(event)
|
383 |
+
elif event.data[0]=="mode":
|
384 |
+
if event.data[1]!="selection":
|
385 |
+
await sync_func(event)
|
386 |
+
await eraser_func(event)
|
387 |
+
document.querySelector("#mode").value=event.data[1]
|
388 |
+
elif event.data[0]=="transfer":
|
389 |
+
await transfer_func(event)
|
390 |
+
elif event.data[0]=="setup":
|
391 |
+
await draw_canvas_func(event)
|
392 |
+
elif event.data[0]=="eraser_size":
|
393 |
+
await eraser_size_func(event)
|
394 |
+
elif event.data[0]=="resize_selection":
|
395 |
+
await resize_selection_func(event)
|
396 |
+
elif event.data[0]=="shortcut":
|
397 |
+
await setup_shortcut_func(event)
|
398 |
+
elif event.data[0]=="resize":
|
399 |
+
await resize_func(event)
|
400 |
+
|
401 |
+
window.addEventListener("message",create_proxy(message_func))
|
402 |
+
|
403 |
+
import asyncio
|
404 |
+
|
405 |
+
_ = await asyncio.gather(
|
406 |
+
setup_func()
|
407 |
+
)
|
408 |
+
</py-script>
|
409 |
+
|
410 |
+
</body>
|
411 |
+
</html>
|
perlin2d.py
CHANGED
@@ -1,45 +1,45 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
##########
|
4 |
-
# https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
|
5 |
-
def perlin(x, y, seed=0):
|
6 |
-
# permutation table
|
7 |
-
np.random.seed(seed)
|
8 |
-
p = np.arange(256, dtype=int)
|
9 |
-
np.random.shuffle(p)
|
10 |
-
p = np.stack([p, p]).flatten()
|
11 |
-
# coordinates of the top-left
|
12 |
-
xi, yi = x.astype(int), y.astype(int)
|
13 |
-
# internal coordinates
|
14 |
-
xf, yf = x - xi, y - yi
|
15 |
-
# fade factors
|
16 |
-
u, v = fade(xf), fade(yf)
|
17 |
-
# noise components
|
18 |
-
n00 = gradient(p[p[xi] + yi], xf, yf)
|
19 |
-
n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
20 |
-
n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
21 |
-
n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
22 |
-
# combine noises
|
23 |
-
x1 = lerp(n00, n10, u)
|
24 |
-
x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
25 |
-
return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
26 |
-
|
27 |
-
|
28 |
-
def lerp(a, b, x):
|
29 |
-
"linear interpolation"
|
30 |
-
return a + x * (b - a)
|
31 |
-
|
32 |
-
|
33 |
-
def fade(t):
|
34 |
-
"6t^5 - 15t^4 + 10t^3"
|
35 |
-
return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
|
36 |
-
|
37 |
-
|
38 |
-
def gradient(h, x, y):
|
39 |
-
"grad converts h to the right gradient vector and return the dot product with (x,y)"
|
40 |
-
vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
41 |
-
g = vectors[h % 4]
|
42 |
-
return g[:, :, 0] * x + g[:, :, 1] * y
|
43 |
-
|
44 |
-
|
45 |
##########
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
##########
|
4 |
+
# https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
|
5 |
+
def perlin(x, y, seed=0):
|
6 |
+
# permutation table
|
7 |
+
np.random.seed(seed)
|
8 |
+
p = np.arange(256, dtype=int)
|
9 |
+
np.random.shuffle(p)
|
10 |
+
p = np.stack([p, p]).flatten()
|
11 |
+
# coordinates of the top-left
|
12 |
+
xi, yi = x.astype(int), y.astype(int)
|
13 |
+
# internal coordinates
|
14 |
+
xf, yf = x - xi, y - yi
|
15 |
+
# fade factors
|
16 |
+
u, v = fade(xf), fade(yf)
|
17 |
+
# noise components
|
18 |
+
n00 = gradient(p[p[xi] + yi], xf, yf)
|
19 |
+
n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
20 |
+
n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
21 |
+
n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
22 |
+
# combine noises
|
23 |
+
x1 = lerp(n00, n10, u)
|
24 |
+
x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
25 |
+
return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
26 |
+
|
27 |
+
|
28 |
+
def lerp(a, b, x):
|
29 |
+
"linear interpolation"
|
30 |
+
return a + x * (b - a)
|
31 |
+
|
32 |
+
|
33 |
+
def fade(t):
|
34 |
+
"6t^5 - 15t^4 + 10t^3"
|
35 |
+
return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
|
36 |
+
|
37 |
+
|
38 |
+
def gradient(h, x, y):
|
39 |
+
"grad converts h to the right gradient vector and return the dot product with (x,y)"
|
40 |
+
vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
41 |
+
g = vectors[h % 4]
|
42 |
+
return g[:, :, 0] * x + g[:, :, 1] * y
|
43 |
+
|
44 |
+
|
45 |
##########
|
postprocess.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2022 Jiayi Weng
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import time
|
27 |
+
import argparse
|
28 |
+
import os
|
29 |
+
import fpie
|
30 |
+
from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
|
31 |
+
from fpie.io import read_images, write_image
|
32 |
+
from process import BaseProcessor, EquProcessor, GridProcessor
|
33 |
+
|
34 |
+
from PIL import Image
|
35 |
+
import numpy as np
|
36 |
+
import skimage
|
37 |
+
import skimage.measure
|
38 |
+
import scipy
|
39 |
+
import scipy.signal
|
40 |
+
|
41 |
+
|
42 |
+
class PhotometricCorrection:
|
43 |
+
def __init__(self,quite=False):
|
44 |
+
self.get_parser("cli")
|
45 |
+
args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
|
46 |
+
args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
|
47 |
+
self.backend=args.backend
|
48 |
+
self.args=args
|
49 |
+
self.quite=quite
|
50 |
+
proc: BaseProcessor
|
51 |
+
proc = GridProcessor(
|
52 |
+
args.gradient,
|
53 |
+
args.backend,
|
54 |
+
args.cpu,
|
55 |
+
args.mpi_sync_interval,
|
56 |
+
args.block_size,
|
57 |
+
args.grid_x,
|
58 |
+
args.grid_y,
|
59 |
+
)
|
60 |
+
print(
|
61 |
+
f"[PIE]Successfully initialize PIE {args.method} solver "
|
62 |
+
f"with {args.backend} backend"
|
63 |
+
)
|
64 |
+
self.proc=proc
|
65 |
+
|
66 |
+
def run(self, original_image, inpainted_image, mode="mask_mode"):
|
67 |
+
print(f"[PIE] start")
|
68 |
+
if mode=="disabled":
|
69 |
+
return inpainted_image
|
70 |
+
input_arr=np.array(original_image)
|
71 |
+
if input_arr[:,:,-1].sum()<1:
|
72 |
+
return inpainted_image
|
73 |
+
output_arr=np.array(inpainted_image)
|
74 |
+
mask=input_arr[:,:,-1]
|
75 |
+
mask=255-mask
|
76 |
+
if mask.sum()<1 and mode=="mask_mode":
|
77 |
+
mode=""
|
78 |
+
if mode=="mask_mode":
|
79 |
+
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
|
80 |
+
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
|
81 |
+
else:
|
82 |
+
mask[8:-9,8:-9]=255
|
83 |
+
mask = mask[:,:,np.newaxis].repeat(3,axis=2)
|
84 |
+
nmask=mask.copy()
|
85 |
+
output_arr2=output_arr[:,:,0:3].copy()
|
86 |
+
input_arr2=input_arr[:,:,0:3].copy()
|
87 |
+
output_arr2[nmask<128]=0
|
88 |
+
input_arr2[nmask>=128]=0
|
89 |
+
output_arr2+=input_arr2
|
90 |
+
src = output_arr2[:,:,0:3]
|
91 |
+
tgt = src.copy()
|
92 |
+
proc=self.proc
|
93 |
+
args=self.args
|
94 |
+
if proc.root:
|
95 |
+
n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
|
96 |
+
proc.sync()
|
97 |
+
if proc.root:
|
98 |
+
result = tgt
|
99 |
+
t = time.time()
|
100 |
+
if args.p == 0:
|
101 |
+
args.p = args.n
|
102 |
+
|
103 |
+
for i in range(0, args.n, args.p):
|
104 |
+
if proc.root:
|
105 |
+
result, err = proc.step(args.p) # type: ignore
|
106 |
+
print(f"[PIE] Iter {i + args.p}, abs_err {err}")
|
107 |
+
else:
|
108 |
+
proc.step(args.p)
|
109 |
+
|
110 |
+
if proc.root:
|
111 |
+
dt = time.time() - t
|
112 |
+
print(f"[PIE] Time elapsed: {dt:.4f}s")
|
113 |
+
# make sure consistent with dummy process
|
114 |
+
return Image.fromarray(result)
|
115 |
+
|
116 |
+
|
117 |
+
def get_parser(self,gen_type: str) -> argparse.Namespace:
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
parser.add_argument(
|
120 |
+
"-v", "--version", action="store_true", help="show the version and exit"
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--check-backend", action="store_true", help="print all available backends"
|
124 |
+
)
|
125 |
+
if gen_type == "gui" and "mpi" in ALL_BACKEND:
|
126 |
+
# gui doesn't support MPI backend
|
127 |
+
ALL_BACKEND.remove("mpi")
|
128 |
+
parser.add_argument(
|
129 |
+
"-b",
|
130 |
+
"--backend",
|
131 |
+
type=str,
|
132 |
+
choices=ALL_BACKEND,
|
133 |
+
default=DEFAULT_BACKEND,
|
134 |
+
help="backend choice",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"-c",
|
138 |
+
"--cpu",
|
139 |
+
type=int,
|
140 |
+
default=CPU_COUNT,
|
141 |
+
help="number of CPU used",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"-z",
|
145 |
+
"--block-size",
|
146 |
+
type=int,
|
147 |
+
default=1024,
|
148 |
+
help="cuda block size (only for equ solver)",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--method",
|
152 |
+
type=str,
|
153 |
+
choices=["equ", "grid"],
|
154 |
+
default="equ",
|
155 |
+
help="how to parallelize computation",
|
156 |
+
)
|
157 |
+
parser.add_argument("-s", "--source", type=str, help="source image filename")
|
158 |
+
if gen_type == "cli":
|
159 |
+
parser.add_argument(
|
160 |
+
"-m",
|
161 |
+
"--mask",
|
162 |
+
type=str,
|
163 |
+
help="mask image filename (default is to use the whole source image)",
|
164 |
+
default="",
|
165 |
+
)
|
166 |
+
parser.add_argument("-t", "--target", type=str, help="target image filename")
|
167 |
+
parser.add_argument("-o", "--output", type=str, help="output image filename")
|
168 |
+
if gen_type == "cli":
|
169 |
+
parser.add_argument(
|
170 |
+
"-h0", type=int, help="mask position (height) on source image", default=0
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"-w0", type=int, help="mask position (width) on source image", default=0
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"-h1", type=int, help="mask position (height) on target image", default=0
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"-w1", type=int, help="mask position (width) on target image", default=0
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"-g",
|
183 |
+
"--gradient",
|
184 |
+
type=str,
|
185 |
+
choices=["max", "src", "avg"],
|
186 |
+
default="max",
|
187 |
+
help="how to calculate gradient for PIE",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"-n",
|
191 |
+
type=int,
|
192 |
+
help="how many iteration would you perfer, the more the better",
|
193 |
+
default=5000,
|
194 |
+
)
|
195 |
+
if gen_type == "cli":
|
196 |
+
parser.add_argument(
|
197 |
+
"-p", type=int, help="output result every P iteration", default=0
|
198 |
+
)
|
199 |
+
if "mpi" in ALL_BACKEND:
|
200 |
+
parser.add_argument(
|
201 |
+
"--mpi-sync-interval",
|
202 |
+
type=int,
|
203 |
+
help="MPI sync iteration interval",
|
204 |
+
default=100,
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--grid-x", type=int, help="x axis stride for grid solver", default=8
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--grid-y", type=int, help="y axis stride for grid solver", default=8
|
211 |
+
)
|
212 |
+
self.parser=parser
|
213 |
+
|
214 |
+
if __name__ =="__main__":
|
215 |
+
import sys
|
216 |
+
import io
|
217 |
+
import base64
|
218 |
+
from PIL import Image
|
219 |
+
def base64_to_pil(base64_str):
|
220 |
+
data = base64.b64decode(str(base64_str))
|
221 |
+
pil = Image.open(io.BytesIO(data))
|
222 |
+
return pil
|
223 |
+
|
224 |
+
def pil_to_base64(out_pil):
|
225 |
+
out_buffer = io.BytesIO()
|
226 |
+
out_pil.save(out_buffer, format="PNG")
|
227 |
+
out_buffer.seek(0)
|
228 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
229 |
+
base64_str = base64_bytes.decode("ascii")
|
230 |
+
return base64_str
|
231 |
+
correction_func=PhotometricCorrection(quite=True)
|
232 |
+
while True:
|
233 |
+
buffer = sys.stdin.readline()
|
234 |
+
print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
|
235 |
+
if len(buffer)==0:
|
236 |
+
break
|
237 |
+
if isinstance(buffer,str):
|
238 |
+
lst=buffer.strip().split(",")
|
239 |
+
else:
|
240 |
+
lst=buffer.decode("ascii").strip().split(",")
|
241 |
+
img0=base64_to_pil(lst[0])
|
242 |
+
img1=base64_to_pil(lst[1])
|
243 |
+
ret=correction_func.run(img0,img1,mode=lst[2])
|
244 |
+
ret_base64=pil_to_base64(ret)
|
245 |
+
if isinstance(buffer,str):
|
246 |
+
sys.stdout.write(f"{ret_base64}\n")
|
247 |
+
else:
|
248 |
+
sys.stdout.write(f"{ret_base64}\n".encode())
|
249 |
+
sys.stdout.flush()
|
process.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2022 Jiayi Weng
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
import os
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from typing import Any, Optional, Tuple
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
|
31 |
+
from fpie import np_solver
|
32 |
+
|
33 |
+
import scipy
|
34 |
+
import scipy.signal
|
35 |
+
|
36 |
+
CPU_COUNT = os.cpu_count() or 1
|
37 |
+
DEFAULT_BACKEND = "numpy"
|
38 |
+
ALL_BACKEND = ["numpy"]
|
39 |
+
|
40 |
+
try:
|
41 |
+
from fpie import numba_solver
|
42 |
+
ALL_BACKEND += ["numba"]
|
43 |
+
DEFAULT_BACKEND = "numba"
|
44 |
+
except ImportError:
|
45 |
+
numba_solver = None # type: ignore
|
46 |
+
|
47 |
+
try:
|
48 |
+
from fpie import taichi_solver
|
49 |
+
ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
|
50 |
+
DEFAULT_BACKEND = "taichi-cpu"
|
51 |
+
except ImportError:
|
52 |
+
taichi_solver = None # type: ignore
|
53 |
+
|
54 |
+
# try:
|
55 |
+
# from fpie import core_gcc # type: ignore
|
56 |
+
# DEFAULT_BACKEND = "gcc"
|
57 |
+
# ALL_BACKEND.append("gcc")
|
58 |
+
# except ImportError:
|
59 |
+
# core_gcc = None
|
60 |
+
|
61 |
+
# try:
|
62 |
+
# from fpie import core_openmp # type: ignore
|
63 |
+
# DEFAULT_BACKEND = "openmp"
|
64 |
+
# ALL_BACKEND.append("openmp")
|
65 |
+
# except ImportError:
|
66 |
+
# core_openmp = None
|
67 |
+
|
68 |
+
# try:
|
69 |
+
# from mpi4py import MPI
|
70 |
+
|
71 |
+
# from fpie import core_mpi # type: ignore
|
72 |
+
# ALL_BACKEND.append("mpi")
|
73 |
+
# except ImportError:
|
74 |
+
# MPI = None # type: ignore
|
75 |
+
# core_mpi = None
|
76 |
+
|
77 |
+
try:
|
78 |
+
from fpie import core_cuda # type: ignore
|
79 |
+
DEFAULT_BACKEND = "cuda"
|
80 |
+
ALL_BACKEND.append("cuda")
|
81 |
+
except ImportError:
|
82 |
+
core_cuda = None
|
83 |
+
|
84 |
+
|
85 |
+
class BaseProcessor(ABC):
|
86 |
+
"""API definition for processor class."""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self, gradient: str, rank: int, backend: str, core: Optional[Any]
|
90 |
+
):
|
91 |
+
if core is None:
|
92 |
+
error_msg = {
|
93 |
+
"numpy":
|
94 |
+
"Please run `pip install numpy`.",
|
95 |
+
"numba":
|
96 |
+
"Please run `pip install numba`.",
|
97 |
+
"gcc":
|
98 |
+
"Please install cmake and gcc in your operating system.",
|
99 |
+
"openmp":
|
100 |
+
"Please make sure your gcc is compatible with `-fopenmp` option.",
|
101 |
+
"mpi":
|
102 |
+
"Please install MPI and run `pip install mpi4py`.",
|
103 |
+
"cuda":
|
104 |
+
"Please make sure nvcc and cuda-related libraries are available.",
|
105 |
+
"taichi":
|
106 |
+
"Please run `pip install taichi`.",
|
107 |
+
}
|
108 |
+
print(error_msg[backend.split("-")[0]])
|
109 |
+
|
110 |
+
raise AssertionError(f"Invalid backend {backend}.")
|
111 |
+
|
112 |
+
self.gradient = gradient
|
113 |
+
self.rank = rank
|
114 |
+
self.backend = backend
|
115 |
+
self.core = core
|
116 |
+
self.root = rank == 0
|
117 |
+
|
118 |
+
def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
119 |
+
if self.gradient == "src":
|
120 |
+
return a
|
121 |
+
if self.gradient == "avg":
|
122 |
+
return (a + b) / 2
|
123 |
+
# mix gradient, see Equ. 12 in PIE paper
|
124 |
+
mask = np.abs(a) < np.abs(b)
|
125 |
+
a[mask] = b[mask]
|
126 |
+
return a
|
127 |
+
|
128 |
+
@abstractmethod
|
129 |
+
def reset(
|
130 |
+
self,
|
131 |
+
src: np.ndarray,
|
132 |
+
mask: np.ndarray,
|
133 |
+
tgt: np.ndarray,
|
134 |
+
mask_on_src: Tuple[int, int],
|
135 |
+
mask_on_tgt: Tuple[int, int],
|
136 |
+
) -> int:
|
137 |
+
pass
|
138 |
+
|
139 |
+
def sync(self) -> None:
|
140 |
+
self.core.sync()
|
141 |
+
|
142 |
+
@abstractmethod
|
143 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
144 |
+
pass
|
145 |
+
|
146 |
+
|
147 |
+
class EquProcessor(BaseProcessor):
|
148 |
+
"""PIE Jacobi equation processor."""
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
gradient: str = "max",
|
153 |
+
backend: str = DEFAULT_BACKEND,
|
154 |
+
n_cpu: int = CPU_COUNT,
|
155 |
+
min_interval: int = 100,
|
156 |
+
block_size: int = 1024,
|
157 |
+
):
|
158 |
+
core: Optional[Any] = None
|
159 |
+
rank = 0
|
160 |
+
|
161 |
+
if backend == "numpy":
|
162 |
+
core = np_solver.EquSolver()
|
163 |
+
elif backend == "numba" and numba_solver is not None:
|
164 |
+
core = numba_solver.EquSolver()
|
165 |
+
elif backend == "gcc":
|
166 |
+
core = core_gcc.EquSolver()
|
167 |
+
elif backend == "openmp" and core_openmp is not None:
|
168 |
+
core = core_openmp.EquSolver(n_cpu)
|
169 |
+
elif backend == "mpi" and core_mpi is not None:
|
170 |
+
core = core_mpi.EquSolver(min_interval)
|
171 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
172 |
+
elif backend == "cuda" and core_cuda is not None:
|
173 |
+
core = core_cuda.EquSolver(block_size)
|
174 |
+
elif backend.startswith("taichi") and taichi_solver is not None:
|
175 |
+
core = taichi_solver.EquSolver(backend, n_cpu, block_size)
|
176 |
+
|
177 |
+
super().__init__(gradient, rank, backend, core)
|
178 |
+
|
179 |
+
def mask2index(
|
180 |
+
self, mask: np.ndarray
|
181 |
+
) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
|
182 |
+
x, y = np.nonzero(mask)
|
183 |
+
max_id = x.shape[0] + 1
|
184 |
+
index = np.zeros((max_id, 3))
|
185 |
+
ids = self.core.partition(mask)
|
186 |
+
ids[mask == 0] = 0 # reserve id=0 for constant
|
187 |
+
index = ids[x, y].argsort()
|
188 |
+
return ids, max_id, x[index], y[index]
|
189 |
+
|
190 |
+
def reset(
|
191 |
+
self,
|
192 |
+
src: np.ndarray,
|
193 |
+
mask: np.ndarray,
|
194 |
+
tgt: np.ndarray,
|
195 |
+
mask_on_src: Tuple[int, int],
|
196 |
+
mask_on_tgt: Tuple[int, int],
|
197 |
+
) -> int:
|
198 |
+
assert self.root
|
199 |
+
# check validity
|
200 |
+
# assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
|
201 |
+
# assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
|
202 |
+
# assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
|
203 |
+
# assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
|
204 |
+
# assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
|
205 |
+
|
206 |
+
if len(mask.shape) == 3:
|
207 |
+
mask = mask.mean(-1)
|
208 |
+
mask = (mask >= 128).astype(np.int32)
|
209 |
+
|
210 |
+
# zero-out edge
|
211 |
+
mask[0] = 0
|
212 |
+
mask[-1] = 0
|
213 |
+
mask[:, 0] = 0
|
214 |
+
mask[:, -1] = 0
|
215 |
+
|
216 |
+
x, y = np.nonzero(mask)
|
217 |
+
x0, x1 = x.min() - 1, x.max() + 2
|
218 |
+
y0, y1 = y.min() - 1, y.max() + 2
|
219 |
+
mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
|
220 |
+
mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
|
221 |
+
mask = mask[x0:x1, y0:y1]
|
222 |
+
ids, max_id, index_x, index_y = self.mask2index(mask)
|
223 |
+
|
224 |
+
src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
|
225 |
+
tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
|
226 |
+
|
227 |
+
src_C = src[src_x, src_y].astype(np.float32)
|
228 |
+
src_U = src[src_x - 1, src_y].astype(np.float32)
|
229 |
+
src_D = src[src_x + 1, src_y].astype(np.float32)
|
230 |
+
src_L = src[src_x, src_y - 1].astype(np.float32)
|
231 |
+
src_R = src[src_x, src_y + 1].astype(np.float32)
|
232 |
+
tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
|
233 |
+
tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
|
234 |
+
tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
|
235 |
+
tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
|
236 |
+
tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
|
237 |
+
|
238 |
+
grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
|
239 |
+
+ self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
|
240 |
+
+ self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
|
241 |
+
+ self.mixgrad(src_C - src_D, tgt_C - tgt_D)
|
242 |
+
|
243 |
+
A = np.zeros((max_id, 4), np.int32)
|
244 |
+
X = np.zeros((max_id, 3), np.float32)
|
245 |
+
B = np.zeros((max_id, 3), np.float32)
|
246 |
+
|
247 |
+
X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
|
248 |
+
# four-way
|
249 |
+
A[1:, 0] = ids[index_x - 1, index_y]
|
250 |
+
A[1:, 1] = ids[index_x + 1, index_y]
|
251 |
+
A[1:, 2] = ids[index_x, index_y - 1]
|
252 |
+
A[1:, 3] = ids[index_x, index_y + 1]
|
253 |
+
B[1:] = grad
|
254 |
+
m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
|
255 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
|
256 |
+
m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
|
257 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
|
258 |
+
m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
|
259 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
|
260 |
+
m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
|
261 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
|
262 |
+
|
263 |
+
self.tgt = tgt.copy()
|
264 |
+
self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
|
265 |
+
self.core.reset(max_id, A, X, B)
|
266 |
+
return max_id
|
267 |
+
|
268 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
269 |
+
result = self.core.step(iteration)
|
270 |
+
if self.root:
|
271 |
+
x, err = result
|
272 |
+
self.tgt[self.tgt_index] = x[1:]
|
273 |
+
return self.tgt, err
|
274 |
+
return None
|
275 |
+
|
276 |
+
|
277 |
+
class GridProcessor(BaseProcessor):
|
278 |
+
"""PIE grid processor."""
|
279 |
+
|
280 |
+
def __init__(
|
281 |
+
self,
|
282 |
+
gradient: str = "max",
|
283 |
+
backend: str = DEFAULT_BACKEND,
|
284 |
+
n_cpu: int = CPU_COUNT,
|
285 |
+
min_interval: int = 100,
|
286 |
+
block_size: int = 1024,
|
287 |
+
grid_x: int = 8,
|
288 |
+
grid_y: int = 8,
|
289 |
+
):
|
290 |
+
core: Optional[Any] = None
|
291 |
+
rank = 0
|
292 |
+
|
293 |
+
if backend == "numpy":
|
294 |
+
core = np_solver.GridSolver()
|
295 |
+
elif backend == "numba" and numba_solver is not None:
|
296 |
+
core = numba_solver.GridSolver()
|
297 |
+
elif backend == "gcc":
|
298 |
+
core = core_gcc.GridSolver(grid_x, grid_y)
|
299 |
+
elif backend == "openmp" and core_openmp is not None:
|
300 |
+
core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
|
301 |
+
elif backend == "mpi" and core_mpi is not None:
|
302 |
+
core = core_mpi.GridSolver(min_interval)
|
303 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
304 |
+
elif backend == "cuda" and core_cuda is not None:
|
305 |
+
core = core_cuda.GridSolver(grid_x, grid_y)
|
306 |
+
elif backend.startswith("taichi") and taichi_solver is not None:
|
307 |
+
core = taichi_solver.GridSolver(
|
308 |
+
grid_x, grid_y, backend, n_cpu, block_size
|
309 |
+
)
|
310 |
+
|
311 |
+
super().__init__(gradient, rank, backend, core)
|
312 |
+
|
313 |
+
def reset(
|
314 |
+
self,
|
315 |
+
src: np.ndarray,
|
316 |
+
mask: np.ndarray,
|
317 |
+
tgt: np.ndarray,
|
318 |
+
mask_on_src: Tuple[int, int],
|
319 |
+
mask_on_tgt: Tuple[int, int],
|
320 |
+
) -> int:
|
321 |
+
assert self.root
|
322 |
+
# check validity
|
323 |
+
# assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
|
324 |
+
# assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
|
325 |
+
# assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
|
326 |
+
# assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
|
327 |
+
# assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
|
328 |
+
|
329 |
+
if len(mask.shape) == 3:
|
330 |
+
mask = mask.mean(-1)
|
331 |
+
mask = (mask >= 128).astype(np.int32)
|
332 |
+
|
333 |
+
# zero-out edge
|
334 |
+
mask[0] = 0
|
335 |
+
mask[-1] = 0
|
336 |
+
mask[:, 0] = 0
|
337 |
+
mask[:, -1] = 0
|
338 |
+
|
339 |
+
x, y = np.nonzero(mask)
|
340 |
+
x0, x1 = x.min() - 1, x.max() + 2
|
341 |
+
y0, y1 = y.min() - 1, y.max() + 2
|
342 |
+
mask = mask[x0:x1, y0:y1]
|
343 |
+
max_id = np.prod(mask.shape)
|
344 |
+
|
345 |
+
src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
|
346 |
+
mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
|
347 |
+
tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
|
348 |
+
mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
|
349 |
+
grad = np.zeros([*mask.shape, 3], np.float32)
|
350 |
+
grad[1:] += self.mixgrad(
|
351 |
+
src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
|
352 |
+
)
|
353 |
+
grad[:-1] += self.mixgrad(
|
354 |
+
src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
|
355 |
+
)
|
356 |
+
grad[:, 1:] += self.mixgrad(
|
357 |
+
src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
|
358 |
+
)
|
359 |
+
grad[:, :-1] += self.mixgrad(
|
360 |
+
src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
|
361 |
+
)
|
362 |
+
|
363 |
+
grad[mask == 0] = 0
|
364 |
+
if True:
|
365 |
+
kernel = [[1] * 3 for _ in range(3)]
|
366 |
+
nmask = mask.copy()
|
367 |
+
nmask[nmask > 0] = 1
|
368 |
+
res = scipy.signal.convolve2d(
|
369 |
+
nmask, kernel, mode="same", boundary="fill", fillvalue=1
|
370 |
+
)
|
371 |
+
res[nmask < 1] = 0
|
372 |
+
res[res == 9] = 0
|
373 |
+
res[res > 0] = 1
|
374 |
+
grad[res>0]=0
|
375 |
+
# ylst, xlst = res.nonzero()
|
376 |
+
# for y, x in zip(ylst, xlst):
|
377 |
+
# grad[y,x]=0
|
378 |
+
# for yi in range(-1,2):
|
379 |
+
# for xi in range(-1,2):
|
380 |
+
# grad[y+yi,x+xi]=0
|
381 |
+
self.x0 = mask_on_tgt[0] + x0
|
382 |
+
self.x1 = mask_on_tgt[0] + x1
|
383 |
+
self.y0 = mask_on_tgt[1] + y0
|
384 |
+
self.y1 = mask_on_tgt[1] + y1
|
385 |
+
self.tgt = tgt.copy()
|
386 |
+
self.core.reset(max_id, mask, tgt_crop, grad)
|
387 |
+
return max_id
|
388 |
+
|
389 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
390 |
+
result = self.core.step(iteration)
|
391 |
+
if self.root:
|
392 |
+
tgt, err = result
|
393 |
+
self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
|
394 |
+
return self.tgt, err
|
395 |
+
return None
|
utils.py
CHANGED
@@ -1,151 +1,263 @@
|
|
1 |
-
from PIL import Image
|
2 |
-
from PIL import ImageFilter
|
3 |
-
import cv2
|
4 |
-
import numpy as np
|
5 |
-
import scipy
|
6 |
-
import scipy.signal
|
7 |
-
from scipy.spatial import cKDTree
|
8 |
-
|
9 |
-
import os
|
10 |
-
from perlin2d import *
|
11 |
-
|
12 |
-
patch_match_compiled = True
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
res
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
#
|
103 |
-
# mask=
|
104 |
-
|
105 |
-
#
|
106 |
-
|
107 |
-
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from PIL import ImageFilter
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import scipy
|
6 |
+
import scipy.signal
|
7 |
+
from scipy.spatial import cKDTree
|
8 |
+
|
9 |
+
import os
|
10 |
+
from perlin2d import *
|
11 |
+
|
12 |
+
patch_match_compiled = True
|
13 |
+
|
14 |
+
try:
|
15 |
+
from PyPatchMatch import patch_match
|
16 |
+
except Exception as e:
|
17 |
+
try:
|
18 |
+
import patch_match
|
19 |
+
except Exception as e:
|
20 |
+
patch_match_compiled = False
|
21 |
+
|
22 |
+
try:
|
23 |
+
patch_match
|
24 |
+
except NameError:
|
25 |
+
print("patch_match compiling failed, will fall back to edge_pad")
|
26 |
+
patch_match_compiled = False
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def edge_pad(img, mask, mode=1):
|
32 |
+
if mode == 0:
|
33 |
+
nmask = mask.copy()
|
34 |
+
nmask[nmask > 0] = 1
|
35 |
+
res0 = 1 - nmask
|
36 |
+
res1 = nmask
|
37 |
+
p0 = np.stack(res0.nonzero(), axis=0).transpose()
|
38 |
+
p1 = np.stack(res1.nonzero(), axis=0).transpose()
|
39 |
+
min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
|
40 |
+
loc = p1[min_dist_idx]
|
41 |
+
for (a, b), (c, d) in zip(p0, loc):
|
42 |
+
img[a, b] = img[c, d]
|
43 |
+
elif mode == 1:
|
44 |
+
record = {}
|
45 |
+
kernel = [[1] * 3 for _ in range(3)]
|
46 |
+
nmask = mask.copy()
|
47 |
+
nmask[nmask > 0] = 1
|
48 |
+
res = scipy.signal.convolve2d(
|
49 |
+
nmask, kernel, mode="same", boundary="fill", fillvalue=1
|
50 |
+
)
|
51 |
+
res[nmask < 1] = 0
|
52 |
+
res[res == 9] = 0
|
53 |
+
res[res > 0] = 1
|
54 |
+
ylst, xlst = res.nonzero()
|
55 |
+
queue = [(y, x) for y, x in zip(ylst, xlst)]
|
56 |
+
# bfs here
|
57 |
+
cnt = res.astype(np.float32)
|
58 |
+
acc = img.astype(np.float32)
|
59 |
+
step = 1
|
60 |
+
h = acc.shape[0]
|
61 |
+
w = acc.shape[1]
|
62 |
+
offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
|
63 |
+
while queue:
|
64 |
+
target = []
|
65 |
+
for y, x in queue:
|
66 |
+
val = acc[y][x]
|
67 |
+
for yo, xo in offset:
|
68 |
+
yn = y + yo
|
69 |
+
xn = x + xo
|
70 |
+
if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
|
71 |
+
if record.get((yn, xn), step) == step:
|
72 |
+
acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
|
73 |
+
cnt[yn][xn] += 1
|
74 |
+
acc[yn][xn] /= cnt[yn][xn]
|
75 |
+
if (yn, xn) not in record:
|
76 |
+
record[(yn, xn)] = step
|
77 |
+
target.append((yn, xn))
|
78 |
+
step += 1
|
79 |
+
queue = target
|
80 |
+
img = acc.astype(np.uint8)
|
81 |
+
else:
|
82 |
+
nmask = mask.copy()
|
83 |
+
ylst, xlst = nmask.nonzero()
|
84 |
+
yt, xt = ylst.min(), xlst.min()
|
85 |
+
yb, xb = ylst.max(), xlst.max()
|
86 |
+
content = img[yt : yb + 1, xt : xb + 1]
|
87 |
+
img = np.pad(
|
88 |
+
content,
|
89 |
+
((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
|
90 |
+
mode="edge",
|
91 |
+
)
|
92 |
+
return img, mask
|
93 |
+
|
94 |
+
|
95 |
+
def perlin_noise(img, mask):
|
96 |
+
lin = np.linspace(0, 5, mask.shape[0], endpoint=False)
|
97 |
+
x, y = np.meshgrid(lin, lin)
|
98 |
+
avg = img.mean(axis=0).mean(axis=0)
|
99 |
+
# noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
|
100 |
+
noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
|
101 |
+
noise = np.stack(noise, axis=-1)
|
102 |
+
# mask=skimage.measure.block_reduce(mask,(8,8),np.min)
|
103 |
+
# mask=mask.repeat(8, axis=0).repeat(8, axis=1)
|
104 |
+
# mask_image=Image.fromarray(mask)
|
105 |
+
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
|
106 |
+
# mask=np.array(mask_image)
|
107 |
+
nmask = mask.copy()
|
108 |
+
# nmask=nmask/255.0
|
109 |
+
nmask[mask > 0] = 1
|
110 |
+
img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
|
111 |
+
# img=img.astype(np.uint8)
|
112 |
+
return img, mask
|
113 |
+
|
114 |
+
|
115 |
+
def gaussian_noise(img, mask):
|
116 |
+
noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
|
117 |
+
noise = (noise + 1) / 2 * 255
|
118 |
+
noise = noise.astype(np.uint8)
|
119 |
+
nmask = mask.copy()
|
120 |
+
nmask[mask > 0] = 1
|
121 |
+
img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
|
122 |
+
return img, mask
|
123 |
+
|
124 |
+
|
125 |
+
def cv2_telea(img, mask):
|
126 |
+
ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
|
127 |
+
return ret, mask
|
128 |
+
|
129 |
+
|
130 |
+
def cv2_ns(img, mask):
|
131 |
+
ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
|
132 |
+
return ret, mask
|
133 |
+
|
134 |
+
|
135 |
+
def patch_match_func(img, mask):
|
136 |
+
ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
|
137 |
+
return ret, mask
|
138 |
+
|
139 |
+
|
140 |
+
def mean_fill(img, mask):
|
141 |
+
avg = img.mean(axis=0).mean(axis=0)
|
142 |
+
img[mask < 1] = avg
|
143 |
+
return img, mask
|
144 |
+
|
145 |
+
"""
|
146 |
+
Apache-2.0 license
|
147 |
+
https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py
|
148 |
+
https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2
|
149 |
+
_handleImageAdjustment
|
150 |
+
"""
|
151 |
+
if True:
|
152 |
+
from sd_grpcserver.sdgrpcserver import images
|
153 |
+
import torch
|
154 |
+
from math import sqrt
|
155 |
+
def handleImageAdjustment(array, adjustments):
|
156 |
+
tensor = images.fromPIL(Image.fromarray(array))
|
157 |
+
for adjustment in adjustments:
|
158 |
+
which = adjustment[0]
|
159 |
+
|
160 |
+
if which == "blur":
|
161 |
+
sigma = adjustment[1]
|
162 |
+
direction = adjustment[2]
|
163 |
+
|
164 |
+
if direction == "DOWN" or direction == "UP":
|
165 |
+
orig = tensor
|
166 |
+
repeatCount=256
|
167 |
+
sigma /= sqrt(repeatCount)
|
168 |
+
|
169 |
+
for _ in range(repeatCount):
|
170 |
+
tensor = images.gaussianblur(tensor, sigma)
|
171 |
+
if direction == "DOWN":
|
172 |
+
tensor = torch.minimum(tensor, orig)
|
173 |
+
else:
|
174 |
+
tensor = torch.maximum(tensor, orig)
|
175 |
+
else:
|
176 |
+
tensor = images.gaussianblur(tensor, adjustment.blur.sigma)
|
177 |
+
elif which == "invert":
|
178 |
+
tensor = images.invert(tensor)
|
179 |
+
elif which == "levels":
|
180 |
+
tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4])
|
181 |
+
elif which == "channels":
|
182 |
+
tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a])
|
183 |
+
elif which == "rescale":
|
184 |
+
self.unimp("Rescale")
|
185 |
+
elif which == "crop":
|
186 |
+
tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width)
|
187 |
+
return np.array(images.toPIL(tensor)[0])
|
188 |
+
|
189 |
+
def g_diffuser(img,mask):
|
190 |
+
adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]]
|
191 |
+
mask=handleImageAdjustment(mask,adjustments)
|
192 |
+
out_mask=handleImageAdjustment(mask,adjustments)
|
193 |
+
return img, mask, out_mask
|
194 |
+
def dummy_fill(img,mask):
|
195 |
+
return img,mask
|
196 |
+
functbl = {
|
197 |
+
"gaussian": gaussian_noise,
|
198 |
+
"perlin": perlin_noise,
|
199 |
+
"edge_pad": edge_pad,
|
200 |
+
"patchmatch": patch_match_func if patch_match_compiled else edge_pad,
|
201 |
+
"cv2_ns": cv2_ns,
|
202 |
+
"cv2_telea": cv2_telea,
|
203 |
+
"g_diffuser": g_diffuser,
|
204 |
+
"g_diffuser_lib": dummy_fill,
|
205 |
+
}
|
206 |
+
|
207 |
+
try:
|
208 |
+
from postprocess import PhotometricCorrection
|
209 |
+
correction_func = PhotometricCorrection()
|
210 |
+
except Exception as e:
|
211 |
+
print(e, "so PhotometricCorrection is disabled")
|
212 |
+
class DummyCorrection:
|
213 |
+
def __init__(self):
|
214 |
+
self.backend=""
|
215 |
+
pass
|
216 |
+
def run(self,a,b,**kwargs):
|
217 |
+
return b
|
218 |
+
correction_func=DummyCorrection()
|
219 |
+
|
220 |
+
if "taichi" in correction_func.backend:
|
221 |
+
import sys
|
222 |
+
import io
|
223 |
+
import base64
|
224 |
+
from PIL import Image
|
225 |
+
def base64_to_pil(base64_str):
|
226 |
+
data = base64.b64decode(str(base64_str))
|
227 |
+
pil = Image.open(io.BytesIO(data))
|
228 |
+
return pil
|
229 |
+
|
230 |
+
def pil_to_base64(out_pil):
|
231 |
+
out_buffer = io.BytesIO()
|
232 |
+
out_pil.save(out_buffer, format="PNG")
|
233 |
+
out_buffer.seek(0)
|
234 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
235 |
+
base64_str = base64_bytes.decode("ascii")
|
236 |
+
return base64_str
|
237 |
+
from subprocess import Popen, PIPE, STDOUT
|
238 |
+
class SubprocessCorrection:
|
239 |
+
def __init__(self):
|
240 |
+
self.backend=correction_func.backend
|
241 |
+
self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
242 |
+
def run(self,img_input,img_inpainted,mode):
|
243 |
+
if mode=="disabled":
|
244 |
+
return img_inpainted
|
245 |
+
base64_str_input = pil_to_base64(img_input)
|
246 |
+
base64_str_inpainted = pil_to_base64(img_inpainted)
|
247 |
+
try:
|
248 |
+
if self.child.poll():
|
249 |
+
self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
250 |
+
self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
|
251 |
+
self.child.stdin.flush()
|
252 |
+
out = self.child.stdout.readline()
|
253 |
+
base64_str=out.decode().strip()
|
254 |
+
while base64_str and base64_str[0]=="[":
|
255 |
+
print(base64_str)
|
256 |
+
out = self.child.stdout.readline()
|
257 |
+
base64_str=out.decode().strip()
|
258 |
+
ret=base64_to_pil(base64_str)
|
259 |
+
except:
|
260 |
+
print("[PIE] not working, photometric correction is disabled")
|
261 |
+
ret=img_inpainted
|
262 |
+
return ret
|
263 |
+
correction_func = SubprocessCorrection()
|