ironjr commited on
Commit
6fedad3
1 Parent(s): 03a14d9

added first version

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ .*.sw*
3
+ .ipynb_checkpoints/
README.md CHANGED
@@ -1,13 +1,16 @@
1
  ---
2
  title: SemanticPalette
3
- emoji: 🐢
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: SemanticPalette
3
+ emoji: 🧠🎨
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
+ suggested_hardware: t4-small
12
+ suggested_storage: small
13
+ models: ironjr/BlazingDriveV11m
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import argparse
22
+ import random
23
+ import time
24
+ import json
25
+ import os
26
+ import glob
27
+ import pathlib
28
+ from functools import partial
29
+ from pprint import pprint
30
+
31
+ import numpy as np
32
+ from PIL import Image
33
+ import torch
34
+
35
+ import spaces
36
+ import gradio as gr
37
+ from huggingface_hub import snapshot_download
38
+
39
+ from model import StableMultiDiffusionPipeline
40
+ from util import seed_everything
41
+
42
+
43
+ ### Utils
44
+
45
+
46
+
47
+
48
+ def log_state(state):
49
+ pprint(vars(opt))
50
+ if isinstance(state, gr.State):
51
+ state = state.value
52
+ pprint(vars(state))
53
+
54
+
55
+ def is_empty_image(im: Image.Image) -> bool:
56
+ if im is None:
57
+ return True
58
+ im = np.array(im)
59
+ has_alpha = (im.shape[2] == 4)
60
+ if not has_alpha:
61
+ return False
62
+ elif im.sum() == 0:
63
+ return True
64
+ else:
65
+ return False
66
+
67
+
68
+ ### Argument passing
69
+
70
+ parser = argparse.ArgumentParser(description='Semantic drawing demo powered by StreamMultiDiffusion.')
71
+ parser.add_argument('-H', '--height', type=int, default=768)
72
+ parser.add_argument('-W', '--width', type=int, default=1920)
73
+ parser.add_argument('--model', type=str, default=None)
74
+ parser.add_argument('--bootstrap_steps', type=int, default=1)
75
+ parser.add_argument('--seed', type=int, default=-1)
76
+ parser.add_argument('--device', type=int, default=0)
77
+ parser.add_argument('--port', type=int, default=8000)
78
+ opt = parser.parse_args()
79
+
80
+
81
+ ### Global variables and data structures
82
+
83
+ device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
84
+
85
+
86
+ model_dict = {
87
+ 'Blazing Drive V11m': 'ironjr/BlazingDriveV11m',
88
+ 'Real Cartoon Pixar V5': 'ironjr/RealCartoon-PixarV5',
89
+ 'Kohaku V2.1': 'KBlueLeaf/kohaku-v2.1',
90
+ 'Realistic Vision V5.1': 'ironjr/RealisticVisionV5-1',
91
+ 'Stable Diffusion V1.5': 'runwayml/stable-diffusion-v1-5',
92
+ }
93
+
94
+ models = {
95
+ k: StableMultiDiffusionPipeline(device, sd_version='1.5', hf_key=v)
96
+ for k, v in model_dict.items()
97
+ }
98
+
99
+
100
+ prompt_suggestions = [
101
+ '1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
102
+ '1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
103
+ '1girl, arima kana, oshi no ko, solo, upper body, from behind',
104
+ ]
105
+
106
+ opt.max_palettes = 5
107
+ opt.default_prompt_strength = 1.0
108
+ opt.default_mask_strength = 1.0
109
+ opt.default_mask_std = 0.0
110
+ opt.default_negative_prompt = (
111
+ 'nsfw, worst quality, bad quality, normal quality, cropped, framed'
112
+ )
113
+ opt.verbose = True
114
+ opt.colors = [
115
+ '#000000',
116
+ '#2692F3',
117
+ '#F89E12',
118
+ '#16C232',
119
+ '#F92F6C',
120
+ '#AC6AEB',
121
+ # '#92C62C',
122
+ # '#92C6EC',
123
+ # '#FECAC0',
124
+ ]
125
+
126
+
127
+ ### Event handlers
128
+
129
+ def add_palette(state):
130
+ old_actives = state.active_palettes
131
+ state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
132
+
133
+ if opt.verbose:
134
+ log_state(state)
135
+
136
+ if state.active_palettes != old_actives:
137
+ return [state] + [
138
+ gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
139
+ ] + [
140
+ gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
141
+ for i in range(opt.max_palettes)
142
+ ]
143
+ else:
144
+ return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
145
+
146
+
147
+ def select_palette(state, button, idx):
148
+ if idx < 0 or idx > opt.max_palettes:
149
+ idx = 0
150
+ old_idx = state.current_palette
151
+ if old_idx == idx:
152
+ return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
153
+
154
+ state.current_palette = idx
155
+
156
+ if opt.verbose:
157
+ log_state(state)
158
+
159
+ updates = [state] + [
160
+ gr.update() if i not in (idx, old_idx) else
161
+ gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
162
+ for i in range(opt.max_palettes + 1)
163
+ ]
164
+ label = 'Background' if idx == 0 else f'Palette {idx}'
165
+ updates.extend([
166
+ gr.update(value=button, interactive=(idx > 0)),
167
+ gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
168
+ gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
169
+ (
170
+ gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
171
+ gr.update(value=opt.default_mask_strength, interactive=False)
172
+ ),
173
+ (
174
+ gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
175
+ gr.update(value=opt.default_prompt_strength, interactive=False)
176
+ ),
177
+ (
178
+ gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
179
+ gr.update(value=opt.default_mask_std, interactive=False)
180
+ ),
181
+ ])
182
+ return updates
183
+
184
+
185
+ def change_prompt_strength(state, strength):
186
+ if state.current_palette == 0:
187
+ return state
188
+
189
+ state.prompt_strengths[state.current_palette - 1] = strength
190
+ if opt.verbose:
191
+ log_state(state)
192
+
193
+ return state
194
+
195
+
196
+ def change_std(state, std):
197
+ if state.current_palette == 0:
198
+ return state
199
+
200
+ state.mask_stds[state.current_palette - 1] = std
201
+ if opt.verbose:
202
+ log_state(state)
203
+
204
+ return state
205
+
206
+
207
+ def change_mask_strength(state, strength):
208
+ if state.current_palette == 0:
209
+ return state
210
+
211
+ state.mask_strengths[state.current_palette - 1] = strength
212
+ if opt.verbose:
213
+ log_state(state)
214
+
215
+ return state
216
+
217
+
218
+ def reset_seed(state, seed):
219
+ state.seed = seed
220
+ if opt.verbose:
221
+ log_state(state)
222
+
223
+ return state
224
+
225
+ def rename_prompt(state, name):
226
+ state.prompt_names[state.current_palette] = name
227
+ if opt.verbose:
228
+ log_state(state)
229
+
230
+ return [state] + [
231
+ gr.update() if i != state.current_palette else gr.update(value=name)
232
+ for i in range(opt.max_palettes + 1)
233
+ ]
234
+
235
+
236
+ def change_prompt(state, prompt):
237
+ state.prompts[state.current_palette] = prompt
238
+ if opt.verbose:
239
+ log_state(state)
240
+
241
+ return state
242
+
243
+
244
+ def change_neg_prompt(state, neg_prompt):
245
+ state.neg_prompts[state.current_palette] = neg_prompt
246
+ if opt.verbose:
247
+ log_state(state)
248
+
249
+ return state
250
+
251
+
252
+ def select_model(state, model_id):
253
+ state.model_id = model_id
254
+ if opt.verbose:
255
+ log_state(state)
256
+
257
+ return state
258
+
259
+
260
+ def import_state(state, json_text):
261
+ current_palette = state.current_palette
262
+ # active_palettes = state.active_palettes
263
+ state = argparse.Namespace(**json.loads(json_text))
264
+ state.active_palettes = opt.max_palettes
265
+ return [state] + [
266
+ gr.update(value=v, visible=True) for v in state.prompt_names
267
+ ] + [
268
+ state.model_id,
269
+ state.prompts[current_palette],
270
+ state.prompt_names[current_palette],
271
+ state.neg_prompts[current_palette],
272
+ state.prompt_strengths[current_palette - 1],
273
+ state.mask_strengths[current_palette - 1],
274
+ state.mask_stds[current_palette - 1],
275
+ state.seed,
276
+ ]
277
+
278
+
279
+ ### Main worker
280
+
281
+ @spaces.GPU
282
+ def generate(state, *args, **kwargs):
283
+ return models[state.model_id](*args, **kwargs)
284
+
285
+
286
+
287
+ def run(state, drawpad):
288
+ seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
289
+ print('Generate!')
290
+
291
+ background = drawpad['background'].convert('RGBA')
292
+ inpainting_mode = np.asarray(background).sum() != 0
293
+ print('Inpainting mode: ', inpainting_mode)
294
+
295
+ user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
296
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
297
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
298
+
299
+ palette = torch.tensor([
300
+ tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
301
+ for s in opt.colors[1:]
302
+ ]) # (N, 3)
303
+ masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
304
+ has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
305
+ print('Has mask: ', has_masks)
306
+ masks = masks * foreground_mask
307
+ masks = masks[has_masks]
308
+
309
+ # if inpainting_mode:
310
+ # prompts = state.prompts[1:len(masks)+1]
311
+ # negative_prompts = state.neg_prompts[1:len(masks)+1]
312
+ # mask_strengths = state.mask_strengths[:len(masks)]
313
+ # mask_stds = state.mask_stds[:len(masks)]
314
+ # prompt_strengths = state.prompt_strengths[:len(masks)]
315
+ # else:
316
+ # masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
317
+ # prompts = state.prompts[:len(masks)+1]
318
+ # negative_prompts = state.neg_prompts[:len(masks)+1]
319
+ # mask_strengths = [1] + state.mask_strengths[:len(masks)]
320
+ # mask_stds = [0] + [state.mask_stds[:len(masks)]
321
+ # prompt_strengths = [1] + state.prompt_strengths[:len(masks)]
322
+
323
+ if inpainting_mode:
324
+ prompts = [state.prompts[v + 1] for v in has_masks]
325
+ negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
326
+ mask_strengths = [state.mask_strengths[v] for v in has_masks]
327
+ mask_stds = [state.mask_stds[v] for v in has_masks]
328
+ prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
329
+ else:
330
+ masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
331
+ prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
332
+ negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
333
+ mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
334
+ mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
335
+ prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
336
+
337
+ return generate(
338
+ state,
339
+ prompts,
340
+ negative_prompts,
341
+ masks=masks,
342
+ mask_strengths=mask_strengths,
343
+ mask_stds=mask_stds,
344
+ prompt_strengths=prompt_strengths,
345
+ background=background.convert('RGB'),
346
+ background_prompt=state.prompts[0],
347
+ background_negative_prompt=state.neg_prompts[0],
348
+ height=opt.height,
349
+ width=opt.width,
350
+ bootstrap_steps=2,
351
+ )
352
+
353
+
354
+
355
+ ### Load examples
356
+
357
+
358
+ root = pathlib.Path(__file__).parent
359
+ example_root = os.path.join(root, 'examples')
360
+ example_images = glob.glob(os.path.join(example_root, '*.png'))
361
+ example_images = [Image.open(i) for i in example_images]
362
+
363
+ with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
364
+ prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
365
+
366
+ with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
367
+ prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
368
+
369
+ with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
370
+ prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
371
+
372
+ with open(os.path.join(example_root, 'prompt_props.txt')) as f:
373
+ prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
374
+ prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
375
+
376
+ prompt_background = lambda: random.choice(prompts_background)
377
+ prompt_girl = lambda: random.choice(prompts_girl)
378
+ prompt_boy = lambda: random.choice(prompts_boy)
379
+ prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
380
+
381
+
382
+ ### Main application
383
+
384
+ css = f"""
385
+ #run-button {{
386
+ font-size: 30pt;
387
+ background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
388
+ margin: 0;
389
+ padding: 15px 45px;
390
+ text-align: center;
391
+ text-transform: uppercase;
392
+ transition: 0.5s;
393
+ background-size: 200% auto;
394
+ color: white;
395
+ box-shadow: 0 0 20px #eee;
396
+ border-radius: 10px;
397
+ display: block;
398
+ background-position: right center;
399
+ }}
400
+
401
+ #run-button:hover {{
402
+ background-position: left center;
403
+ color: #fff;
404
+ text-decoration: none;
405
+ }}
406
+
407
+ #semantic-palette {{
408
+ border-style: solid;
409
+ border-width: 0.2em;
410
+ border-color: #eee;
411
+ }}
412
+
413
+ #semantic-palette:hover {{
414
+ box-shadow: 0 0 20px #eee;
415
+ }}
416
+
417
+ #output-screen {{
418
+ width: 100%;
419
+ aspect-ratio: {opt.width} / {opt.height};
420
+ }}
421
+
422
+ .layer-wrap {{
423
+ display: none;
424
+ }}
425
+ """
426
+
427
+ for i in range(opt.max_palettes + 1):
428
+ css = css + f"""
429
+ .secondary#semantic-palette-{i} {{
430
+ background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
431
+ }}
432
+
433
+ .primary#semantic-palette-{i} {{
434
+ background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
435
+ }}
436
+ """
437
+
438
+
439
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
440
+
441
+ iface = argparse.Namespace()
442
+
443
+ def _define_state():
444
+ state = argparse.Namespace()
445
+
446
+ # Cursor.
447
+ state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
448
+ state.model_id = list(model_dict.keys())[0]
449
+
450
+ # State variables (one-hot).
451
+ state.active_palettes = 1
452
+
453
+ # Front-end initialized to the default values.
454
+ prompt_props_ = prompt_props()
455
+ # state.prompt_names = [
456
+ # '🌄 Background',
457
+ # '👧 Girl',
458
+ # '👦 Boy',
459
+ # '🐶 Dog',
460
+ # '🚗 Car',
461
+ # '💐 Garden',
462
+ # ] + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
463
+ # state.prompts = [
464
+ # 'Maximalism, best quality, high quality, city lights, times square',
465
+ # '1girl, looking at viewer, pink hair, leather jacket',
466
+ # '1boy, looking at viewer, brown hair, casual shirt',
467
+ # 'Doggy body part',
468
+ # 'Car',
469
+ # 'Flower garden',
470
+ # ] + ['' for _ in range(opt.max_palettes - 5)]
471
+ state.prompt_names = [
472
+ '🌄 Background',
473
+ '👧 Girl',
474
+ '👦 Boy',
475
+ ] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
476
+ state.prompts = [
477
+ prompt_background(),
478
+ prompt_girl(),
479
+ prompt_boy(),
480
+ ] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
481
+ state.neg_prompts = [
482
+ opt.default_negative_prompt
483
+ + (', humans, humans, humans' if i == 0 else '')
484
+ for i in range(opt.max_palettes + 1)
485
+ ]
486
+ state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
487
+ state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
488
+ state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
489
+ state.seed = opt.seed
490
+ return state
491
+
492
+ state = gr.State(value=_define_state)
493
+
494
+
495
+ ### Demo user interface
496
+
497
+ gr.HTML(
498
+ """
499
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
500
+ <div>
501
+ <h1>🧠 Semantic Paint 🎨</h1>
502
+ <h5 style="margin: 0;">powered by</h5>
503
+ <h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
504
+ <h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
505
+ </br>
506
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
507
+ <a href='https://arxiv.org/abs/2403.09055'>
508
+ <img src="https://img.shields.io/badge/arXiv-2403.09055-red">
509
+ </a>
510
+ &nbsp;
511
+ <a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
512
+ <img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
513
+ </a>
514
+ &nbsp;
515
+ <a href='https://github.com/ironjr/StreamMultiDiffusion'>
516
+ <img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
517
+ </a>
518
+ &nbsp;
519
+ <a href='https://twitter.com/_ironjr_'>
520
+ <img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
521
+ </a>
522
+ &nbsp;
523
+ <a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
524
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
525
+ </a>
526
+ &nbsp;
527
+ <a href='https://huggingface.co/papers/2403.09055'>
528
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper-yellow'>
529
+ </a>
530
+ </div>
531
+ </div>
532
+ </div>
533
+ <div>
534
+ </br>
535
+ </div>
536
+ """
537
+ )
538
+
539
+ with gr.Row():
540
+
541
+ iface.image_slot = gr.Image(
542
+ interactive=False,
543
+ show_label=False,
544
+ show_download_button=True,
545
+ type='pil',
546
+ label='Generated Result',
547
+ elem_id='output-screen',
548
+ show_share_button=True,
549
+ value=lambda: random.choice(example_images),
550
+ )
551
+
552
+ with gr.Row():
553
+
554
+ with gr.Column(scale=1):
555
+
556
+ with gr.Group(elem_id='semantic-palette'):
557
+
558
+ gr.HTML(
559
+ """
560
+ <div style="justify-content: center; align-items: center;">
561
+ <br/>
562
+ <h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
563
+ <br/>
564
+ </div>
565
+ """
566
+ )
567
+
568
+ iface.btn_semantics = [gr.Button(
569
+ value=state.value.prompt_names[0],
570
+ variant='primary',
571
+ elem_id='semantic-palette-0',
572
+ )]
573
+ for i in range(opt.max_palettes):
574
+ iface.btn_semantics.append(gr.Button(
575
+ value=state.value.prompt_names[i + 1],
576
+ variant='secondary',
577
+ visible=(i < state.value.active_palettes),
578
+ elem_id=f'semantic-palette-{i + 1}'
579
+ ))
580
+
581
+ iface.btn_add_palette = gr.Button(
582
+ value='Create New Semantic Brush',
583
+ variant='primary',
584
+ )
585
+
586
+ with gr.Accordion(label='Import/Export Semantic Palette', open=False):
587
+ iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
588
+ iface.json_state_export = gr.JSON(label='Exported Palette')
589
+ iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
590
+ iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
591
+
592
+ gr.HTML(
593
+ """
594
+ <div>
595
+ </br>
596
+ </div>
597
+ <div style="justify-content: center; align-items: center;">
598
+ <h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
599
+ </br>
600
+ <div style="justify-content: center; align-items: left; text-align: left;">
601
+ <p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
602
+ <p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
603
+ <p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
604
+ <p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
605
+ <p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
606
+ <p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
607
+ </div>
608
+ </div>
609
+ """
610
+ )
611
+
612
+ gr.HTML(
613
+ """
614
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
615
+ <h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
616
+ </div>
617
+ """
618
+ )
619
+
620
+ gr.DuplicateButton()
621
+
622
+ with gr.Column(scale=4):
623
+
624
+ with gr.Row():
625
+
626
+ with gr.Column(scale=3):
627
+
628
+ iface.ctrl_semantic = gr.ImageEditor(
629
+ image_mode='RGBA',
630
+ sources=['upload', 'clipboard', 'webcam'],
631
+ transforms=['crop'],
632
+ crop_size=(opt.width, opt.height),
633
+ brush=gr.Brush(
634
+ colors=opt.colors[1:],
635
+ color_mode="fixed",
636
+ ),
637
+ type='pil',
638
+ label='Semantic Drawpad',
639
+ elem_id='drawpad',
640
+ show_share_button=True,
641
+ )
642
+
643
+ with gr.Column(scale=1):
644
+
645
+ iface.btn_generate = gr.Button(
646
+ value='Generate!',
647
+ variant='primary',
648
+ # scale=1,
649
+ elem_id='run-button'
650
+ )
651
+
652
+ iface.model_select = gr.Radio(
653
+ list(model_dict.keys()),
654
+ label='Stable Diffusion Checkpoint',
655
+ info='Choose your favorite style.',
656
+ value=state.value.model_id,
657
+ )
658
+
659
+ with gr.Group(elem_id='control-panel'):
660
+
661
+ with gr.Row():
662
+ iface.tbox_prompt = gr.Textbox(
663
+ label='Edit Prompt for Background',
664
+ info='What do you want to draw?',
665
+ value=state.value.prompts[0],
666
+ placeholder=lambda: random.choice(prompt_suggestions),
667
+ scale=2,
668
+ )
669
+
670
+ iface.tbox_name = gr.Textbox(
671
+ label='Edit Brush Name',
672
+ info='Just for your convenience.',
673
+ value=state.value.prompt_names[0],
674
+ placeholder='🌄 Background',
675
+ scale=1,
676
+ )
677
+
678
+ with gr.Row():
679
+ iface.tbox_neg_prompt = gr.Textbox(
680
+ label='Edit Negative Prompt for Background',
681
+ info='Add unwanted objects for this semantic brush.',
682
+ value=opt.default_negative_prompt,
683
+ scale=2,
684
+ )
685
+
686
+ iface.slider_strength = gr.Slider(
687
+ label='Prompt Strength',
688
+ info='Blends fg & bg in the prompt level, >0.8 Preferred.',
689
+ minimum=0.5,
690
+ maximum=1.0,
691
+ value=opt.default_prompt_strength,
692
+ scale=1,
693
+ )
694
+
695
+ with gr.Row():
696
+ iface.slider_alpha = gr.Slider(
697
+ label='Mask Alpha',
698
+ info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
699
+ minimum=0.5,
700
+ maximum=1.0,
701
+ value=opt.default_mask_strength,
702
+ )
703
+
704
+ iface.slider_std = gr.Slider(
705
+ label='Mask Blur STD',
706
+ info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
707
+ minimum=0.0001,
708
+ maximum=100.0,
709
+ value=opt.default_mask_std,
710
+ )
711
+
712
+ iface.slider_seed = gr.Slider(
713
+ label='Seed',
714
+ info='The global seed.',
715
+ minimum=-1,
716
+ maximum=2147483647,
717
+ step=1,
718
+ value=opt.seed,
719
+ )
720
+
721
+ ### Attach event handlers
722
+
723
+ for idx, btn in enumerate(iface.btn_semantics):
724
+ btn.click(
725
+ fn=partial(select_palette, idx=idx),
726
+ inputs=[state, btn],
727
+ outputs=[state] + iface.btn_semantics + [
728
+ iface.tbox_name,
729
+ iface.tbox_prompt,
730
+ iface.tbox_neg_prompt,
731
+ iface.slider_alpha,
732
+ iface.slider_strength,
733
+ iface.slider_std,
734
+ ],
735
+ api_name=f'select_palette_{idx}',
736
+ )
737
+
738
+ iface.btn_add_palette.click(
739
+ fn=add_palette,
740
+ inputs=state,
741
+ outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
742
+ api_name='create_new',
743
+ )
744
+
745
+ iface.btn_generate.click(
746
+ fn=run,
747
+ inputs=[state, iface.ctrl_semantic],
748
+ outputs=iface.image_slot,
749
+ api_name='run',
750
+ )
751
+
752
+ iface.slider_alpha.input(
753
+ fn=change_mask_strength,
754
+ inputs=[state, iface.slider_alpha],
755
+ outputs=state,
756
+ api_name='change_alpha',
757
+ )
758
+ iface.slider_std.input(
759
+ fn=change_std,
760
+ inputs=[state, iface.slider_std],
761
+ outputs=state,
762
+ api_name='change_std',
763
+ )
764
+ iface.slider_strength.input(
765
+ fn=change_prompt_strength,
766
+ inputs=[state, iface.slider_strength],
767
+ outputs=state,
768
+ api_name='change_strength',
769
+ )
770
+ iface.slider_seed.input(
771
+ fn=reset_seed,
772
+ inputs=[state, iface.slider_seed],
773
+ outputs=state,
774
+ api_name='reset_seed',
775
+ )
776
+
777
+ iface.tbox_name.input(
778
+ fn=rename_prompt,
779
+ inputs=[state, iface.tbox_name],
780
+ outputs=[state] + iface.btn_semantics,
781
+ api_name='prompt_rename',
782
+ )
783
+ iface.tbox_prompt.input(
784
+ fn=change_prompt,
785
+ inputs=[state, iface.tbox_prompt],
786
+ outputs=state,
787
+ api_name='prompt_edit',
788
+ )
789
+ iface.tbox_neg_prompt.input(
790
+ fn=change_neg_prompt,
791
+ inputs=[state, iface.tbox_neg_prompt],
792
+ outputs=state,
793
+ api_name='neg_prompt_edit',
794
+ )
795
+
796
+ iface.model_select.change(
797
+ fn=select_model,
798
+ inputs=[state, iface.model_select],
799
+ outputs=state,
800
+ api_name='model_select',
801
+ )
802
+
803
+ iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
804
+ iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
805
+ state,
806
+ *iface.btn_semantics,
807
+ iface.model_select,
808
+ iface.tbox_prompt,
809
+ iface.tbox_name,
810
+ iface.tbox_neg_prompt,
811
+ iface.slider_strength,
812
+ iface.slider_alpha,
813
+ iface.slider_std,
814
+ iface.slider_seed,
815
+ ])
816
+
817
+
818
+ if __name__ == '__main__':
819
+ demo..queue(max_size=20).launch()
examples/prompt_background.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
2
+ Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
3
+ Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
4
+ Maximalism, best quality, high quality, no humans, background, galaxy
5
+ Maximalism, best quality, high quality, no humans, background, sky, daylight
6
+ Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
7
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
8
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
examples/prompt_background_advanced.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/prompt_boy.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1boy, looking at viewer, brown hair, blue shirt
2
+ 1boy, looking at viewer, brown hair, red shirt
3
+ 1boy, looking at viewer, brown hair, purple shirt
4
+ 1boy, looking at viewer, brown hair, orange shirt
5
+ 1boy, looking at viewer, brown hair, yellow shirt
6
+ 1boy, looking at viewer, brown hair, green shirt
7
+ 1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
8
+ 1boy, looking back, short hair, renaissance cloths, noble boy
9
+ 1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
10
+ 1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
11
+ 1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
12
+ 1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
13
+ 1boy, looking at viewer, black haired, old eastern cloth
14
+ 1boy, looking back, messy hair, suit, short beard, noir
15
+ 1boy, looking at viewer, cute face, light smile, starry eyes, jeans
examples/prompt_girl.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
2
+ 1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
3
+ 1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
4
+ 1girl, looking at viewer, fantasy adventurer, backpack
5
+ 1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
6
+ 1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
7
+ 1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
8
+ 1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
9
+ 1girl, looking at viewer, evil smile, very short hair, suit, evil genius
10
+ 1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
11
+ 1girl, looking at viewer, purple hair, happy face, black leather jacket
12
+ 1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
13
+ 1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
14
+ 1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
15
+ 1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
16
+ 1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
examples/prompt_props.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🏯 Palace, Gyeongbokgung palace
2
+ 🌳 Garden, Chinese garden
3
+ 🏛️ Rome, Ancient city of Rome
4
+ 🧱 Wall, Castle wall
5
+ 🔴 Mars, Martian desert, Red rocky desert
6
+ 🌻 Grassland, Grasslands
7
+ 🏡 Village, A fantasy village
8
+ 🐉 Dragon, a flying chinese dragon
9
+ 🌏 Earth, Earth seen from ISS
10
+ 🚀 Space Station, the international space station
11
+ 🪻 Grassland, Rusty grassland with flowers
12
+ 🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
13
+ 🏙️ City Ruin, city, ruins, ruins, ruins, deserted
14
+ 🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
15
+ 🌷 Flowers, Flower garden
16
+ 🌼 Flowers, Flower garden, spring garden
17
+ 🌹 Flowers, Flowers flowers, flowers
18
+ ⛰️ Dolomites Mountains, Dolomites
19
+ ⛰️ Himalayas Mountains, Himalayas
20
+ ⛰️ Alps Mountains, Alps
21
+ ⛰️ Mountains, Mountains
22
+ ❄️⛰️ Mountains, Winter mountains
23
+ 🌷⛰️ Mountains, Spring mountains
24
+ 🌞⛰️ Mountains, Summer mountains
25
+ 🌵 Desert, A sandy desert, dunes
26
+ 🪨🌵 Desert, A rocky desert
27
+ 💦 Waterfall, A giant waterfall
28
+ 🌊 Ocean, Ocean
29
+ ⛱️ Seashore, Seashore
30
+ 🌅 Sea Horizon, Sea horizon
31
+ 🌊 Lake, Clear blue lake
32
+ 💻 Computer, A giant supecomputer
33
+ 🌳 Tree, A giant tree
34
+ 🌳 Forest, A forest
35
+ 🌳🌳 Forest, A dense forest
36
+ 🌲 Forest, Winter forest
37
+ 🌴 Forest, Summer forest, tropical forest
38
+ 👒 Hat, A hat
39
+ 🐶 Dog, Doggy body parts
40
+ 😻 Cat, A cat
41
+ 🦉 Owl, A small sitting owl
42
+ 🦅 Eagle, A small sitting eagle
43
+ 🚀 Rocket, A flying rocket
model.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
22
+ from diffusers import DiffusionPipeline, LCMScheduler, DDIMScheduler, AutoencoderTiny
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torchvision.transforms as T
28
+ from einops import rearrange
29
+
30
+ from typing import Tuple, List, Literal, Optional, Union
31
+ from tqdm import tqdm
32
+ from PIL import Image
33
+
34
+ from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
35
+
36
+
37
+ class StableMultiDiffusionPipeline(nn.Module):
38
+ def __init__(
39
+ self,
40
+ device: torch.device,
41
+ dtype: torch.dtype = torch.float16,
42
+ sd_version: Literal['1.5', '2.0', '2.1', 'xl'] = '1.5',
43
+ hf_key: Optional[str] = None,
44
+ lora_key: Optional[str] = None,
45
+ load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
46
+ default_mask_std: float = 1.0, # 8.0
47
+ default_mask_strength: float = 1.0,
48
+ default_prompt_strength: float = 1.0, # 8.0
49
+ default_bootstrap_steps: int = 1,
50
+ default_boostrap_mix_steps: float = 1.0,
51
+ default_bootstrap_leak_sensitivity: float = 0.2,
52
+ default_preprocess_mask_cover_alpha: float = 0.3,
53
+ t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
54
+ mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
55
+ ) -> None:
56
+ r"""Stabilized MultiDiffusion for fast sampling.
57
+
58
+ Accelrated region-based text-to-image synthesis with Latent Consistency
59
+ Model while preserving mask fidelity and quality.
60
+
61
+ Args:
62
+ device (torch.device): Specify CUDA device.
63
+ dtype (torch.dtype): Default precision used in the sampling
64
+ process. By default, it is FP16.
65
+ sd_version (Literal['1.5', '2.0', '2.1', 'xl']): StableDiffusion
66
+ version. Currently, only 1.5 is supported.
67
+ hf_key (Optional[str]): Custom StableDiffusion checkpoint for
68
+ stylized generation.
69
+ lora_key (Optional[str]): Custom LCM LoRA for acceleration.
70
+ load_from_local (bool): Turn on if you have already downloaed LoRA
71
+ & Hugging Face hub is down.
72
+ default_mask_std (float): Preprocess mask with Gaussian blur with
73
+ specified standard deviation.
74
+ default_mask_strength (float): Preprocess mask by multiplying it
75
+ globally with the specified variable. Caution: extremely
76
+ sensitive. Recommended range: 0.98-1.
77
+ default_prompt_strength (float): Preprocess foreground prompts
78
+ globally by linearly interpolating its embedding with the
79
+ background prompt embeddint with specified mix ratio. Useful
80
+ control handle for foreground blending. Recommended range:
81
+ 0.5-1.
82
+ default_bootstrap_steps (int): Bootstrapping stage steps to
83
+ encourage region separation. Recommended range: 1-3.
84
+ default_boostrap_mix_steps (float): Bootstrapping background is a
85
+ linear interpolation between background latent and the white
86
+ image latent. This handle controls the mix ratio. Available
87
+ range: 0-(number of bootstrapping inference steps). For
88
+ example, 2.3 means that for the first two steps, white image
89
+ is used as a bootstrapping background and in the third step,
90
+ mixture of white (0.3) and registered background (0.7) is used
91
+ as a bootstrapping background.
92
+ default_bootstrap_leak_sensitivity (float): Postprocessing at each
93
+ inference step by masking away the remaining bootstrap
94
+ backgrounds t Recommended range: 0-1.
95
+ default_preprocess_mask_cover_alpha (float): Optional preprocessing
96
+ where each mask covered by other masks is reduced in its alpha
97
+ value by this specified factor.
98
+ t_index_list (List[int]): The default scheduling for LCM scheduler.
99
+ mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
100
+ defines the mask quantization modes. Details in the codes of
101
+ `self.process_mask`. Basically, this (subtly) controls the
102
+ smoothness of foreground-background blending. More continuous
103
+ means more blending, but smaller generated patch depending on
104
+ the mask standard deviation.
105
+ """
106
+ super().__init__()
107
+
108
+ self.device = device
109
+ self.dtype = dtype
110
+ self.sd_version = sd_version
111
+
112
+ self.default_mask_std = default_mask_std
113
+ self.default_mask_strength = default_mask_strength
114
+ self.default_prompt_strength = default_prompt_strength
115
+ self.default_t_list = t_index_list
116
+ self.default_bootstrap_steps = default_bootstrap_steps
117
+ self.default_boostrap_mix_steps = default_boostrap_mix_steps
118
+ self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
119
+ self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
120
+ self.mask_type = mask_type
121
+
122
+ print(f'[INFO] Loading Stable Diffusion...')
123
+ variant = None
124
+ lora_weight_name = None
125
+ if self.sd_version == '1.5':
126
+ if hf_key is not None:
127
+ print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
128
+ model_key = hf_key
129
+ else:
130
+ model_key = 'runwayml/stable-diffusion-v1-5'
131
+ variant = 'fp16'
132
+ lora_key = 'latent-consistency/lcm-lora-sdv1-5'
133
+ lora_weight_name = 'pytorch_lora_weights.safetensors'
134
+ # elif self.sd_version == 'xl':
135
+ # model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
136
+ # lora_key = 'latent-consistency/lcm-lora-sdxl'
137
+ # variant = 'fp16'
138
+ # lora_weight_name = 'pytorch_lora_weights.safetensors'
139
+ else:
140
+ raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
141
+
142
+ # Create model
143
+ self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
144
+ self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
145
+
146
+ self.pipe = DiffusionPipeline.from_pretrained(model_key, variant=variant, torch_dtype=dtype).to(self.device)
147
+ if lora_key is None:
148
+ print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
149
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
150
+ self.scheduler = self.pipe.scheduler
151
+ self.default_num_inference_steps = 50
152
+ self.default_guidance_scale = 7.5
153
+ else:
154
+ self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
155
+ self.scheduler = self.pipe.scheduler
156
+ self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
157
+ self.default_num_inference_steps = 4
158
+ self.default_guidance_scale = 1.0
159
+
160
+ self.prepare_lcm_schedule(t_index_list, 50)
161
+
162
+ self.vae = self.pipe.vae
163
+ self.tokenizer = self.pipe.tokenizer
164
+ self.text_encoder = self.pipe.text_encoder
165
+ self.unet = self.pipe.unet
166
+ self.vae_scale_factor = self.pipe.vae_scale_factor
167
+
168
+ # Prepare white background for bootstrapping.
169
+ self.get_white_background(768, 768)
170
+
171
+ print(f'[INFO] Model is loaded!')
172
+
173
+ def prepare_lcm_schedule(
174
+ self,
175
+ t_index_list: Optional[List[int]] = None,
176
+ num_inference_steps: Optional[int] = None,
177
+ ) -> None:
178
+ r"""Set up different inference schedule for the diffusion model.
179
+
180
+ You do not have to run this explicitly if you want to use the default
181
+ setting, but if you want other time schedules, run this function
182
+ between the module initialization and the main call.
183
+
184
+ Note:
185
+ - Recommended t_index_lists for LCMs:
186
+ - [0, 12, 25, 37]: Default schedule for 4 steps. Best for
187
+ panorama. Not recommended if you want to use bootstrapping.
188
+ Because bootstrapping stage affects the initial structuring
189
+ of the generated image & in this four step LCM, this is done
190
+ with only at the first step, the structure may be distorted.
191
+ - [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
192
+ strapping. Default initialization in this implementation.
193
+ - [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
194
+ bootstrapping.
195
+ - Due to the characteristic of SD1.5 LCM LoRA, setting
196
+ `num_inference_steps` larger than 20 may results in overly blurry
197
+ and unrealistic images. Beware!
198
+
199
+ Args:
200
+ t_index_list (Optional[List[int]]): The specified scheduling step
201
+ regarding the maximum timestep as `num_inference_steps`, which
202
+ is by default, 50. That means that
203
+ `t_index_list=[0, 12, 25, 37]` is a relative time indices basd
204
+ on the full scale of 50. If None, reinitialize the module with
205
+ the default value.
206
+ num_inference_steps (Optional[int]): The maximum timestep of the
207
+ sampler. Defines relative scale of the `t_index_list`. Rarely
208
+ used in practice. If None, reinitialize the module with the
209
+ default value.
210
+ """
211
+ if t_index_list is None:
212
+ t_index_list = self.default_t_list
213
+ if num_inference_steps is None:
214
+ num_inference_steps = self.default_num_inference_steps
215
+
216
+ self.scheduler.set_timesteps(num_inference_steps)
217
+ self.timesteps = torch.as_tensor([
218
+ self.scheduler.timesteps[t] for t in t_index_list
219
+ ], dtype=torch.long)
220
+
221
+ shape = (len(t_index_list), 1, 1, 1)
222
+
223
+ c_skip_list = []
224
+ c_out_list = []
225
+ for timestep in self.timesteps:
226
+ c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
227
+ c_skip_list.append(c_skip)
228
+ c_out_list.append(c_out)
229
+ self.c_skip = torch.stack(c_skip_list).view(*shape).to(dtype=self.dtype, device=self.device)
230
+ self.c_out = torch.stack(c_out_list).view(*shape).to(dtype=self.dtype, device=self.device)
231
+
232
+ alpha_prod_t_sqrt_list = []
233
+ beta_prod_t_sqrt_list = []
234
+ for timestep in self.timesteps:
235
+ alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
236
+ beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
237
+ alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
238
+ beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
239
+ alpha_prod_t_sqrt = (torch.stack(alpha_prod_t_sqrt_list).view(*shape)
240
+ .to(dtype=self.dtype, device=self.device))
241
+ beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(*shape)
242
+ .to(dtype=self.dtype, device=self.device))
243
+ self.alpha_prod_t_sqrt = alpha_prod_t_sqrt
244
+ self.beta_prod_t_sqrt = beta_prod_t_sqrt
245
+
246
+ noise_lvs = (1 - self.scheduler.alphas_cumprod[self.timesteps].to(self.device)) ** 0.5
247
+ self.noise_lvs = noise_lvs[None, :, None, None, None]
248
+ self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
249
+
250
+ @torch.no_grad()
251
+ def get_text_embeds(self, prompt: str, negative_prompt: str) -> Tuple[torch.Tensor]:
252
+ r"""Text embeddings from string text prompts.
253
+
254
+ Args:
255
+ prompt (str): A text prompt string.
256
+ negative_prompt: An optional negative text prompt string. Good for
257
+ high-quality generation.
258
+
259
+ Returns:
260
+ A tuple of (negative, positive) prompt embeddings of (1, 77, 768).
261
+ """
262
+ kwargs = dict(padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
263
+
264
+ # Tokenize text and get embeddings.
265
+ text_input = self.tokenizer(prompt, truncation=True, **kwargs)
266
+ text_embeds = self.text_encoder(text_input.input_ids.to(self.device))[0]
267
+ uncond_input = self.tokenizer(negative_prompt, **kwargs)
268
+ uncond_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
269
+ return uncond_embeds, text_embeds
270
+
271
+ @torch.no_grad()
272
+ def get_text_prompts(self, image: Image.Image) -> str:
273
+ r"""A convenient method to extract text prompt from an image.
274
+
275
+ This is called if the user does not provide background prompt but only
276
+ the background image. We use BLIP-2 to automatically generate prompts.
277
+
278
+ Args:
279
+ image (Image.Image): A PIL image.
280
+
281
+ Returns:
282
+ A single string of text prompt.
283
+ """
284
+ question = 'Question: What are in the image? Answer:'
285
+ inputs = self.i2t_processor(image, question, return_tensors='pt')
286
+ out = self.i2t_model.generate(**inputs, max_new_tokens=77)
287
+ prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
288
+ return prompt
289
+
290
+ @torch.no_grad()
291
+ def encode_imgs(
292
+ self,
293
+ imgs: torch.Tensor,
294
+ generator: Optional[torch.Generator] = None,
295
+ vae: Optional[nn.Module] = None,
296
+ ) -> torch.Tensor:
297
+ r"""A wrapper function for VAE encoder of the latent diffusion model.
298
+
299
+ Args:
300
+ imgs (torch.Tensor): An image to get StableDiffusion latents.
301
+ Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
302
+ generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
303
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
304
+ the demo application with TinyVAE).
305
+
306
+ Returns:
307
+ An image latent embedding with 1/8 size (depending on the auto-
308
+ encoder. Shape: (B, 4, H//8, W//8).
309
+ """
310
+ def _retrieve_latents(
311
+ encoder_output: torch.Tensor,
312
+ generator: Optional[torch.Generator] = None,
313
+ sample_mode: str = 'sample',
314
+ ):
315
+ if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
316
+ return encoder_output.latent_dist.sample(generator)
317
+ elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
318
+ return encoder_output.latent_dist.mode()
319
+ elif hasattr(encoder_output, 'latents'):
320
+ return encoder_output.latents
321
+ else:
322
+ raise AttributeError('Could not access latents of provided encoder_output')
323
+
324
+ vae = self.vae if vae is None else vae
325
+ imgs = 2 * imgs - 1
326
+ latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
327
+ return latents
328
+
329
+ @torch.no_grad()
330
+ def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
331
+ r"""A wrapper function for VAE decoder of the latent diffusion model.
332
+
333
+ Args:
334
+ latents (torch.Tensor): An image latent to get associated images.
335
+ Expected shape: (B, 4, H//8, W//8).
336
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
337
+ the demo application with TinyVAE).
338
+
339
+ Returns:
340
+ An image latent embedding with 1/8 size (depending on the auto-
341
+ encoder. Shape: (B, 3, H, W).
342
+ """
343
+ vae = self.vae if vae is None else vae
344
+ latents = 1 / vae.config.scaling_factor * latents
345
+ imgs = vae.decode(latents).sample
346
+ imgs = (imgs / 2 + 0.5).clip_(0, 1)
347
+ return imgs
348
+
349
+ @torch.no_grad()
350
+ def get_white_background(self, height: int, width: int) -> torch.Tensor:
351
+ r"""White background image latent for bootstrapping or in case of
352
+ absent background.
353
+
354
+ Additionally stores the maximally-sized white latent for fast retrieval
355
+ in the future. By default, we initially call this with 768x768 sized
356
+ white image, so the function is rarely visited twice.
357
+
358
+ Args:
359
+ height (int): The height of the white *image*, not its latent.
360
+ width (int): The width of the white *image*, not its latent.
361
+
362
+ Returns:
363
+ A white image latent of size (1, 4, height//8, width//8). A cropped
364
+ version of the stored white latent is returned if the requested
365
+ size is smaller than what we already have created.
366
+ """
367
+ if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
368
+ white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
369
+ self.white = self.encode_imgs(white)
370
+ return self.white
371
+ return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
372
+
373
+ @torch.no_grad()
374
+ def process_mask(
375
+ self,
376
+ masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
377
+ strength: Optional[Union[torch.Tensor, float]] = None,
378
+ std: Optional[Union[torch.Tensor, float]] = None,
379
+ height: int = 512,
380
+ width: int = 512,
381
+ use_boolean_mask: bool = True,
382
+ timesteps: Optional[torch.Tensor] = None,
383
+ preprocess_mask_cover_alpha: Optional[float] = None,
384
+ ) -> Tuple[torch.Tensor]:
385
+ r"""Fast preprocess of masks for region-based generation with fine-
386
+ grained controls.
387
+
388
+ Mask preprocessing is done in four steps:
389
+ 1. Resizing: Resize the masks into the specified width and height by
390
+ nearest neighbor interpolation.
391
+ 2. (Optional) Ordering: Masks with higher indices are considered to
392
+ cover the masks with smaller indices. Covered masks are decayed
393
+ in its alpha value by the specified factor of
394
+ `preprocess_mask_cover_alpha`.
395
+ 3. Blurring: Gaussian blur is applied to the mask with the specified
396
+ standard deviation (isotropic). This results in gradual increase of
397
+ masked region as the timesteps evolve, naturally blending fore-
398
+ ground and the predesignated background. Not strictly required if
399
+ you want to produce images from scratch withoout background.
400
+ 4. Quantization: Split the real-numbered masks of value between [0, 1]
401
+ into predefined noise levels for each quantized scheduling step of
402
+ the diffusion sampler. For example, if the diffusion model sampler
403
+ has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
404
+ is the default noise level of this module with schedule [0, 4, 12,
405
+ 25, 37], the masks are split into binary masks whose values are
406
+ greater than these levels. This results in tradual increase of mask
407
+ region as the timesteps increase. Details are described in our
408
+ paper at https://arxiv.org/pdf/2403.09055.pdf.
409
+
410
+ On the Three Modes of `mask_type`:
411
+ `self.mask_type` is predefined at the initialization stage of this
412
+ pipeline. Three possible modes are available: 'discrete', 'semi-
413
+ continuous', and 'continuous'. These define the mask quantization
414
+ modes we use. Basically, this (subtly) controls the smoothness of
415
+ foreground-background blending. Continuous modes produces nonbinary
416
+ masks to further blend foreground and background latents by linear-
417
+ ly interpolating between them. Semi-continuous masks only applies
418
+ continuous mask at the last step of the LCM sampler. Due to the
419
+ large step size of the LCM scheduler, we find that our continuous
420
+ blending helps generating seamless inpainting and editing results.
421
+
422
+ Args:
423
+ masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
424
+ strength (Optional[Union[torch.Tensor, float]]): Mask strength that
425
+ overrides the default value. A globally multiplied factor to
426
+ the mask at the initial stage of processing. Can be applied
427
+ seperately for each mask.
428
+ std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
429
+ kernel's standard deviation. Overrides the default value. Can
430
+ be applied seperately for each mask.
431
+ height (int): The height of the expected generation. Mask is
432
+ resized to (height//8, width//8) with nearest neighbor inter-
433
+ polation.
434
+ width (int): The width of the expected generation. Mask is resized
435
+ to (height//8, width//8) with nearest neighbor interpolation.
436
+ use_boolean_mask (bool): Specify this to treat the mask image as
437
+ a boolean tensor. The retion with dark part darker than 0.5 of
438
+ the maximal pixel value (that is, 127.5) is considered as the
439
+ designated mask.
440
+ timesteps (Optional[torch.Tensor]): Defines the scheduler noise
441
+ levels that acts as bins of mask quantization.
442
+ preprocess_mask_cover_alpha (Optional[float]): Optional pre-
443
+ processing where each mask covered by other masks is reduced in
444
+ its alpha value by this specified factor. Overrides the default
445
+ value.
446
+
447
+ Returns: A tuple of tensors.
448
+ - masks: Preprocessed (ordered, blurred, and quantized) binary/non-
449
+ binary masks (see the explanation on `mask_type` above) for
450
+ region-based image synthesis.
451
+ - masks_blurred: Gaussian blurred masks. Used for optionally
452
+ specified foreground-background blending after image
453
+ generation.
454
+ - std: Mask blur standard deviation. Used for optionally specified
455
+ foreground-background blending after image generation.
456
+ """
457
+ if isinstance(masks, Image.Image):
458
+ masks = [masks]
459
+ if isinstance(masks, (tuple, list)):
460
+ # Assumes white background for Image.Image;
461
+ # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
462
+ if use_boolean_mask:
463
+ proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
464
+ else:
465
+ proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
466
+ masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
467
+ masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
468
+ masks = masks.to(self.device)
469
+
470
+ # Background mask alpha is decayed by the specified factor where foreground masks covers it.
471
+ if preprocess_mask_cover_alpha is None:
472
+ preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
473
+ if preprocess_mask_cover_alpha > 0:
474
+ masks = torch.stack([
475
+ torch.where(
476
+ masks[i + 1:].sum(dim=0) > 0,
477
+ mask * preprocess_mask_cover_alpha,
478
+ mask,
479
+ ) if i < len(masks) - 1 else mask
480
+ for i, mask in enumerate(masks)
481
+ ], dim=0)
482
+
483
+ # Scheduler noise levels for mask quantization.
484
+ if timesteps is None:
485
+ noise_lvs = self.noise_lvs
486
+ next_noise_lvs = self.next_noise_lvs
487
+ else:
488
+ noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
489
+ noise_lvs = noise_lvs_[None, :, None, None, None]
490
+ next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
491
+
492
+ # Mask preprocessing parameters are fetched from the default settings.
493
+ if std is None:
494
+ std = self.default_mask_std
495
+ if isinstance(std, (int, float)):
496
+ std = [std] * len(masks)
497
+ if isinstance(std, (list, tuple)):
498
+ std = torch.as_tensor(std, dtype=torch.float, device=self.device)
499
+
500
+ if strength is None:
501
+ strength = self.default_mask_strength
502
+ if isinstance(strength, (int, float)):
503
+ strength = [strength] * len(masks)
504
+ if isinstance(strength, (list, tuple)):
505
+ strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
506
+
507
+ if (std > 0).any():
508
+ std = torch.where(std > 0, std, 1e-5)
509
+ masks = gaussian_lowpass(masks, std)
510
+ masks_blurred = masks
511
+
512
+ # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
513
+ # gives unpleasant results.
514
+ masks = masks * strength[:, None, None, None]
515
+ masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
516
+
517
+ # Mask is quantized according to the current noise levels specified by the scheduler.
518
+ if self.mask_type == 'discrete':
519
+ # Discrete mode.
520
+ masks = masks > noise_lvs
521
+ elif self.mask_type == 'semi-continuous':
522
+ # Semi-continuous mode (continuous at the last step only).
523
+ masks = torch.cat((
524
+ masks[:, :-1] > noise_lvs[:, :-1],
525
+ (
526
+ (masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
527
+ ).clip_(0, 1),
528
+ ), dim=1)
529
+ elif self.mask_type == 'continuous':
530
+ # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
531
+ # decreases continuously after the discrete mode boundary to become `0` at the
532
+ # next lower threshold.
533
+ masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
534
+
535
+ # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
536
+ # fine-grained mask alpha channel tuning is available with this form.
537
+ # masks = masks * strength[None, :, None, None, None]
538
+
539
+ h = height // self.vae_scale_factor
540
+ w = width // self.vae_scale_factor
541
+ masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
542
+ masks = F.interpolate(masks, size=(h, w), mode='nearest')
543
+ masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
544
+ return masks, masks_blurred, std
545
+
546
+ def scheduler_step(
547
+ self,
548
+ noise_pred: torch.Tensor,
549
+ idx: int,
550
+ latent: torch.Tensor,
551
+ ) -> torch.Tensor:
552
+ r"""Denoise-only step for reverse diffusion scheduler.
553
+
554
+ Designed to match the interface of the original `pipe.scheduler.step`,
555
+ which is a combination of this method and the following
556
+ `scheduler_add_noise`.
557
+
558
+ Args:
559
+ noise_pred (torch.Tensor): Noise prediction results from the U-Net.
560
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
561
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
562
+ latent (torch.Tensor): Noisy latent.
563
+
564
+ Returns:
565
+ A denoised tensor with the same size as latent.
566
+ """
567
+ F_theta = (latent - self.beta_prod_t_sqrt[idx] * noise_pred) / self.alpha_prod_t_sqrt[idx]
568
+ return self.c_out[idx] * F_theta + self.c_skip[idx] * latent
569
+
570
+ def scheduler_add_noise(
571
+ self,
572
+ latent: torch.Tensor,
573
+ noise: Optional[torch.Tensor],
574
+ idx: int,
575
+ ) -> torch.Tensor:
576
+ r"""Separated noise-add step for the reverse diffusion scheduler.
577
+
578
+ Designed to match the interface of the original
579
+ `pipe.scheduler.add_noise`.
580
+
581
+ Args:
582
+ latent (torch.Tensor): Denoised latent.
583
+ noise (torch.Tensor): Added noise. Can be None. If None, a random
584
+ noise is newly sampled for addition.
585
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
586
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
587
+
588
+ Returns:
589
+ A noisy tensor with the same size as latent.
590
+ """
591
+ if idx >= len(self.alpha_prod_t_sqrt) or idx < 0:
592
+ # The last step does not require noise addition.
593
+ return latent
594
+ noise = torch.randn_like(latent) if noise is None else noise
595
+ return self.alpha_prod_t_sqrt[idx] * latent + self.beta_prod_t_sqrt[idx] * noise
596
+
597
+ @torch.no_grad()
598
+ def sample(
599
+ self,
600
+ prompts: Union[str, List[str]],
601
+ negative_prompts: Union[str, List[str]] = '',
602
+ height: int = 512,
603
+ width: int = 512,
604
+ num_inference_steps: Optional[int] = None,
605
+ guidance_scale: Optional[float] = None,
606
+ batch_size: int = 1,
607
+ ) -> Image.Image:
608
+ r"""StableDiffusionPipeline for single-prompt single-tile generation.
609
+
610
+ Minimal Example:
611
+ >>> device = torch.device('cuda:0')
612
+ >>> smd = StableMultiDiffusionPipeline(device)
613
+ >>> image = smd.sample('A photo of the dolomites')
614
+ >>> image.save('my_creation.png')
615
+
616
+ Args:
617
+ prompts (Union[str, List[str]]): A text prompt.
618
+ negative_prompts (Union[str, List[str]]): A negative text prompt.
619
+ height (int): Height of a generated image.
620
+ width (int): Width of a generated image.
621
+ num_inference_steps (Optional[int]): Number of inference steps.
622
+ Default inference scheduling is used if none is specified.
623
+ guidance_scale (Optional[float]): Classifier guidance scale.
624
+ Default value is used if none is specified.
625
+ batch_size (int): Number of images to generate.
626
+
627
+ Returns: A PIL.Image image.
628
+ """
629
+ if num_inference_steps is None:
630
+ num_inference_steps = self.default_num_inference_steps
631
+ if guidance_scale is None:
632
+ guidance_scale = self.default_guidance_scale
633
+ self.scheduler.set_timesteps(num_inference_steps)
634
+
635
+ if isinstance(prompts, str):
636
+ prompts = [prompts]
637
+ if isinstance(negative_prompts, str):
638
+ negative_prompts = [negative_prompts]
639
+
640
+ # Calculate text embeddings.
641
+ uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
642
+ text_embeds = torch.cat([uncond_embeds.mean(dim=0, keepdim=True), text_embeds.mean(dim=0, keepdim=True)])
643
+ h = height // self.vae_scale_factor
644
+ w = width // self.vae_scale_factor
645
+ latent = torch.randn((batch_size, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
646
+
647
+ with torch.autocast('cuda'):
648
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
649
+ # Expand the latents if we are doing classifier-free guidance.
650
+ latent_model_input = torch.cat([latent] * 2)
651
+
652
+ # Perform one step of the reverse diffusion.
653
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
654
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
655
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
656
+ latent = self.scheduler.step(noise_pred, t, latent)['prev_sample']
657
+
658
+ # Return PIL Image.
659
+ latent = latent.to(dtype=self.dtype)
660
+ imgs = [T.ToPILImage()(self.decode_latents(l[None])[0]) for l in latent]
661
+ return imgs
662
+
663
+ @torch.no_grad()
664
+ def sample_panorama(
665
+ self,
666
+ prompts: Union[str, List[str]],
667
+ negative_prompts: Union[str, List[str]] = '',
668
+ height: int = 512,
669
+ width: int = 2048,
670
+ num_inference_steps: Optional[int] = None,
671
+ guidance_scale: Optional[float] = None,
672
+ tile_size: Optional[int] = None,
673
+ ) -> Image.Image:
674
+ r"""Large size image generation from a single set of prompts.
675
+
676
+ Minimal Example:
677
+ >>> device = torch.device('cuda:0')
678
+ >>> smd = StableMultiDiffusionPipeline(device)
679
+ >>> image = smd.sample_panorama(
680
+ >>> 'A photo of Alps', height=512, width=3072)
681
+ >>> image.save('my_panorama_creation.png')
682
+
683
+ Args:
684
+ prompts (Union[str, List[str]]): A text prompt.
685
+ negative_prompts (Union[str, List[str]]): A negative text prompt.
686
+ height (int): Height of a generated image. It is tiled if larger
687
+ than `tile_size`.
688
+ width (int): Width of a generated image. It is tiled if larger
689
+ than `tile_size`.
690
+ num_inference_steps (Optional[int]): Number of inference steps.
691
+ Default inference scheduling is used if none is specified.
692
+ guidance_scale (Optional[float]): Classifier guidance scale.
693
+ Default value is used if none is specified.
694
+ tile_size (Optional[int]): Tile size of the panorama generation.
695
+ Works best with the default training size of the Stable-
696
+ Diffusion model, i.e., 512 or 768 for SD1.5 and 1024 for SDXL.
697
+
698
+ Returns: A PIL.Image image of a panorama (large-size) image.
699
+ """
700
+ if num_inference_steps is None:
701
+ num_inference_steps = self.default_num_inference_steps
702
+ self.scheduler.set_timesteps(num_inference_steps)
703
+ timesteps = self.timesteps
704
+ use_custom_timesteps = False
705
+ else:
706
+ self.scheduler.set_timesteps(num_inference_steps)
707
+ timesteps = self.scheduler.timesteps
708
+ use_custom_timesteps = True
709
+ if guidance_scale is None:
710
+ guidance_scale = self.default_guidance_scale
711
+
712
+ if isinstance(prompts, str):
713
+ prompts = [prompts]
714
+ if isinstance(negative_prompts, str):
715
+ negative_prompts = [negative_prompts]
716
+
717
+ # Calculate text embeddings.
718
+ uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
719
+ text_embeds = torch.cat([uncond_embeds.mean(dim=0, keepdim=True), text_embeds.mean(dim=0, keepdim=True)])
720
+
721
+ # Define panorama grid and get views
722
+ h = height // self.vae_scale_factor
723
+ w = width // self.vae_scale_factor
724
+ latent = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
725
+
726
+ if tile_size is None:
727
+ tile_size = min(min(height, width), 512)
728
+ views, masks = get_panorama_views(h, w, tile_size // self.vae_scale_factor)
729
+ masks = masks.to(dtype=self.dtype, device=self.device)
730
+ value = torch.zeros_like(latent)
731
+ with torch.autocast('cuda'):
732
+ for i, t in enumerate(tqdm(timesteps)):
733
+ value.zero_()
734
+
735
+ for j, (h_start, h_end, w_start, w_end) in enumerate(views):
736
+ # TODO we can support batches, and pass multiple views at once to the unet
737
+ latent_view = latent[:, :, h_start:h_end, w_start:w_end]
738
+
739
+ # Expand the latents if we are doing classifier-free guidance.
740
+ latent_model_input = torch.cat([latent_view] * 2)
741
+
742
+ # Perform one step of the reverse diffusion.
743
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
744
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
745
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
746
+
747
+ # Compute the denoising step.
748
+ latents_view_denoised = self.scheduler_step(noise_pred, i, latent_view) # (1, 4, h, w)
749
+ mask = masks[..., j:j + 1, h_start:h_end, w_start:w_end] # (1, 1, h, w)
750
+ value[..., h_start:h_end, w_start:w_end] += mask * latents_view_denoised # (1, 1, h, w)
751
+
752
+ # Update denoised latent.
753
+ latent = value.clone()
754
+
755
+ if i < len(timesteps) - 1:
756
+ latent = self.scheduler_add_noise(latent, None, i + 1)
757
+
758
+ # Return PIL Image.
759
+ imgs = self.decode_latents(latent)
760
+ img = T.ToPILImage()(imgs[0].cpu())
761
+ return img
762
+
763
+ @torch.no_grad()
764
+ def __call__(
765
+ self,
766
+ prompts: Optional[Union[str, List[str]]] = None,
767
+ negative_prompts: Union[str, List[str]] = '',
768
+ suffix: Optional[str] = None, #', background is ',
769
+ background: Optional[Union[torch.Tensor, Image.Image]] = None,
770
+ background_prompt: Optional[str] = None,
771
+ background_negative_prompt: str = '',
772
+ height: int = 512,
773
+ width: int = 512,
774
+ num_inference_steps: Optional[int] = None,
775
+ guidance_scale: Optional[float] = None,
776
+ prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
777
+ masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
778
+ mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
779
+ mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
780
+ use_boolean_mask: bool = True,
781
+ do_blend: bool = True,
782
+ tile_size: int = 768,
783
+ bootstrap_steps: Optional[int] = None,
784
+ boostrap_mix_steps: Optional[float] = None,
785
+ bootstrap_leak_sensitivity: Optional[float] = None,
786
+ preprocess_mask_cover_alpha: Optional[float] = None,
787
+ ) -> Image.Image:
788
+ r"""Arbitrary-size image generation from multiple pairs of (regional)
789
+ text prompt-mask pairs.
790
+
791
+ This is a main routine for this pipeline.
792
+
793
+ Example:
794
+ >>> device = torch.device('cuda:0')
795
+ >>> smd = StableMultiDiffusionPipeline(device)
796
+ >>> prompts = {... specify prompts}
797
+ >>> masks = {... specify mask tensors}
798
+ >>> height, width = masks.shape[-2:]
799
+ >>> image = smd(
800
+ >>> prompts, masks=masks.float(), height=height, width=width)
801
+ >>> image.save('my_beautiful_creation.png')
802
+
803
+ Args:
804
+ prompts (Union[str, List[str]]): A text prompt.
805
+ negative_prompts (Union[str, List[str]]): A negative text prompt.
806
+ suffix (Optional[str]): One option for blending foreground prompts
807
+ with background prompts by simply appending background prompt
808
+ to the end of each foreground prompt with this `middle word` in
809
+ between. For example, if you set this as `, background is`,
810
+ then the foreground prompt will be changed into
811
+ `(fg), background is (bg)` before conditional generation.
812
+ background (Optional[Union[torch.Tensor, Image.Image]]): a
813
+ background image, if the user wants to draw in front of the
814
+ specified image. Background prompt will automatically generated
815
+ with a BLIP-2 model.
816
+ background_prompt (Optional[str]): The background prompt is used
817
+ for preprocessing foreground prompt embeddings to blend
818
+ foreground and background.
819
+ background_negative_prompt (Optional[str]): The negative background
820
+ prompt.
821
+ height (int): Height of a generated image. It is tiled if larger
822
+ than `tile_size`.
823
+ width (int): Width of a generated image. It is tiled if larger
824
+ than `tile_size`.
825
+ num_inference_steps (Optional[int]): Number of inference steps.
826
+ Default inference scheduling is used if none is specified.
827
+ guidance_scale (Optional[float]): Classifier guidance scale.
828
+ Default value is used if none is specified.
829
+ prompt_strength (float): Overrides default value. Preprocess
830
+ foreground prompts globally by linearly interpolating its
831
+ embedding with the background prompt embeddint with specified
832
+ mix ratio. Useful control handle for foreground blending.
833
+ Recommended range: 0.5-1.
834
+ masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
835
+ mask images. Each mask associates with each of the text prompts
836
+ and each of the negative prompts. If specified as an image, it
837
+ regards the image as a boolean mask. Also accepts torch.Tensor
838
+ masks, which can have nonbinary values for fine-grained
839
+ controls in mixing regional generations.
840
+ mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
841
+ Overrides the default value. an be assigned for each mask
842
+ separately. Preprocess mask by multiplying it globally with the
843
+ specified variable. Caution: extremely sensitive. Recommended
844
+ range: 0.98-1.
845
+ mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
846
+ Overrides the default value. Can be assigned for each mask
847
+ separately. Preprocess mask with Gaussian blur with specified
848
+ standard deviation. Recommended range: 0-64.
849
+ use_boolean_mask (bool): Turn this off if you want to treat the
850
+ mask image as nonbinary one. The module will use the last
851
+ channel of the given image in `masks` as the mask value.
852
+ do_blend (bool): Blend the generated foreground and the optionally
853
+ predefined background by smooth boundary obtained from Gaussian
854
+ blurs of the foreground `masks` with the given `mask_stds`.
855
+ tile_size (Optional[int]): Tile size of the panorama generation.
856
+ Works best with the default training size of the Stable-
857
+ Diffusion model, i.e., 512 or 768 for SD1.5 and 1024 for SDXL.
858
+ bootstrap_steps (int): Overrides the default value. Bootstrapping
859
+ stage steps to encourage region separation. Recommended range:
860
+ 1-3.
861
+ boostrap_mix_steps (float): Overrides the default value.
862
+ Bootstrapping background is a linear interpolation between
863
+ background latent and the white image latent. This handle
864
+ controls the mix ratio. Available range: 0-(number of
865
+ bootstrapping inference steps). For example, 2.3 means that for
866
+ the first two steps, white image is used as a bootstrapping
867
+ background and in the third step, mixture of white (0.3) and
868
+ registered background (0.7) is used as a bootstrapping
869
+ background.
870
+ bootstrap_leak_sensitivity (float): Overrides the default value.
871
+ Postprocessing at each inference step by masking away the
872
+ remaining bootstrap backgrounds t Recommended range: 0-1.
873
+ preprocess_mask_cover_alpha (float): Overrides the default value.
874
+ Optional preprocessing where each mask covered by other masks
875
+ is reduced in its alpha value by this specified factor.
876
+
877
+ Returns: A PIL.Image image of a panorama (large-size) image.
878
+ """
879
+
880
+ ### Simplest cases
881
+
882
+ # prompts is None: return background.
883
+ # masks is None but prompts is not None: return prompts
884
+ # masks is not None and prompts is not None: Do StableMultiDiffusion.
885
+
886
+ if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
887
+ if background is None and background_prompt is not None:
888
+ return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
889
+ return background
890
+ elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
891
+ return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
892
+
893
+
894
+ ### Prepare generation
895
+
896
+ if num_inference_steps is not None:
897
+ self.prepare_lcm_schedule(list(range(num_inference_steps)), num_inference_steps)
898
+
899
+ if guidance_scale is None:
900
+ guidance_scale = self.default_guidance_scale
901
+
902
+
903
+ ### Prompts & Masks
904
+
905
+ # asserts #m > 0 and #p > 0.
906
+ # #m == #p == #n > 0: We happily generate according to the prompts & masks.
907
+ # #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
908
+ # #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
909
+
910
+ if isinstance(masks, Image.Image):
911
+ masks = [masks]
912
+ if isinstance(prompts, str):
913
+ prompts = [prompts]
914
+ if isinstance(negative_prompts, str):
915
+ negative_prompts = [negative_prompts]
916
+ num_masks = len(masks)
917
+ num_prompts = len(prompts)
918
+ num_nprompts = len(negative_prompts)
919
+ assert num_prompts in (num_masks, 1), \
920
+ f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
921
+ assert num_nprompts in (num_prompts, 1), \
922
+ f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
923
+
924
+ fg_masks, masks_g, std = self.process_mask(
925
+ masks,
926
+ mask_strengths,
927
+ mask_stds,
928
+ height=height,
929
+ width=width,
930
+ use_boolean_mask=use_boolean_mask,
931
+ timesteps=self.timesteps,
932
+ preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
933
+ ) # (p, t, 1, H, W)
934
+ bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
935
+ has_background = bg_masks.sum() > 0
936
+
937
+ h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
938
+ w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
939
+
940
+
941
+ ### Background
942
+
943
+ # background == None && background_prompt == None: Initialize with white background.
944
+ # background == None && background_prompt != None: Generate background *along with other prompts*.
945
+ # background != None && background_prompt == None: Retrieve text prompt using BLIP.
946
+ # background != None && background_prompt != None: Use the given arguments.
947
+
948
+ # not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
949
+ # has_background && prompt_strength != 1: mix only for this case.
950
+
951
+ bg_latent = None
952
+ if has_background:
953
+ if background is None and background_prompt is not None:
954
+ fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
955
+ if suffix is not None:
956
+ prompts = [p + suffix + background_prompt for p in prompts]
957
+ prompts = [background_prompt] + prompts
958
+ negative_prompts = [background_negative_prompt] + negative_prompts
959
+ has_background = False # Regard that background does not exist.
960
+ else:
961
+ if background is None and background_prompt is None:
962
+ background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
963
+ background_prompt = 'simple white background image'
964
+ elif background is not None and background_prompt is None:
965
+ background_prompt = self.get_text_prompts(background)
966
+ if suffix is not None:
967
+ prompts = [p + suffix + background_prompt for p in prompts]
968
+ prompts = [background_prompt] + prompts
969
+ negative_prompts = [background_negative_prompt] + negative_prompts
970
+ if isinstance(background, Image.Image):
971
+ background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
972
+ background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
973
+ bg_latent = self.encode_imgs(background)
974
+
975
+ # Bootstrapping stage preparation.
976
+
977
+ if bootstrap_steps is None:
978
+ bootstrap_steps = self.default_bootstrap_steps
979
+ if boostrap_mix_steps is None:
980
+ boostrap_mix_steps = self.default_boostrap_mix_steps
981
+ if bootstrap_leak_sensitivity is None:
982
+ bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
983
+ if bootstrap_steps > 0:
984
+ height_ = min(height, tile_size)
985
+ width_ = min(width, tile_size)
986
+ white = self.get_white_background(height, width) # (1, 4, h, w)
987
+
988
+
989
+ ### Prepare text embeddings (optimized for the minimal encoder batch size)
990
+
991
+ uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]
992
+ if has_background:
993
+ # First channel is background prompt text embeds. Background prompt itself is not used for generation.
994
+ s = prompt_strengths
995
+ if prompt_strengths is None:
996
+ s = self.default_prompt_strength
997
+ if isinstance(s, (int, float)):
998
+ s = [s] * num_prompts
999
+ if isinstance(s, (list, tuple)):
1000
+ assert len(s) == num_prompts, \
1001
+ f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
1002
+ s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
1003
+ s = s[:, None, None]
1004
+
1005
+ be = text_embeds[:1]
1006
+ bu = uncond_embeds[:1]
1007
+ fe = text_embeds[1:]
1008
+ fu = uncond_embeds[1:]
1009
+ if num_prompts > num_nprompts:
1010
+ # # negative prompts = 1; # prompts > 1.
1011
+ assert fu.shape[0] == 1 and fe.shape == num_prompts
1012
+ fu = fu.repeat(num_prompts, 1, 1)
1013
+ text_embeds = torch.lerp(be, fe, s) # (p, 77, 768)
1014
+ uncond_embeds = torch.lerp(bu, fu, s) # (n, 77, 768)
1015
+ elif num_prompts > num_nprompts:
1016
+ # # negative prompts = 1; # prompts > 1.
1017
+ assert uncond_embeds.shape[0] == 1 and text_embeds.shape[0] == num_prompts
1018
+ uncond_embeds = uncond_embeds.repeat(num_prompts, 1, 1)
1019
+ assert uncond_embeds.shape[0] == text_embeds.shape[0] == num_prompts
1020
+ if num_masks > num_prompts:
1021
+ assert masks.shape[0] == num_masks and num_prompts == 1
1022
+ text_embeds = text_embeds.repeat(num_masks, 1, 1)
1023
+ uncond_embeds = uncond_embeds.repeat(num_masks, 1, 1)
1024
+ text_embeds = torch.cat([uncond_embeds, text_embeds])
1025
+
1026
+
1027
+ ### Run
1028
+
1029
+ # Latent initialization.
1030
+ if self.timesteps[0] < 999 and has_background:
1031
+ latent = self.scheduler_add_noise(bg_latent, None, 0)
1032
+ else:
1033
+ latent = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
1034
+
1035
+ # Tiling (if needed).
1036
+ if height > tile_size or width > tile_size:
1037
+ t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
1038
+ views, tile_masks = get_panorama_views(h, w, t)
1039
+ tile_masks = tile_masks.to(self.device)
1040
+ else:
1041
+ views = [(0, h, 0, w)]
1042
+ tile_masks = latent.new_ones((1, 1, h, w))
1043
+ value = torch.zeros_like(latent)
1044
+ count_all = torch.zeros_like(latent)
1045
+
1046
+ with torch.autocast('cuda'):
1047
+ for i, t in enumerate(tqdm(self.timesteps)):
1048
+ fg_mask = fg_masks[:, i]
1049
+ bg_mask = bg_masks[i:i + 1]
1050
+
1051
+ value.zero_()
1052
+ count_all.zero_()
1053
+ for j, (h_start, h_end, w_start, w_end) in enumerate(views):
1054
+ fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
1055
+ latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
1056
+
1057
+ # Bootstrap for tight background.
1058
+ if i < bootstrap_steps:
1059
+ mix_ratio = min(1, max(0, boostrap_mix_steps - i))
1060
+ # Treat the first foreground latent as the background latent if one does not exist.
1061
+ bg_latent_ = bg_latent[..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
1062
+ white_ = white[..., h_start:h_end, w_start:w_end]
1063
+ bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
1064
+ bg_latent_ = self.scheduler_add_noise(bg_latent_, None, i)
1065
+ latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
1066
+
1067
+ # Centering.
1068
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
1069
+
1070
+ # Perform one step of the reverse diffusion.
1071
+ noise_pred = self.unet(torch.cat([latent_] * 2), t, encoder_hidden_states=text_embeds)['sample']
1072
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1073
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1074
+ latent_ = self.scheduler_step(noise_pred, i, latent_)
1075
+
1076
+ if i < bootstrap_steps:
1077
+ # Uncentering.
1078
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
1079
+
1080
+ # Remove leakage (optional).
1081
+ leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
1082
+ leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
1083
+ fg_mask_ = fg_mask_ * leak_sigmoid
1084
+
1085
+ # Mix the latents.
1086
+ fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
1087
+ value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
1088
+ count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
1089
+
1090
+ latent = torch.where(count_all > 0, value / count_all, value)
1091
+ bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
1092
+ if has_background:
1093
+ latent = (1 - bg_mask) * latent + bg_mask * bg_latent
1094
+
1095
+ # Noise is added after mixing.
1096
+ if i < len(self.timesteps) - 1:
1097
+ latent = self.scheduler_add_noise(latent, None, i + 1)
1098
+
1099
+ # Return PIL Image.
1100
+ image = self.decode_latents(latent.to(dtype=self.dtype))[0]
1101
+ if has_background and do_blend:
1102
+ fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
1103
+ image = blend(image, background[0], fg_mask)
1104
+ else:
1105
+ image = T.ToPILImage()(image)
1106
+ return image
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ xformers==0.0.22
4
+ einops
5
+ diffusers
6
+ transformers
7
+ huggingface_hub[torch]
8
+ gradio
9
+ Pillow
10
+ emoji
11
+ numpy
12
+ tqdm
13
+ jupyterlab
14
+ spaces
util.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import concurrent.futures
22
+ import time
23
+ from typing import Any, Callable, List, Tuple, Union
24
+
25
+ from PIL import Image
26
+ import numpy as np
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.cuda.amp as amp
31
+ import torchvision.transforms as T
32
+ import torchvision.transforms.functional as TF
33
+
34
+
35
+ def seed_everything(seed: int) -> None:
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = True
40
+
41
+
42
+ def get_cutoff(cutoff: float = None, scale: float = None) -> float:
43
+ if cutoff is not None:
44
+ return cutoff
45
+
46
+ if scale is not None and cutoff is None:
47
+ return 0.5 / scale
48
+
49
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
50
+
51
+
52
+ def get_scale(cutoff: float = None, scale: float = None) -> float:
53
+ if scale is not None:
54
+ return scale
55
+
56
+ if cutoff is not None and scale is None:
57
+ return 0.5 / cutoff
58
+
59
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
60
+
61
+
62
+ def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
63
+ assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
64
+ # assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
65
+
66
+ b, c, h, w = x.shape
67
+ ks = k.shape[-1]
68
+ k = k.view(1, 1, -1).repeat(c, 1, 1)
69
+
70
+ x = x.permute(0, 2, 1, 3)
71
+ x = x.reshape(b * h, c, w)
72
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
73
+ x = F.conv1d(x, k, groups=c)
74
+ x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
75
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
76
+ x = F.conv1d(x, k, groups=c)
77
+ x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
78
+ return x
79
+
80
+
81
+ def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
82
+ assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
83
+
84
+ x = F.pad(x, (
85
+ k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
86
+ k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
87
+ ), mode='replicate')
88
+
89
+ b, c, _, _ = x.shape
90
+ if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
91
+ k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
92
+ x = F.conv2d(x, k, groups=c)
93
+ elif len(k.shape) == 3:
94
+ assert k.shape[0] == b, \
95
+ 'The number of kernels should match the batch size.'
96
+
97
+ k = k.unsqueeze(1)
98
+ x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
99
+ return x
100
+
101
+
102
+ @amp.autocast(False)
103
+ def filter_by_kernel(
104
+ x: torch.Tensor,
105
+ k: torch.Tensor,
106
+ is_batch: bool = False,
107
+ ) -> torch.Tensor:
108
+ k_dim = len(k.shape)
109
+ if k_dim == 1 or k_dim == 2 and is_batch:
110
+ return filter_2d_by_kernel_1d(x, k)
111
+ elif k_dim == 2 or k_dim == 3 and is_batch:
112
+ return filter_2d_by_kernel_2d(x, k)
113
+ else:
114
+ raise ValueError('Kernel size should be one of (1, 2, 3).')
115
+
116
+
117
+ def gen_gauss_lowpass_filter_2d(
118
+ std: torch.Tensor,
119
+ window_size: int = None,
120
+ ) -> torch.Tensor:
121
+ # Gaussian kernel size is odd in order to preserve the center.
122
+ if window_size is None:
123
+ window_size = (
124
+ 2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
125
+
126
+ y = torch.arange(
127
+ window_size, dtype=std.dtype, device=std.device
128
+ ).view(-1, 1).repeat(1, window_size)
129
+ grid = torch.stack((y.t(), y), dim=-1)
130
+ grid -= 0.5 * (window_size - 1) # (W, W)
131
+ var = (std * std).unsqueeze(-1).unsqueeze(-1)
132
+ distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
133
+ k = torch.exp(-0.5 * distsq / var)
134
+ k /= k.sum(dim=(-2, -1), keepdim=True)
135
+ return k
136
+
137
+
138
+ def gaussian_lowpass(
139
+ x: torch.Tensor,
140
+ std: Union[float, Tuple[float], torch.Tensor] = None,
141
+ cutoff: Union[float, torch.Tensor] = None,
142
+ scale: Union[float, torch.Tensor] = None,
143
+ ) -> torch.Tensor:
144
+ if std is None:
145
+ cutoff = get_cutoff(cutoff, scale)
146
+ std = 0.5 / (np.pi * cutoff)
147
+ if isinstance(std, (float, int)):
148
+ std = (std, std)
149
+ if isinstance(std, torch.Tensor):
150
+ """Using nn.functional.conv2d with Gaussian kernels built in runtime is
151
+ 80% faster than transforms.functional.gaussian_blur for individual
152
+ items.
153
+
154
+ (in GPU); However, in CPU, the result is exactly opposite. But you
155
+ won't gonna run this on CPU, right?
156
+ """
157
+ if len(list(s for s in std.shape if s != 1)) >= 2:
158
+ raise NotImplementedError(
159
+ 'Anisotropic Gaussian filter is not currently available.')
160
+
161
+ # k.shape == (B, W, W).
162
+ k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
163
+ if k.shape[0] == 1:
164
+ return filter_by_kernel(x, k[0], False)
165
+ else:
166
+ return filter_by_kernel(x, k, True)
167
+ else:
168
+ # Gaussian kernel size is odd in order to preserve the center.
169
+ window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
170
+ return TF.gaussian_blur(x, window_size, std)
171
+
172
+
173
+ def blend(
174
+ fg: Union[torch.Tensor, Image.Image],
175
+ bg: Union[torch.Tensor, Image.Image],
176
+ mask: Union[torch.Tensor, Image.Image],
177
+ std: float = 0.0,
178
+ ) -> Image.Image:
179
+ if not isinstance(fg, torch.Tensor):
180
+ fg = T.ToTensor()(fg)
181
+ if not isinstance(bg, torch.Tensor):
182
+ bg = T.ToTensor()(bg)
183
+ if not isinstance(mask, torch.Tensor):
184
+ mask = (T.ToTensor()(mask) < 0.5).float()[:1]
185
+ if std > 0:
186
+ mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
187
+ return T.ToPILImage()(fg * mask + bg * (1 - mask))
188
+
189
+
190
+ def get_panorama_views(
191
+ panorama_height: int,
192
+ panorama_width: int,
193
+ window_size: int = 64,
194
+ ) -> tuple[List[Tuple[int]], torch.Tensor]:
195
+ stride = window_size // 2
196
+ is_horizontal = panorama_width > panorama_height
197
+ num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
198
+ num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
199
+ total_num_blocks = num_blocks_height * num_blocks_width
200
+
201
+ half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
202
+ half_rev = half_fwd.flip(0)
203
+ if window_size % 2 == 1:
204
+ half_rev = half_rev[1:]
205
+ c = torch.cat((half_fwd, half_rev))
206
+ one = torch.ones_like(c)
207
+ f = c.clone()
208
+ f[:window_size // 2] = 1
209
+ b = c.clone()
210
+ b[-(window_size // 2):] = 1
211
+
212
+ h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
213
+ w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
214
+
215
+ views = []
216
+ masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
217
+ for i in range(total_num_blocks):
218
+ hi, wi = i // num_blocks_width, i % num_blocks_width
219
+ h_start = hi * stride
220
+ h_end = min(h_start + window_size, panorama_height)
221
+ w_start = wi * stride
222
+ w_end = min(w_start + window_size, panorama_width)
223
+ views.append((h_start, h_end, w_start, w_end))
224
+
225
+ h_width = h_end - h_start
226
+ w_width = w_end - w_start
227
+ masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
228
+
229
+ # Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
230
+ return views, masks[None] # (1, n, h, w)
231
+
232
+
233
+ def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
234
+ h, w = mask.shape[-2:]
235
+ device = mask.device
236
+ mask = mask.reshape(-1, h, w)
237
+ # assert mask.shape[0] == im.shape[0]
238
+ h_occupied = mask.sum(dim=-2) > 0
239
+ w_occupied = mask.sum(dim=-1) > 0
240
+ l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
241
+ r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
242
+ t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
243
+ b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
244
+ tb = (t + b + 1) // 2
245
+ lr = (l + r + 1) // 2
246
+ shifts = (tb - (h // 2), lr - (w // 2))
247
+ shifts = torch.cat(shifts, dim=1) # (p, 2)
248
+ if reverse:
249
+ shifts = shifts * -1
250
+ return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
251
+
252
+
253
+ class Streamer:
254
+ def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
255
+ self.fn = fn
256
+ self.ema_alpha = ema_alpha
257
+
258
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
259
+ self.future = self.executor.submit(fn)
260
+ self.image = None
261
+
262
+ self.prev_exec_time = 0
263
+ self.ema_exec_time = 0
264
+
265
+ @property
266
+ def throughput(self) -> float:
267
+ return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
268
+
269
+ def timed_fn(self) -> Any:
270
+ start = time.time()
271
+ res = self.fn()
272
+ end = time.time()
273
+ self.prev_exec_time = end - start
274
+ self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
275
+ return res
276
+
277
+ def __call__(self) -> Any:
278
+ if self.future.done() or self.image is None:
279
+ # get the result (the new image) and start a new task
280
+ image = self.future.result()
281
+ self.future = self.executor.submit(self.timed_fn)
282
+ self.image = image
283
+ return image
284
+ else:
285
+ # if self.fn() is not ready yet, use the previous image
286
+ # NOTE: This assumes that we have access to a previously generated image here.
287
+ # If there's no previous image (i.e., this is the first invocation), you could fall
288
+ # back to some default image or handle it differently based on your requirements.
289
+ return self.image