Spaces:
Runtime error
Runtime error
added first version
Browse files- .gitignore +3 -0
- README.md +7 -4
- app.py +819 -0
- examples/prompt_background.txt +8 -0
- examples/prompt_background_advanced.txt +0 -0
- examples/prompt_boy.txt +15 -0
- examples/prompt_girl.txt +16 -0
- examples/prompt_props.txt +43 -0
- model.py +1106 -0
- requirements.txt +14 -0
- util.py +289 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.*.sw*
|
3 |
+
.ipynb_checkpoints/
|
README.md
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
---
|
2 |
title: SemanticPalette
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: SemanticPalette
|
3 |
+
emoji: 🧠🎨
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
+
suggested_hardware: t4-small
|
12 |
+
suggested_storage: small
|
13 |
+
models: ironjr/BlazingDriveV11m
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import argparse
|
22 |
+
import random
|
23 |
+
import time
|
24 |
+
import json
|
25 |
+
import os
|
26 |
+
import glob
|
27 |
+
import pathlib
|
28 |
+
from functools import partial
|
29 |
+
from pprint import pprint
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
from PIL import Image
|
33 |
+
import torch
|
34 |
+
|
35 |
+
import spaces
|
36 |
+
import gradio as gr
|
37 |
+
from huggingface_hub import snapshot_download
|
38 |
+
|
39 |
+
from model import StableMultiDiffusionPipeline
|
40 |
+
from util import seed_everything
|
41 |
+
|
42 |
+
|
43 |
+
### Utils
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
def log_state(state):
|
49 |
+
pprint(vars(opt))
|
50 |
+
if isinstance(state, gr.State):
|
51 |
+
state = state.value
|
52 |
+
pprint(vars(state))
|
53 |
+
|
54 |
+
|
55 |
+
def is_empty_image(im: Image.Image) -> bool:
|
56 |
+
if im is None:
|
57 |
+
return True
|
58 |
+
im = np.array(im)
|
59 |
+
has_alpha = (im.shape[2] == 4)
|
60 |
+
if not has_alpha:
|
61 |
+
return False
|
62 |
+
elif im.sum() == 0:
|
63 |
+
return True
|
64 |
+
else:
|
65 |
+
return False
|
66 |
+
|
67 |
+
|
68 |
+
### Argument passing
|
69 |
+
|
70 |
+
parser = argparse.ArgumentParser(description='Semantic drawing demo powered by StreamMultiDiffusion.')
|
71 |
+
parser.add_argument('-H', '--height', type=int, default=768)
|
72 |
+
parser.add_argument('-W', '--width', type=int, default=1920)
|
73 |
+
parser.add_argument('--model', type=str, default=None)
|
74 |
+
parser.add_argument('--bootstrap_steps', type=int, default=1)
|
75 |
+
parser.add_argument('--seed', type=int, default=-1)
|
76 |
+
parser.add_argument('--device', type=int, default=0)
|
77 |
+
parser.add_argument('--port', type=int, default=8000)
|
78 |
+
opt = parser.parse_args()
|
79 |
+
|
80 |
+
|
81 |
+
### Global variables and data structures
|
82 |
+
|
83 |
+
device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
|
84 |
+
|
85 |
+
|
86 |
+
model_dict = {
|
87 |
+
'Blazing Drive V11m': 'ironjr/BlazingDriveV11m',
|
88 |
+
'Real Cartoon Pixar V5': 'ironjr/RealCartoon-PixarV5',
|
89 |
+
'Kohaku V2.1': 'KBlueLeaf/kohaku-v2.1',
|
90 |
+
'Realistic Vision V5.1': 'ironjr/RealisticVisionV5-1',
|
91 |
+
'Stable Diffusion V1.5': 'runwayml/stable-diffusion-v1-5',
|
92 |
+
}
|
93 |
+
|
94 |
+
models = {
|
95 |
+
k: StableMultiDiffusionPipeline(device, sd_version='1.5', hf_key=v)
|
96 |
+
for k, v in model_dict.items()
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
prompt_suggestions = [
|
101 |
+
'1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
|
102 |
+
'1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
|
103 |
+
'1girl, arima kana, oshi no ko, solo, upper body, from behind',
|
104 |
+
]
|
105 |
+
|
106 |
+
opt.max_palettes = 5
|
107 |
+
opt.default_prompt_strength = 1.0
|
108 |
+
opt.default_mask_strength = 1.0
|
109 |
+
opt.default_mask_std = 0.0
|
110 |
+
opt.default_negative_prompt = (
|
111 |
+
'nsfw, worst quality, bad quality, normal quality, cropped, framed'
|
112 |
+
)
|
113 |
+
opt.verbose = True
|
114 |
+
opt.colors = [
|
115 |
+
'#000000',
|
116 |
+
'#2692F3',
|
117 |
+
'#F89E12',
|
118 |
+
'#16C232',
|
119 |
+
'#F92F6C',
|
120 |
+
'#AC6AEB',
|
121 |
+
# '#92C62C',
|
122 |
+
# '#92C6EC',
|
123 |
+
# '#FECAC0',
|
124 |
+
]
|
125 |
+
|
126 |
+
|
127 |
+
### Event handlers
|
128 |
+
|
129 |
+
def add_palette(state):
|
130 |
+
old_actives = state.active_palettes
|
131 |
+
state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
|
132 |
+
|
133 |
+
if opt.verbose:
|
134 |
+
log_state(state)
|
135 |
+
|
136 |
+
if state.active_palettes != old_actives:
|
137 |
+
return [state] + [
|
138 |
+
gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
|
139 |
+
] + [
|
140 |
+
gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
|
141 |
+
for i in range(opt.max_palettes)
|
142 |
+
]
|
143 |
+
else:
|
144 |
+
return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
|
145 |
+
|
146 |
+
|
147 |
+
def select_palette(state, button, idx):
|
148 |
+
if idx < 0 or idx > opt.max_palettes:
|
149 |
+
idx = 0
|
150 |
+
old_idx = state.current_palette
|
151 |
+
if old_idx == idx:
|
152 |
+
return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
|
153 |
+
|
154 |
+
state.current_palette = idx
|
155 |
+
|
156 |
+
if opt.verbose:
|
157 |
+
log_state(state)
|
158 |
+
|
159 |
+
updates = [state] + [
|
160 |
+
gr.update() if i not in (idx, old_idx) else
|
161 |
+
gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
|
162 |
+
for i in range(opt.max_palettes + 1)
|
163 |
+
]
|
164 |
+
label = 'Background' if idx == 0 else f'Palette {idx}'
|
165 |
+
updates.extend([
|
166 |
+
gr.update(value=button, interactive=(idx > 0)),
|
167 |
+
gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
|
168 |
+
gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
|
169 |
+
(
|
170 |
+
gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
|
171 |
+
gr.update(value=opt.default_mask_strength, interactive=False)
|
172 |
+
),
|
173 |
+
(
|
174 |
+
gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
|
175 |
+
gr.update(value=opt.default_prompt_strength, interactive=False)
|
176 |
+
),
|
177 |
+
(
|
178 |
+
gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
|
179 |
+
gr.update(value=opt.default_mask_std, interactive=False)
|
180 |
+
),
|
181 |
+
])
|
182 |
+
return updates
|
183 |
+
|
184 |
+
|
185 |
+
def change_prompt_strength(state, strength):
|
186 |
+
if state.current_palette == 0:
|
187 |
+
return state
|
188 |
+
|
189 |
+
state.prompt_strengths[state.current_palette - 1] = strength
|
190 |
+
if opt.verbose:
|
191 |
+
log_state(state)
|
192 |
+
|
193 |
+
return state
|
194 |
+
|
195 |
+
|
196 |
+
def change_std(state, std):
|
197 |
+
if state.current_palette == 0:
|
198 |
+
return state
|
199 |
+
|
200 |
+
state.mask_stds[state.current_palette - 1] = std
|
201 |
+
if opt.verbose:
|
202 |
+
log_state(state)
|
203 |
+
|
204 |
+
return state
|
205 |
+
|
206 |
+
|
207 |
+
def change_mask_strength(state, strength):
|
208 |
+
if state.current_palette == 0:
|
209 |
+
return state
|
210 |
+
|
211 |
+
state.mask_strengths[state.current_palette - 1] = strength
|
212 |
+
if opt.verbose:
|
213 |
+
log_state(state)
|
214 |
+
|
215 |
+
return state
|
216 |
+
|
217 |
+
|
218 |
+
def reset_seed(state, seed):
|
219 |
+
state.seed = seed
|
220 |
+
if opt.verbose:
|
221 |
+
log_state(state)
|
222 |
+
|
223 |
+
return state
|
224 |
+
|
225 |
+
def rename_prompt(state, name):
|
226 |
+
state.prompt_names[state.current_palette] = name
|
227 |
+
if opt.verbose:
|
228 |
+
log_state(state)
|
229 |
+
|
230 |
+
return [state] + [
|
231 |
+
gr.update() if i != state.current_palette else gr.update(value=name)
|
232 |
+
for i in range(opt.max_palettes + 1)
|
233 |
+
]
|
234 |
+
|
235 |
+
|
236 |
+
def change_prompt(state, prompt):
|
237 |
+
state.prompts[state.current_palette] = prompt
|
238 |
+
if opt.verbose:
|
239 |
+
log_state(state)
|
240 |
+
|
241 |
+
return state
|
242 |
+
|
243 |
+
|
244 |
+
def change_neg_prompt(state, neg_prompt):
|
245 |
+
state.neg_prompts[state.current_palette] = neg_prompt
|
246 |
+
if opt.verbose:
|
247 |
+
log_state(state)
|
248 |
+
|
249 |
+
return state
|
250 |
+
|
251 |
+
|
252 |
+
def select_model(state, model_id):
|
253 |
+
state.model_id = model_id
|
254 |
+
if opt.verbose:
|
255 |
+
log_state(state)
|
256 |
+
|
257 |
+
return state
|
258 |
+
|
259 |
+
|
260 |
+
def import_state(state, json_text):
|
261 |
+
current_palette = state.current_palette
|
262 |
+
# active_palettes = state.active_palettes
|
263 |
+
state = argparse.Namespace(**json.loads(json_text))
|
264 |
+
state.active_palettes = opt.max_palettes
|
265 |
+
return [state] + [
|
266 |
+
gr.update(value=v, visible=True) for v in state.prompt_names
|
267 |
+
] + [
|
268 |
+
state.model_id,
|
269 |
+
state.prompts[current_palette],
|
270 |
+
state.prompt_names[current_palette],
|
271 |
+
state.neg_prompts[current_palette],
|
272 |
+
state.prompt_strengths[current_palette - 1],
|
273 |
+
state.mask_strengths[current_palette - 1],
|
274 |
+
state.mask_stds[current_palette - 1],
|
275 |
+
state.seed,
|
276 |
+
]
|
277 |
+
|
278 |
+
|
279 |
+
### Main worker
|
280 |
+
|
281 |
+
@spaces.GPU
|
282 |
+
def generate(state, *args, **kwargs):
|
283 |
+
return models[state.model_id](*args, **kwargs)
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
def run(state, drawpad):
|
288 |
+
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
289 |
+
print('Generate!')
|
290 |
+
|
291 |
+
background = drawpad['background'].convert('RGBA')
|
292 |
+
inpainting_mode = np.asarray(background).sum() != 0
|
293 |
+
print('Inpainting mode: ', inpainting_mode)
|
294 |
+
|
295 |
+
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
296 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
297 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
298 |
+
|
299 |
+
palette = torch.tensor([
|
300 |
+
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
301 |
+
for s in opt.colors[1:]
|
302 |
+
]) # (N, 3)
|
303 |
+
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
304 |
+
has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
305 |
+
print('Has mask: ', has_masks)
|
306 |
+
masks = masks * foreground_mask
|
307 |
+
masks = masks[has_masks]
|
308 |
+
|
309 |
+
# if inpainting_mode:
|
310 |
+
# prompts = state.prompts[1:len(masks)+1]
|
311 |
+
# negative_prompts = state.neg_prompts[1:len(masks)+1]
|
312 |
+
# mask_strengths = state.mask_strengths[:len(masks)]
|
313 |
+
# mask_stds = state.mask_stds[:len(masks)]
|
314 |
+
# prompt_strengths = state.prompt_strengths[:len(masks)]
|
315 |
+
# else:
|
316 |
+
# masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
317 |
+
# prompts = state.prompts[:len(masks)+1]
|
318 |
+
# negative_prompts = state.neg_prompts[:len(masks)+1]
|
319 |
+
# mask_strengths = [1] + state.mask_strengths[:len(masks)]
|
320 |
+
# mask_stds = [0] + [state.mask_stds[:len(masks)]
|
321 |
+
# prompt_strengths = [1] + state.prompt_strengths[:len(masks)]
|
322 |
+
|
323 |
+
if inpainting_mode:
|
324 |
+
prompts = [state.prompts[v + 1] for v in has_masks]
|
325 |
+
negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
|
326 |
+
mask_strengths = [state.mask_strengths[v] for v in has_masks]
|
327 |
+
mask_stds = [state.mask_stds[v] for v in has_masks]
|
328 |
+
prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
|
329 |
+
else:
|
330 |
+
masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
331 |
+
prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
|
332 |
+
negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
|
333 |
+
mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
334 |
+
mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
335 |
+
prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
|
336 |
+
|
337 |
+
return generate(
|
338 |
+
state,
|
339 |
+
prompts,
|
340 |
+
negative_prompts,
|
341 |
+
masks=masks,
|
342 |
+
mask_strengths=mask_strengths,
|
343 |
+
mask_stds=mask_stds,
|
344 |
+
prompt_strengths=prompt_strengths,
|
345 |
+
background=background.convert('RGB'),
|
346 |
+
background_prompt=state.prompts[0],
|
347 |
+
background_negative_prompt=state.neg_prompts[0],
|
348 |
+
height=opt.height,
|
349 |
+
width=opt.width,
|
350 |
+
bootstrap_steps=2,
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
### Load examples
|
356 |
+
|
357 |
+
|
358 |
+
root = pathlib.Path(__file__).parent
|
359 |
+
example_root = os.path.join(root, 'examples')
|
360 |
+
example_images = glob.glob(os.path.join(example_root, '*.png'))
|
361 |
+
example_images = [Image.open(i) for i in example_images]
|
362 |
+
|
363 |
+
with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
|
364 |
+
prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
|
365 |
+
|
366 |
+
with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
|
367 |
+
prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
|
368 |
+
|
369 |
+
with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
|
370 |
+
prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
|
371 |
+
|
372 |
+
with open(os.path.join(example_root, 'prompt_props.txt')) as f:
|
373 |
+
prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
|
374 |
+
prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
|
375 |
+
|
376 |
+
prompt_background = lambda: random.choice(prompts_background)
|
377 |
+
prompt_girl = lambda: random.choice(prompts_girl)
|
378 |
+
prompt_boy = lambda: random.choice(prompts_boy)
|
379 |
+
prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
|
380 |
+
|
381 |
+
|
382 |
+
### Main application
|
383 |
+
|
384 |
+
css = f"""
|
385 |
+
#run-button {{
|
386 |
+
font-size: 30pt;
|
387 |
+
background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
|
388 |
+
margin: 0;
|
389 |
+
padding: 15px 45px;
|
390 |
+
text-align: center;
|
391 |
+
text-transform: uppercase;
|
392 |
+
transition: 0.5s;
|
393 |
+
background-size: 200% auto;
|
394 |
+
color: white;
|
395 |
+
box-shadow: 0 0 20px #eee;
|
396 |
+
border-radius: 10px;
|
397 |
+
display: block;
|
398 |
+
background-position: right center;
|
399 |
+
}}
|
400 |
+
|
401 |
+
#run-button:hover {{
|
402 |
+
background-position: left center;
|
403 |
+
color: #fff;
|
404 |
+
text-decoration: none;
|
405 |
+
}}
|
406 |
+
|
407 |
+
#semantic-palette {{
|
408 |
+
border-style: solid;
|
409 |
+
border-width: 0.2em;
|
410 |
+
border-color: #eee;
|
411 |
+
}}
|
412 |
+
|
413 |
+
#semantic-palette:hover {{
|
414 |
+
box-shadow: 0 0 20px #eee;
|
415 |
+
}}
|
416 |
+
|
417 |
+
#output-screen {{
|
418 |
+
width: 100%;
|
419 |
+
aspect-ratio: {opt.width} / {opt.height};
|
420 |
+
}}
|
421 |
+
|
422 |
+
.layer-wrap {{
|
423 |
+
display: none;
|
424 |
+
}}
|
425 |
+
"""
|
426 |
+
|
427 |
+
for i in range(opt.max_palettes + 1):
|
428 |
+
css = css + f"""
|
429 |
+
.secondary#semantic-palette-{i} {{
|
430 |
+
background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
|
431 |
+
}}
|
432 |
+
|
433 |
+
.primary#semantic-palette-{i} {{
|
434 |
+
background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
|
435 |
+
}}
|
436 |
+
"""
|
437 |
+
|
438 |
+
|
439 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
440 |
+
|
441 |
+
iface = argparse.Namespace()
|
442 |
+
|
443 |
+
def _define_state():
|
444 |
+
state = argparse.Namespace()
|
445 |
+
|
446 |
+
# Cursor.
|
447 |
+
state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
|
448 |
+
state.model_id = list(model_dict.keys())[0]
|
449 |
+
|
450 |
+
# State variables (one-hot).
|
451 |
+
state.active_palettes = 1
|
452 |
+
|
453 |
+
# Front-end initialized to the default values.
|
454 |
+
prompt_props_ = prompt_props()
|
455 |
+
# state.prompt_names = [
|
456 |
+
# '🌄 Background',
|
457 |
+
# '👧 Girl',
|
458 |
+
# '👦 Boy',
|
459 |
+
# '🐶 Dog',
|
460 |
+
# '🚗 Car',
|
461 |
+
# '💐 Garden',
|
462 |
+
# ] + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
|
463 |
+
# state.prompts = [
|
464 |
+
# 'Maximalism, best quality, high quality, city lights, times square',
|
465 |
+
# '1girl, looking at viewer, pink hair, leather jacket',
|
466 |
+
# '1boy, looking at viewer, brown hair, casual shirt',
|
467 |
+
# 'Doggy body part',
|
468 |
+
# 'Car',
|
469 |
+
# 'Flower garden',
|
470 |
+
# ] + ['' for _ in range(opt.max_palettes - 5)]
|
471 |
+
state.prompt_names = [
|
472 |
+
'🌄 Background',
|
473 |
+
'👧 Girl',
|
474 |
+
'👦 Boy',
|
475 |
+
] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
|
476 |
+
state.prompts = [
|
477 |
+
prompt_background(),
|
478 |
+
prompt_girl(),
|
479 |
+
prompt_boy(),
|
480 |
+
] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
|
481 |
+
state.neg_prompts = [
|
482 |
+
opt.default_negative_prompt
|
483 |
+
+ (', humans, humans, humans' if i == 0 else '')
|
484 |
+
for i in range(opt.max_palettes + 1)
|
485 |
+
]
|
486 |
+
state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
|
487 |
+
state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
|
488 |
+
state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
|
489 |
+
state.seed = opt.seed
|
490 |
+
return state
|
491 |
+
|
492 |
+
state = gr.State(value=_define_state)
|
493 |
+
|
494 |
+
|
495 |
+
### Demo user interface
|
496 |
+
|
497 |
+
gr.HTML(
|
498 |
+
"""
|
499 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
500 |
+
<div>
|
501 |
+
<h1>🧠 Semantic Paint 🎨</h1>
|
502 |
+
<h5 style="margin: 0;">powered by</h5>
|
503 |
+
<h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
|
504 |
+
<h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
|
505 |
+
</br>
|
506 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
507 |
+
<a href='https://arxiv.org/abs/2403.09055'>
|
508 |
+
<img src="https://img.shields.io/badge/arXiv-2403.09055-red">
|
509 |
+
</a>
|
510 |
+
|
511 |
+
<a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
|
512 |
+
<img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
|
513 |
+
</a>
|
514 |
+
|
515 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion'>
|
516 |
+
<img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
|
517 |
+
</a>
|
518 |
+
|
519 |
+
<a href='https://twitter.com/_ironjr_'>
|
520 |
+
<img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
|
521 |
+
</a>
|
522 |
+
|
523 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
|
524 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
525 |
+
</a>
|
526 |
+
|
527 |
+
<a href='https://huggingface.co/papers/2403.09055'>
|
528 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper-yellow'>
|
529 |
+
</a>
|
530 |
+
</div>
|
531 |
+
</div>
|
532 |
+
</div>
|
533 |
+
<div>
|
534 |
+
</br>
|
535 |
+
</div>
|
536 |
+
"""
|
537 |
+
)
|
538 |
+
|
539 |
+
with gr.Row():
|
540 |
+
|
541 |
+
iface.image_slot = gr.Image(
|
542 |
+
interactive=False,
|
543 |
+
show_label=False,
|
544 |
+
show_download_button=True,
|
545 |
+
type='pil',
|
546 |
+
label='Generated Result',
|
547 |
+
elem_id='output-screen',
|
548 |
+
show_share_button=True,
|
549 |
+
value=lambda: random.choice(example_images),
|
550 |
+
)
|
551 |
+
|
552 |
+
with gr.Row():
|
553 |
+
|
554 |
+
with gr.Column(scale=1):
|
555 |
+
|
556 |
+
with gr.Group(elem_id='semantic-palette'):
|
557 |
+
|
558 |
+
gr.HTML(
|
559 |
+
"""
|
560 |
+
<div style="justify-content: center; align-items: center;">
|
561 |
+
<br/>
|
562 |
+
<h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
|
563 |
+
<br/>
|
564 |
+
</div>
|
565 |
+
"""
|
566 |
+
)
|
567 |
+
|
568 |
+
iface.btn_semantics = [gr.Button(
|
569 |
+
value=state.value.prompt_names[0],
|
570 |
+
variant='primary',
|
571 |
+
elem_id='semantic-palette-0',
|
572 |
+
)]
|
573 |
+
for i in range(opt.max_palettes):
|
574 |
+
iface.btn_semantics.append(gr.Button(
|
575 |
+
value=state.value.prompt_names[i + 1],
|
576 |
+
variant='secondary',
|
577 |
+
visible=(i < state.value.active_palettes),
|
578 |
+
elem_id=f'semantic-palette-{i + 1}'
|
579 |
+
))
|
580 |
+
|
581 |
+
iface.btn_add_palette = gr.Button(
|
582 |
+
value='Create New Semantic Brush',
|
583 |
+
variant='primary',
|
584 |
+
)
|
585 |
+
|
586 |
+
with gr.Accordion(label='Import/Export Semantic Palette', open=False):
|
587 |
+
iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
|
588 |
+
iface.json_state_export = gr.JSON(label='Exported Palette')
|
589 |
+
iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
|
590 |
+
iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
|
591 |
+
|
592 |
+
gr.HTML(
|
593 |
+
"""
|
594 |
+
<div>
|
595 |
+
</br>
|
596 |
+
</div>
|
597 |
+
<div style="justify-content: center; align-items: center;">
|
598 |
+
<h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
|
599 |
+
</br>
|
600 |
+
<div style="justify-content: center; align-items: left; text-align: left;">
|
601 |
+
<p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
|
602 |
+
<p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
|
603 |
+
<p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
|
604 |
+
<p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
|
605 |
+
<p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
|
606 |
+
<p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
|
607 |
+
</div>
|
608 |
+
</div>
|
609 |
+
"""
|
610 |
+
)
|
611 |
+
|
612 |
+
gr.HTML(
|
613 |
+
"""
|
614 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
615 |
+
<h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
|
616 |
+
</div>
|
617 |
+
"""
|
618 |
+
)
|
619 |
+
|
620 |
+
gr.DuplicateButton()
|
621 |
+
|
622 |
+
with gr.Column(scale=4):
|
623 |
+
|
624 |
+
with gr.Row():
|
625 |
+
|
626 |
+
with gr.Column(scale=3):
|
627 |
+
|
628 |
+
iface.ctrl_semantic = gr.ImageEditor(
|
629 |
+
image_mode='RGBA',
|
630 |
+
sources=['upload', 'clipboard', 'webcam'],
|
631 |
+
transforms=['crop'],
|
632 |
+
crop_size=(opt.width, opt.height),
|
633 |
+
brush=gr.Brush(
|
634 |
+
colors=opt.colors[1:],
|
635 |
+
color_mode="fixed",
|
636 |
+
),
|
637 |
+
type='pil',
|
638 |
+
label='Semantic Drawpad',
|
639 |
+
elem_id='drawpad',
|
640 |
+
show_share_button=True,
|
641 |
+
)
|
642 |
+
|
643 |
+
with gr.Column(scale=1):
|
644 |
+
|
645 |
+
iface.btn_generate = gr.Button(
|
646 |
+
value='Generate!',
|
647 |
+
variant='primary',
|
648 |
+
# scale=1,
|
649 |
+
elem_id='run-button'
|
650 |
+
)
|
651 |
+
|
652 |
+
iface.model_select = gr.Radio(
|
653 |
+
list(model_dict.keys()),
|
654 |
+
label='Stable Diffusion Checkpoint',
|
655 |
+
info='Choose your favorite style.',
|
656 |
+
value=state.value.model_id,
|
657 |
+
)
|
658 |
+
|
659 |
+
with gr.Group(elem_id='control-panel'):
|
660 |
+
|
661 |
+
with gr.Row():
|
662 |
+
iface.tbox_prompt = gr.Textbox(
|
663 |
+
label='Edit Prompt for Background',
|
664 |
+
info='What do you want to draw?',
|
665 |
+
value=state.value.prompts[0],
|
666 |
+
placeholder=lambda: random.choice(prompt_suggestions),
|
667 |
+
scale=2,
|
668 |
+
)
|
669 |
+
|
670 |
+
iface.tbox_name = gr.Textbox(
|
671 |
+
label='Edit Brush Name',
|
672 |
+
info='Just for your convenience.',
|
673 |
+
value=state.value.prompt_names[0],
|
674 |
+
placeholder='🌄 Background',
|
675 |
+
scale=1,
|
676 |
+
)
|
677 |
+
|
678 |
+
with gr.Row():
|
679 |
+
iface.tbox_neg_prompt = gr.Textbox(
|
680 |
+
label='Edit Negative Prompt for Background',
|
681 |
+
info='Add unwanted objects for this semantic brush.',
|
682 |
+
value=opt.default_negative_prompt,
|
683 |
+
scale=2,
|
684 |
+
)
|
685 |
+
|
686 |
+
iface.slider_strength = gr.Slider(
|
687 |
+
label='Prompt Strength',
|
688 |
+
info='Blends fg & bg in the prompt level, >0.8 Preferred.',
|
689 |
+
minimum=0.5,
|
690 |
+
maximum=1.0,
|
691 |
+
value=opt.default_prompt_strength,
|
692 |
+
scale=1,
|
693 |
+
)
|
694 |
+
|
695 |
+
with gr.Row():
|
696 |
+
iface.slider_alpha = gr.Slider(
|
697 |
+
label='Mask Alpha',
|
698 |
+
info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
|
699 |
+
minimum=0.5,
|
700 |
+
maximum=1.0,
|
701 |
+
value=opt.default_mask_strength,
|
702 |
+
)
|
703 |
+
|
704 |
+
iface.slider_std = gr.Slider(
|
705 |
+
label='Mask Blur STD',
|
706 |
+
info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
|
707 |
+
minimum=0.0001,
|
708 |
+
maximum=100.0,
|
709 |
+
value=opt.default_mask_std,
|
710 |
+
)
|
711 |
+
|
712 |
+
iface.slider_seed = gr.Slider(
|
713 |
+
label='Seed',
|
714 |
+
info='The global seed.',
|
715 |
+
minimum=-1,
|
716 |
+
maximum=2147483647,
|
717 |
+
step=1,
|
718 |
+
value=opt.seed,
|
719 |
+
)
|
720 |
+
|
721 |
+
### Attach event handlers
|
722 |
+
|
723 |
+
for idx, btn in enumerate(iface.btn_semantics):
|
724 |
+
btn.click(
|
725 |
+
fn=partial(select_palette, idx=idx),
|
726 |
+
inputs=[state, btn],
|
727 |
+
outputs=[state] + iface.btn_semantics + [
|
728 |
+
iface.tbox_name,
|
729 |
+
iface.tbox_prompt,
|
730 |
+
iface.tbox_neg_prompt,
|
731 |
+
iface.slider_alpha,
|
732 |
+
iface.slider_strength,
|
733 |
+
iface.slider_std,
|
734 |
+
],
|
735 |
+
api_name=f'select_palette_{idx}',
|
736 |
+
)
|
737 |
+
|
738 |
+
iface.btn_add_palette.click(
|
739 |
+
fn=add_palette,
|
740 |
+
inputs=state,
|
741 |
+
outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
|
742 |
+
api_name='create_new',
|
743 |
+
)
|
744 |
+
|
745 |
+
iface.btn_generate.click(
|
746 |
+
fn=run,
|
747 |
+
inputs=[state, iface.ctrl_semantic],
|
748 |
+
outputs=iface.image_slot,
|
749 |
+
api_name='run',
|
750 |
+
)
|
751 |
+
|
752 |
+
iface.slider_alpha.input(
|
753 |
+
fn=change_mask_strength,
|
754 |
+
inputs=[state, iface.slider_alpha],
|
755 |
+
outputs=state,
|
756 |
+
api_name='change_alpha',
|
757 |
+
)
|
758 |
+
iface.slider_std.input(
|
759 |
+
fn=change_std,
|
760 |
+
inputs=[state, iface.slider_std],
|
761 |
+
outputs=state,
|
762 |
+
api_name='change_std',
|
763 |
+
)
|
764 |
+
iface.slider_strength.input(
|
765 |
+
fn=change_prompt_strength,
|
766 |
+
inputs=[state, iface.slider_strength],
|
767 |
+
outputs=state,
|
768 |
+
api_name='change_strength',
|
769 |
+
)
|
770 |
+
iface.slider_seed.input(
|
771 |
+
fn=reset_seed,
|
772 |
+
inputs=[state, iface.slider_seed],
|
773 |
+
outputs=state,
|
774 |
+
api_name='reset_seed',
|
775 |
+
)
|
776 |
+
|
777 |
+
iface.tbox_name.input(
|
778 |
+
fn=rename_prompt,
|
779 |
+
inputs=[state, iface.tbox_name],
|
780 |
+
outputs=[state] + iface.btn_semantics,
|
781 |
+
api_name='prompt_rename',
|
782 |
+
)
|
783 |
+
iface.tbox_prompt.input(
|
784 |
+
fn=change_prompt,
|
785 |
+
inputs=[state, iface.tbox_prompt],
|
786 |
+
outputs=state,
|
787 |
+
api_name='prompt_edit',
|
788 |
+
)
|
789 |
+
iface.tbox_neg_prompt.input(
|
790 |
+
fn=change_neg_prompt,
|
791 |
+
inputs=[state, iface.tbox_neg_prompt],
|
792 |
+
outputs=state,
|
793 |
+
api_name='neg_prompt_edit',
|
794 |
+
)
|
795 |
+
|
796 |
+
iface.model_select.change(
|
797 |
+
fn=select_model,
|
798 |
+
inputs=[state, iface.model_select],
|
799 |
+
outputs=state,
|
800 |
+
api_name='model_select',
|
801 |
+
)
|
802 |
+
|
803 |
+
iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
|
804 |
+
iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
|
805 |
+
state,
|
806 |
+
*iface.btn_semantics,
|
807 |
+
iface.model_select,
|
808 |
+
iface.tbox_prompt,
|
809 |
+
iface.tbox_name,
|
810 |
+
iface.tbox_neg_prompt,
|
811 |
+
iface.slider_strength,
|
812 |
+
iface.slider_alpha,
|
813 |
+
iface.slider_std,
|
814 |
+
iface.slider_seed,
|
815 |
+
])
|
816 |
+
|
817 |
+
|
818 |
+
if __name__ == '__main__':
|
819 |
+
demo..queue(max_size=20).launch()
|
examples/prompt_background.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
|
2 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
|
3 |
+
Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
|
4 |
+
Maximalism, best quality, high quality, no humans, background, galaxy
|
5 |
+
Maximalism, best quality, high quality, no humans, background, sky, daylight
|
6 |
+
Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
|
7 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
|
8 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
|
examples/prompt_background_advanced.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/prompt_boy.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1boy, looking at viewer, brown hair, blue shirt
|
2 |
+
1boy, looking at viewer, brown hair, red shirt
|
3 |
+
1boy, looking at viewer, brown hair, purple shirt
|
4 |
+
1boy, looking at viewer, brown hair, orange shirt
|
5 |
+
1boy, looking at viewer, brown hair, yellow shirt
|
6 |
+
1boy, looking at viewer, brown hair, green shirt
|
7 |
+
1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
|
8 |
+
1boy, looking back, short hair, renaissance cloths, noble boy
|
9 |
+
1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
|
10 |
+
1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
|
11 |
+
1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
|
12 |
+
1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
|
13 |
+
1boy, looking at viewer, black haired, old eastern cloth
|
14 |
+
1boy, looking back, messy hair, suit, short beard, noir
|
15 |
+
1boy, looking at viewer, cute face, light smile, starry eyes, jeans
|
examples/prompt_girl.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
|
2 |
+
1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
|
3 |
+
1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
|
4 |
+
1girl, looking at viewer, fantasy adventurer, backpack
|
5 |
+
1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
|
6 |
+
1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
|
7 |
+
1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
|
8 |
+
1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
|
9 |
+
1girl, looking at viewer, evil smile, very short hair, suit, evil genius
|
10 |
+
1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
|
11 |
+
1girl, looking at viewer, purple hair, happy face, black leather jacket
|
12 |
+
1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
|
13 |
+
1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
|
14 |
+
1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
|
15 |
+
1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
|
16 |
+
1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
|
examples/prompt_props.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
🏯 Palace, Gyeongbokgung palace
|
2 |
+
🌳 Garden, Chinese garden
|
3 |
+
🏛️ Rome, Ancient city of Rome
|
4 |
+
🧱 Wall, Castle wall
|
5 |
+
🔴 Mars, Martian desert, Red rocky desert
|
6 |
+
🌻 Grassland, Grasslands
|
7 |
+
🏡 Village, A fantasy village
|
8 |
+
🐉 Dragon, a flying chinese dragon
|
9 |
+
🌏 Earth, Earth seen from ISS
|
10 |
+
🚀 Space Station, the international space station
|
11 |
+
🪻 Grassland, Rusty grassland with flowers
|
12 |
+
🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
|
13 |
+
🏙️ City Ruin, city, ruins, ruins, ruins, deserted
|
14 |
+
🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
|
15 |
+
🌷 Flowers, Flower garden
|
16 |
+
🌼 Flowers, Flower garden, spring garden
|
17 |
+
🌹 Flowers, Flowers flowers, flowers
|
18 |
+
⛰️ Dolomites Mountains, Dolomites
|
19 |
+
⛰️ Himalayas Mountains, Himalayas
|
20 |
+
⛰️ Alps Mountains, Alps
|
21 |
+
⛰️ Mountains, Mountains
|
22 |
+
❄️⛰️ Mountains, Winter mountains
|
23 |
+
🌷⛰️ Mountains, Spring mountains
|
24 |
+
🌞⛰️ Mountains, Summer mountains
|
25 |
+
🌵 Desert, A sandy desert, dunes
|
26 |
+
🪨🌵 Desert, A rocky desert
|
27 |
+
💦 Waterfall, A giant waterfall
|
28 |
+
🌊 Ocean, Ocean
|
29 |
+
⛱️ Seashore, Seashore
|
30 |
+
🌅 Sea Horizon, Sea horizon
|
31 |
+
🌊 Lake, Clear blue lake
|
32 |
+
💻 Computer, A giant supecomputer
|
33 |
+
🌳 Tree, A giant tree
|
34 |
+
🌳 Forest, A forest
|
35 |
+
🌳🌳 Forest, A dense forest
|
36 |
+
🌲 Forest, Winter forest
|
37 |
+
🌴 Forest, Summer forest, tropical forest
|
38 |
+
👒 Hat, A hat
|
39 |
+
🐶 Dog, Doggy body parts
|
40 |
+
😻 Cat, A cat
|
41 |
+
🦉 Owl, A small sitting owl
|
42 |
+
🦅 Eagle, A small sitting eagle
|
43 |
+
🚀 Rocket, A flying rocket
|
model.py
ADDED
@@ -0,0 +1,1106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
22 |
+
from diffusers import DiffusionPipeline, LCMScheduler, DDIMScheduler, AutoencoderTiny
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import torchvision.transforms as T
|
28 |
+
from einops import rearrange
|
29 |
+
|
30 |
+
from typing import Tuple, List, Literal, Optional, Union
|
31 |
+
from tqdm import tqdm
|
32 |
+
from PIL import Image
|
33 |
+
|
34 |
+
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
35 |
+
|
36 |
+
|
37 |
+
class StableMultiDiffusionPipeline(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
device: torch.device,
|
41 |
+
dtype: torch.dtype = torch.float16,
|
42 |
+
sd_version: Literal['1.5', '2.0', '2.1', 'xl'] = '1.5',
|
43 |
+
hf_key: Optional[str] = None,
|
44 |
+
lora_key: Optional[str] = None,
|
45 |
+
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
46 |
+
default_mask_std: float = 1.0, # 8.0
|
47 |
+
default_mask_strength: float = 1.0,
|
48 |
+
default_prompt_strength: float = 1.0, # 8.0
|
49 |
+
default_bootstrap_steps: int = 1,
|
50 |
+
default_boostrap_mix_steps: float = 1.0,
|
51 |
+
default_bootstrap_leak_sensitivity: float = 0.2,
|
52 |
+
default_preprocess_mask_cover_alpha: float = 0.3,
|
53 |
+
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
|
54 |
+
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
55 |
+
) -> None:
|
56 |
+
r"""Stabilized MultiDiffusion for fast sampling.
|
57 |
+
|
58 |
+
Accelrated region-based text-to-image synthesis with Latent Consistency
|
59 |
+
Model while preserving mask fidelity and quality.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
device (torch.device): Specify CUDA device.
|
63 |
+
dtype (torch.dtype): Default precision used in the sampling
|
64 |
+
process. By default, it is FP16.
|
65 |
+
sd_version (Literal['1.5', '2.0', '2.1', 'xl']): StableDiffusion
|
66 |
+
version. Currently, only 1.5 is supported.
|
67 |
+
hf_key (Optional[str]): Custom StableDiffusion checkpoint for
|
68 |
+
stylized generation.
|
69 |
+
lora_key (Optional[str]): Custom LCM LoRA for acceleration.
|
70 |
+
load_from_local (bool): Turn on if you have already downloaed LoRA
|
71 |
+
& Hugging Face hub is down.
|
72 |
+
default_mask_std (float): Preprocess mask with Gaussian blur with
|
73 |
+
specified standard deviation.
|
74 |
+
default_mask_strength (float): Preprocess mask by multiplying it
|
75 |
+
globally with the specified variable. Caution: extremely
|
76 |
+
sensitive. Recommended range: 0.98-1.
|
77 |
+
default_prompt_strength (float): Preprocess foreground prompts
|
78 |
+
globally by linearly interpolating its embedding with the
|
79 |
+
background prompt embeddint with specified mix ratio. Useful
|
80 |
+
control handle for foreground blending. Recommended range:
|
81 |
+
0.5-1.
|
82 |
+
default_bootstrap_steps (int): Bootstrapping stage steps to
|
83 |
+
encourage region separation. Recommended range: 1-3.
|
84 |
+
default_boostrap_mix_steps (float): Bootstrapping background is a
|
85 |
+
linear interpolation between background latent and the white
|
86 |
+
image latent. This handle controls the mix ratio. Available
|
87 |
+
range: 0-(number of bootstrapping inference steps). For
|
88 |
+
example, 2.3 means that for the first two steps, white image
|
89 |
+
is used as a bootstrapping background and in the third step,
|
90 |
+
mixture of white (0.3) and registered background (0.7) is used
|
91 |
+
as a bootstrapping background.
|
92 |
+
default_bootstrap_leak_sensitivity (float): Postprocessing at each
|
93 |
+
inference step by masking away the remaining bootstrap
|
94 |
+
backgrounds t Recommended range: 0-1.
|
95 |
+
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
96 |
+
where each mask covered by other masks is reduced in its alpha
|
97 |
+
value by this specified factor.
|
98 |
+
t_index_list (List[int]): The default scheduling for LCM scheduler.
|
99 |
+
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
100 |
+
defines the mask quantization modes. Details in the codes of
|
101 |
+
`self.process_mask`. Basically, this (subtly) controls the
|
102 |
+
smoothness of foreground-background blending. More continuous
|
103 |
+
means more blending, but smaller generated patch depending on
|
104 |
+
the mask standard deviation.
|
105 |
+
"""
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.device = device
|
109 |
+
self.dtype = dtype
|
110 |
+
self.sd_version = sd_version
|
111 |
+
|
112 |
+
self.default_mask_std = default_mask_std
|
113 |
+
self.default_mask_strength = default_mask_strength
|
114 |
+
self.default_prompt_strength = default_prompt_strength
|
115 |
+
self.default_t_list = t_index_list
|
116 |
+
self.default_bootstrap_steps = default_bootstrap_steps
|
117 |
+
self.default_boostrap_mix_steps = default_boostrap_mix_steps
|
118 |
+
self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
|
119 |
+
self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
|
120 |
+
self.mask_type = mask_type
|
121 |
+
|
122 |
+
print(f'[INFO] Loading Stable Diffusion...')
|
123 |
+
variant = None
|
124 |
+
lora_weight_name = None
|
125 |
+
if self.sd_version == '1.5':
|
126 |
+
if hf_key is not None:
|
127 |
+
print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
|
128 |
+
model_key = hf_key
|
129 |
+
else:
|
130 |
+
model_key = 'runwayml/stable-diffusion-v1-5'
|
131 |
+
variant = 'fp16'
|
132 |
+
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
133 |
+
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
134 |
+
# elif self.sd_version == 'xl':
|
135 |
+
# model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
136 |
+
# lora_key = 'latent-consistency/lcm-lora-sdxl'
|
137 |
+
# variant = 'fp16'
|
138 |
+
# lora_weight_name = 'pytorch_lora_weights.safetensors'
|
139 |
+
else:
|
140 |
+
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
141 |
+
|
142 |
+
# Create model
|
143 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
144 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
145 |
+
|
146 |
+
self.pipe = DiffusionPipeline.from_pretrained(model_key, variant=variant, torch_dtype=dtype).to(self.device)
|
147 |
+
if lora_key is None:
|
148 |
+
print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
|
149 |
+
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
150 |
+
self.scheduler = self.pipe.scheduler
|
151 |
+
self.default_num_inference_steps = 50
|
152 |
+
self.default_guidance_scale = 7.5
|
153 |
+
else:
|
154 |
+
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
155 |
+
self.scheduler = self.pipe.scheduler
|
156 |
+
self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
|
157 |
+
self.default_num_inference_steps = 4
|
158 |
+
self.default_guidance_scale = 1.0
|
159 |
+
|
160 |
+
self.prepare_lcm_schedule(t_index_list, 50)
|
161 |
+
|
162 |
+
self.vae = self.pipe.vae
|
163 |
+
self.tokenizer = self.pipe.tokenizer
|
164 |
+
self.text_encoder = self.pipe.text_encoder
|
165 |
+
self.unet = self.pipe.unet
|
166 |
+
self.vae_scale_factor = self.pipe.vae_scale_factor
|
167 |
+
|
168 |
+
# Prepare white background for bootstrapping.
|
169 |
+
self.get_white_background(768, 768)
|
170 |
+
|
171 |
+
print(f'[INFO] Model is loaded!')
|
172 |
+
|
173 |
+
def prepare_lcm_schedule(
|
174 |
+
self,
|
175 |
+
t_index_list: Optional[List[int]] = None,
|
176 |
+
num_inference_steps: Optional[int] = None,
|
177 |
+
) -> None:
|
178 |
+
r"""Set up different inference schedule for the diffusion model.
|
179 |
+
|
180 |
+
You do not have to run this explicitly if you want to use the default
|
181 |
+
setting, but if you want other time schedules, run this function
|
182 |
+
between the module initialization and the main call.
|
183 |
+
|
184 |
+
Note:
|
185 |
+
- Recommended t_index_lists for LCMs:
|
186 |
+
- [0, 12, 25, 37]: Default schedule for 4 steps. Best for
|
187 |
+
panorama. Not recommended if you want to use bootstrapping.
|
188 |
+
Because bootstrapping stage affects the initial structuring
|
189 |
+
of the generated image & in this four step LCM, this is done
|
190 |
+
with only at the first step, the structure may be distorted.
|
191 |
+
- [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
|
192 |
+
strapping. Default initialization in this implementation.
|
193 |
+
- [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
|
194 |
+
bootstrapping.
|
195 |
+
- Due to the characteristic of SD1.5 LCM LoRA, setting
|
196 |
+
`num_inference_steps` larger than 20 may results in overly blurry
|
197 |
+
and unrealistic images. Beware!
|
198 |
+
|
199 |
+
Args:
|
200 |
+
t_index_list (Optional[List[int]]): The specified scheduling step
|
201 |
+
regarding the maximum timestep as `num_inference_steps`, which
|
202 |
+
is by default, 50. That means that
|
203 |
+
`t_index_list=[0, 12, 25, 37]` is a relative time indices basd
|
204 |
+
on the full scale of 50. If None, reinitialize the module with
|
205 |
+
the default value.
|
206 |
+
num_inference_steps (Optional[int]): The maximum timestep of the
|
207 |
+
sampler. Defines relative scale of the `t_index_list`. Rarely
|
208 |
+
used in practice. If None, reinitialize the module with the
|
209 |
+
default value.
|
210 |
+
"""
|
211 |
+
if t_index_list is None:
|
212 |
+
t_index_list = self.default_t_list
|
213 |
+
if num_inference_steps is None:
|
214 |
+
num_inference_steps = self.default_num_inference_steps
|
215 |
+
|
216 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
217 |
+
self.timesteps = torch.as_tensor([
|
218 |
+
self.scheduler.timesteps[t] for t in t_index_list
|
219 |
+
], dtype=torch.long)
|
220 |
+
|
221 |
+
shape = (len(t_index_list), 1, 1, 1)
|
222 |
+
|
223 |
+
c_skip_list = []
|
224 |
+
c_out_list = []
|
225 |
+
for timestep in self.timesteps:
|
226 |
+
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
|
227 |
+
c_skip_list.append(c_skip)
|
228 |
+
c_out_list.append(c_out)
|
229 |
+
self.c_skip = torch.stack(c_skip_list).view(*shape).to(dtype=self.dtype, device=self.device)
|
230 |
+
self.c_out = torch.stack(c_out_list).view(*shape).to(dtype=self.dtype, device=self.device)
|
231 |
+
|
232 |
+
alpha_prod_t_sqrt_list = []
|
233 |
+
beta_prod_t_sqrt_list = []
|
234 |
+
for timestep in self.timesteps:
|
235 |
+
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
|
236 |
+
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
|
237 |
+
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
|
238 |
+
beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
|
239 |
+
alpha_prod_t_sqrt = (torch.stack(alpha_prod_t_sqrt_list).view(*shape)
|
240 |
+
.to(dtype=self.dtype, device=self.device))
|
241 |
+
beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(*shape)
|
242 |
+
.to(dtype=self.dtype, device=self.device))
|
243 |
+
self.alpha_prod_t_sqrt = alpha_prod_t_sqrt
|
244 |
+
self.beta_prod_t_sqrt = beta_prod_t_sqrt
|
245 |
+
|
246 |
+
noise_lvs = (1 - self.scheduler.alphas_cumprod[self.timesteps].to(self.device)) ** 0.5
|
247 |
+
self.noise_lvs = noise_lvs[None, :, None, None, None]
|
248 |
+
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def get_text_embeds(self, prompt: str, negative_prompt: str) -> Tuple[torch.Tensor]:
|
252 |
+
r"""Text embeddings from string text prompts.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
prompt (str): A text prompt string.
|
256 |
+
negative_prompt: An optional negative text prompt string. Good for
|
257 |
+
high-quality generation.
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
A tuple of (negative, positive) prompt embeddings of (1, 77, 768).
|
261 |
+
"""
|
262 |
+
kwargs = dict(padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
263 |
+
|
264 |
+
# Tokenize text and get embeddings.
|
265 |
+
text_input = self.tokenizer(prompt, truncation=True, **kwargs)
|
266 |
+
text_embeds = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
267 |
+
uncond_input = self.tokenizer(negative_prompt, **kwargs)
|
268 |
+
uncond_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
269 |
+
return uncond_embeds, text_embeds
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def get_text_prompts(self, image: Image.Image) -> str:
|
273 |
+
r"""A convenient method to extract text prompt from an image.
|
274 |
+
|
275 |
+
This is called if the user does not provide background prompt but only
|
276 |
+
the background image. We use BLIP-2 to automatically generate prompts.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
image (Image.Image): A PIL image.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
A single string of text prompt.
|
283 |
+
"""
|
284 |
+
question = 'Question: What are in the image? Answer:'
|
285 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
286 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
287 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
288 |
+
return prompt
|
289 |
+
|
290 |
+
@torch.no_grad()
|
291 |
+
def encode_imgs(
|
292 |
+
self,
|
293 |
+
imgs: torch.Tensor,
|
294 |
+
generator: Optional[torch.Generator] = None,
|
295 |
+
vae: Optional[nn.Module] = None,
|
296 |
+
) -> torch.Tensor:
|
297 |
+
r"""A wrapper function for VAE encoder of the latent diffusion model.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
imgs (torch.Tensor): An image to get StableDiffusion latents.
|
301 |
+
Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
|
302 |
+
generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
|
303 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
304 |
+
the demo application with TinyVAE).
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
308 |
+
encoder. Shape: (B, 4, H//8, W//8).
|
309 |
+
"""
|
310 |
+
def _retrieve_latents(
|
311 |
+
encoder_output: torch.Tensor,
|
312 |
+
generator: Optional[torch.Generator] = None,
|
313 |
+
sample_mode: str = 'sample',
|
314 |
+
):
|
315 |
+
if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
|
316 |
+
return encoder_output.latent_dist.sample(generator)
|
317 |
+
elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
|
318 |
+
return encoder_output.latent_dist.mode()
|
319 |
+
elif hasattr(encoder_output, 'latents'):
|
320 |
+
return encoder_output.latents
|
321 |
+
else:
|
322 |
+
raise AttributeError('Could not access latents of provided encoder_output')
|
323 |
+
|
324 |
+
vae = self.vae if vae is None else vae
|
325 |
+
imgs = 2 * imgs - 1
|
326 |
+
latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
|
327 |
+
return latents
|
328 |
+
|
329 |
+
@torch.no_grad()
|
330 |
+
def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
|
331 |
+
r"""A wrapper function for VAE decoder of the latent diffusion model.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
latents (torch.Tensor): An image latent to get associated images.
|
335 |
+
Expected shape: (B, 4, H//8, W//8).
|
336 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
337 |
+
the demo application with TinyVAE).
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
341 |
+
encoder. Shape: (B, 3, H, W).
|
342 |
+
"""
|
343 |
+
vae = self.vae if vae is None else vae
|
344 |
+
latents = 1 / vae.config.scaling_factor * latents
|
345 |
+
imgs = vae.decode(latents).sample
|
346 |
+
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
347 |
+
return imgs
|
348 |
+
|
349 |
+
@torch.no_grad()
|
350 |
+
def get_white_background(self, height: int, width: int) -> torch.Tensor:
|
351 |
+
r"""White background image latent for bootstrapping or in case of
|
352 |
+
absent background.
|
353 |
+
|
354 |
+
Additionally stores the maximally-sized white latent for fast retrieval
|
355 |
+
in the future. By default, we initially call this with 768x768 sized
|
356 |
+
white image, so the function is rarely visited twice.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
height (int): The height of the white *image*, not its latent.
|
360 |
+
width (int): The width of the white *image*, not its latent.
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
A white image latent of size (1, 4, height//8, width//8). A cropped
|
364 |
+
version of the stored white latent is returned if the requested
|
365 |
+
size is smaller than what we already have created.
|
366 |
+
"""
|
367 |
+
if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
|
368 |
+
white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
369 |
+
self.white = self.encode_imgs(white)
|
370 |
+
return self.white
|
371 |
+
return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
|
372 |
+
|
373 |
+
@torch.no_grad()
|
374 |
+
def process_mask(
|
375 |
+
self,
|
376 |
+
masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
|
377 |
+
strength: Optional[Union[torch.Tensor, float]] = None,
|
378 |
+
std: Optional[Union[torch.Tensor, float]] = None,
|
379 |
+
height: int = 512,
|
380 |
+
width: int = 512,
|
381 |
+
use_boolean_mask: bool = True,
|
382 |
+
timesteps: Optional[torch.Tensor] = None,
|
383 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
384 |
+
) -> Tuple[torch.Tensor]:
|
385 |
+
r"""Fast preprocess of masks for region-based generation with fine-
|
386 |
+
grained controls.
|
387 |
+
|
388 |
+
Mask preprocessing is done in four steps:
|
389 |
+
1. Resizing: Resize the masks into the specified width and height by
|
390 |
+
nearest neighbor interpolation.
|
391 |
+
2. (Optional) Ordering: Masks with higher indices are considered to
|
392 |
+
cover the masks with smaller indices. Covered masks are decayed
|
393 |
+
in its alpha value by the specified factor of
|
394 |
+
`preprocess_mask_cover_alpha`.
|
395 |
+
3. Blurring: Gaussian blur is applied to the mask with the specified
|
396 |
+
standard deviation (isotropic). This results in gradual increase of
|
397 |
+
masked region as the timesteps evolve, naturally blending fore-
|
398 |
+
ground and the predesignated background. Not strictly required if
|
399 |
+
you want to produce images from scratch withoout background.
|
400 |
+
4. Quantization: Split the real-numbered masks of value between [0, 1]
|
401 |
+
into predefined noise levels for each quantized scheduling step of
|
402 |
+
the diffusion sampler. For example, if the diffusion model sampler
|
403 |
+
has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
|
404 |
+
is the default noise level of this module with schedule [0, 4, 12,
|
405 |
+
25, 37], the masks are split into binary masks whose values are
|
406 |
+
greater than these levels. This results in tradual increase of mask
|
407 |
+
region as the timesteps increase. Details are described in our
|
408 |
+
paper at https://arxiv.org/pdf/2403.09055.pdf.
|
409 |
+
|
410 |
+
On the Three Modes of `mask_type`:
|
411 |
+
`self.mask_type` is predefined at the initialization stage of this
|
412 |
+
pipeline. Three possible modes are available: 'discrete', 'semi-
|
413 |
+
continuous', and 'continuous'. These define the mask quantization
|
414 |
+
modes we use. Basically, this (subtly) controls the smoothness of
|
415 |
+
foreground-background blending. Continuous modes produces nonbinary
|
416 |
+
masks to further blend foreground and background latents by linear-
|
417 |
+
ly interpolating between them. Semi-continuous masks only applies
|
418 |
+
continuous mask at the last step of the LCM sampler. Due to the
|
419 |
+
large step size of the LCM scheduler, we find that our continuous
|
420 |
+
blending helps generating seamless inpainting and editing results.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
|
424 |
+
strength (Optional[Union[torch.Tensor, float]]): Mask strength that
|
425 |
+
overrides the default value. A globally multiplied factor to
|
426 |
+
the mask at the initial stage of processing. Can be applied
|
427 |
+
seperately for each mask.
|
428 |
+
std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
|
429 |
+
kernel's standard deviation. Overrides the default value. Can
|
430 |
+
be applied seperately for each mask.
|
431 |
+
height (int): The height of the expected generation. Mask is
|
432 |
+
resized to (height//8, width//8) with nearest neighbor inter-
|
433 |
+
polation.
|
434 |
+
width (int): The width of the expected generation. Mask is resized
|
435 |
+
to (height//8, width//8) with nearest neighbor interpolation.
|
436 |
+
use_boolean_mask (bool): Specify this to treat the mask image as
|
437 |
+
a boolean tensor. The retion with dark part darker than 0.5 of
|
438 |
+
the maximal pixel value (that is, 127.5) is considered as the
|
439 |
+
designated mask.
|
440 |
+
timesteps (Optional[torch.Tensor]): Defines the scheduler noise
|
441 |
+
levels that acts as bins of mask quantization.
|
442 |
+
preprocess_mask_cover_alpha (Optional[float]): Optional pre-
|
443 |
+
processing where each mask covered by other masks is reduced in
|
444 |
+
its alpha value by this specified factor. Overrides the default
|
445 |
+
value.
|
446 |
+
|
447 |
+
Returns: A tuple of tensors.
|
448 |
+
- masks: Preprocessed (ordered, blurred, and quantized) binary/non-
|
449 |
+
binary masks (see the explanation on `mask_type` above) for
|
450 |
+
region-based image synthesis.
|
451 |
+
- masks_blurred: Gaussian blurred masks. Used for optionally
|
452 |
+
specified foreground-background blending after image
|
453 |
+
generation.
|
454 |
+
- std: Mask blur standard deviation. Used for optionally specified
|
455 |
+
foreground-background blending after image generation.
|
456 |
+
"""
|
457 |
+
if isinstance(masks, Image.Image):
|
458 |
+
masks = [masks]
|
459 |
+
if isinstance(masks, (tuple, list)):
|
460 |
+
# Assumes white background for Image.Image;
|
461 |
+
# inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
|
462 |
+
if use_boolean_mask:
|
463 |
+
proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
|
464 |
+
else:
|
465 |
+
proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
|
466 |
+
masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
|
467 |
+
masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
|
468 |
+
masks = masks.to(self.device)
|
469 |
+
|
470 |
+
# Background mask alpha is decayed by the specified factor where foreground masks covers it.
|
471 |
+
if preprocess_mask_cover_alpha is None:
|
472 |
+
preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
|
473 |
+
if preprocess_mask_cover_alpha > 0:
|
474 |
+
masks = torch.stack([
|
475 |
+
torch.where(
|
476 |
+
masks[i + 1:].sum(dim=0) > 0,
|
477 |
+
mask * preprocess_mask_cover_alpha,
|
478 |
+
mask,
|
479 |
+
) if i < len(masks) - 1 else mask
|
480 |
+
for i, mask in enumerate(masks)
|
481 |
+
], dim=0)
|
482 |
+
|
483 |
+
# Scheduler noise levels for mask quantization.
|
484 |
+
if timesteps is None:
|
485 |
+
noise_lvs = self.noise_lvs
|
486 |
+
next_noise_lvs = self.next_noise_lvs
|
487 |
+
else:
|
488 |
+
noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
|
489 |
+
noise_lvs = noise_lvs_[None, :, None, None, None]
|
490 |
+
next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
|
491 |
+
|
492 |
+
# Mask preprocessing parameters are fetched from the default settings.
|
493 |
+
if std is None:
|
494 |
+
std = self.default_mask_std
|
495 |
+
if isinstance(std, (int, float)):
|
496 |
+
std = [std] * len(masks)
|
497 |
+
if isinstance(std, (list, tuple)):
|
498 |
+
std = torch.as_tensor(std, dtype=torch.float, device=self.device)
|
499 |
+
|
500 |
+
if strength is None:
|
501 |
+
strength = self.default_mask_strength
|
502 |
+
if isinstance(strength, (int, float)):
|
503 |
+
strength = [strength] * len(masks)
|
504 |
+
if isinstance(strength, (list, tuple)):
|
505 |
+
strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
|
506 |
+
|
507 |
+
if (std > 0).any():
|
508 |
+
std = torch.where(std > 0, std, 1e-5)
|
509 |
+
masks = gaussian_lowpass(masks, std)
|
510 |
+
masks_blurred = masks
|
511 |
+
|
512 |
+
# NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
|
513 |
+
# gives unpleasant results.
|
514 |
+
masks = masks * strength[:, None, None, None]
|
515 |
+
masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
|
516 |
+
|
517 |
+
# Mask is quantized according to the current noise levels specified by the scheduler.
|
518 |
+
if self.mask_type == 'discrete':
|
519 |
+
# Discrete mode.
|
520 |
+
masks = masks > noise_lvs
|
521 |
+
elif self.mask_type == 'semi-continuous':
|
522 |
+
# Semi-continuous mode (continuous at the last step only).
|
523 |
+
masks = torch.cat((
|
524 |
+
masks[:, :-1] > noise_lvs[:, :-1],
|
525 |
+
(
|
526 |
+
(masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
|
527 |
+
).clip_(0, 1),
|
528 |
+
), dim=1)
|
529 |
+
elif self.mask_type == 'continuous':
|
530 |
+
# Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
|
531 |
+
# decreases continuously after the discrete mode boundary to become `0` at the
|
532 |
+
# next lower threshold.
|
533 |
+
masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
|
534 |
+
|
535 |
+
# NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
|
536 |
+
# fine-grained mask alpha channel tuning is available with this form.
|
537 |
+
# masks = masks * strength[None, :, None, None, None]
|
538 |
+
|
539 |
+
h = height // self.vae_scale_factor
|
540 |
+
w = width // self.vae_scale_factor
|
541 |
+
masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
|
542 |
+
masks = F.interpolate(masks, size=(h, w), mode='nearest')
|
543 |
+
masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
|
544 |
+
return masks, masks_blurred, std
|
545 |
+
|
546 |
+
def scheduler_step(
|
547 |
+
self,
|
548 |
+
noise_pred: torch.Tensor,
|
549 |
+
idx: int,
|
550 |
+
latent: torch.Tensor,
|
551 |
+
) -> torch.Tensor:
|
552 |
+
r"""Denoise-only step for reverse diffusion scheduler.
|
553 |
+
|
554 |
+
Designed to match the interface of the original `pipe.scheduler.step`,
|
555 |
+
which is a combination of this method and the following
|
556 |
+
`scheduler_add_noise`.
|
557 |
+
|
558 |
+
Args:
|
559 |
+
noise_pred (torch.Tensor): Noise prediction results from the U-Net.
|
560 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
561 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
562 |
+
latent (torch.Tensor): Noisy latent.
|
563 |
+
|
564 |
+
Returns:
|
565 |
+
A denoised tensor with the same size as latent.
|
566 |
+
"""
|
567 |
+
F_theta = (latent - self.beta_prod_t_sqrt[idx] * noise_pred) / self.alpha_prod_t_sqrt[idx]
|
568 |
+
return self.c_out[idx] * F_theta + self.c_skip[idx] * latent
|
569 |
+
|
570 |
+
def scheduler_add_noise(
|
571 |
+
self,
|
572 |
+
latent: torch.Tensor,
|
573 |
+
noise: Optional[torch.Tensor],
|
574 |
+
idx: int,
|
575 |
+
) -> torch.Tensor:
|
576 |
+
r"""Separated noise-add step for the reverse diffusion scheduler.
|
577 |
+
|
578 |
+
Designed to match the interface of the original
|
579 |
+
`pipe.scheduler.add_noise`.
|
580 |
+
|
581 |
+
Args:
|
582 |
+
latent (torch.Tensor): Denoised latent.
|
583 |
+
noise (torch.Tensor): Added noise. Can be None. If None, a random
|
584 |
+
noise is newly sampled for addition.
|
585 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
586 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
587 |
+
|
588 |
+
Returns:
|
589 |
+
A noisy tensor with the same size as latent.
|
590 |
+
"""
|
591 |
+
if idx >= len(self.alpha_prod_t_sqrt) or idx < 0:
|
592 |
+
# The last step does not require noise addition.
|
593 |
+
return latent
|
594 |
+
noise = torch.randn_like(latent) if noise is None else noise
|
595 |
+
return self.alpha_prod_t_sqrt[idx] * latent + self.beta_prod_t_sqrt[idx] * noise
|
596 |
+
|
597 |
+
@torch.no_grad()
|
598 |
+
def sample(
|
599 |
+
self,
|
600 |
+
prompts: Union[str, List[str]],
|
601 |
+
negative_prompts: Union[str, List[str]] = '',
|
602 |
+
height: int = 512,
|
603 |
+
width: int = 512,
|
604 |
+
num_inference_steps: Optional[int] = None,
|
605 |
+
guidance_scale: Optional[float] = None,
|
606 |
+
batch_size: int = 1,
|
607 |
+
) -> Image.Image:
|
608 |
+
r"""StableDiffusionPipeline for single-prompt single-tile generation.
|
609 |
+
|
610 |
+
Minimal Example:
|
611 |
+
>>> device = torch.device('cuda:0')
|
612 |
+
>>> smd = StableMultiDiffusionPipeline(device)
|
613 |
+
>>> image = smd.sample('A photo of the dolomites')
|
614 |
+
>>> image.save('my_creation.png')
|
615 |
+
|
616 |
+
Args:
|
617 |
+
prompts (Union[str, List[str]]): A text prompt.
|
618 |
+
negative_prompts (Union[str, List[str]]): A negative text prompt.
|
619 |
+
height (int): Height of a generated image.
|
620 |
+
width (int): Width of a generated image.
|
621 |
+
num_inference_steps (Optional[int]): Number of inference steps.
|
622 |
+
Default inference scheduling is used if none is specified.
|
623 |
+
guidance_scale (Optional[float]): Classifier guidance scale.
|
624 |
+
Default value is used if none is specified.
|
625 |
+
batch_size (int): Number of images to generate.
|
626 |
+
|
627 |
+
Returns: A PIL.Image image.
|
628 |
+
"""
|
629 |
+
if num_inference_steps is None:
|
630 |
+
num_inference_steps = self.default_num_inference_steps
|
631 |
+
if guidance_scale is None:
|
632 |
+
guidance_scale = self.default_guidance_scale
|
633 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
634 |
+
|
635 |
+
if isinstance(prompts, str):
|
636 |
+
prompts = [prompts]
|
637 |
+
if isinstance(negative_prompts, str):
|
638 |
+
negative_prompts = [negative_prompts]
|
639 |
+
|
640 |
+
# Calculate text embeddings.
|
641 |
+
uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
642 |
+
text_embeds = torch.cat([uncond_embeds.mean(dim=0, keepdim=True), text_embeds.mean(dim=0, keepdim=True)])
|
643 |
+
h = height // self.vae_scale_factor
|
644 |
+
w = width // self.vae_scale_factor
|
645 |
+
latent = torch.randn((batch_size, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
646 |
+
|
647 |
+
with torch.autocast('cuda'):
|
648 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
649 |
+
# Expand the latents if we are doing classifier-free guidance.
|
650 |
+
latent_model_input = torch.cat([latent] * 2)
|
651 |
+
|
652 |
+
# Perform one step of the reverse diffusion.
|
653 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
654 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
655 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
656 |
+
latent = self.scheduler.step(noise_pred, t, latent)['prev_sample']
|
657 |
+
|
658 |
+
# Return PIL Image.
|
659 |
+
latent = latent.to(dtype=self.dtype)
|
660 |
+
imgs = [T.ToPILImage()(self.decode_latents(l[None])[0]) for l in latent]
|
661 |
+
return imgs
|
662 |
+
|
663 |
+
@torch.no_grad()
|
664 |
+
def sample_panorama(
|
665 |
+
self,
|
666 |
+
prompts: Union[str, List[str]],
|
667 |
+
negative_prompts: Union[str, List[str]] = '',
|
668 |
+
height: int = 512,
|
669 |
+
width: int = 2048,
|
670 |
+
num_inference_steps: Optional[int] = None,
|
671 |
+
guidance_scale: Optional[float] = None,
|
672 |
+
tile_size: Optional[int] = None,
|
673 |
+
) -> Image.Image:
|
674 |
+
r"""Large size image generation from a single set of prompts.
|
675 |
+
|
676 |
+
Minimal Example:
|
677 |
+
>>> device = torch.device('cuda:0')
|
678 |
+
>>> smd = StableMultiDiffusionPipeline(device)
|
679 |
+
>>> image = smd.sample_panorama(
|
680 |
+
>>> 'A photo of Alps', height=512, width=3072)
|
681 |
+
>>> image.save('my_panorama_creation.png')
|
682 |
+
|
683 |
+
Args:
|
684 |
+
prompts (Union[str, List[str]]): A text prompt.
|
685 |
+
negative_prompts (Union[str, List[str]]): A negative text prompt.
|
686 |
+
height (int): Height of a generated image. It is tiled if larger
|
687 |
+
than `tile_size`.
|
688 |
+
width (int): Width of a generated image. It is tiled if larger
|
689 |
+
than `tile_size`.
|
690 |
+
num_inference_steps (Optional[int]): Number of inference steps.
|
691 |
+
Default inference scheduling is used if none is specified.
|
692 |
+
guidance_scale (Optional[float]): Classifier guidance scale.
|
693 |
+
Default value is used if none is specified.
|
694 |
+
tile_size (Optional[int]): Tile size of the panorama generation.
|
695 |
+
Works best with the default training size of the Stable-
|
696 |
+
Diffusion model, i.e., 512 or 768 for SD1.5 and 1024 for SDXL.
|
697 |
+
|
698 |
+
Returns: A PIL.Image image of a panorama (large-size) image.
|
699 |
+
"""
|
700 |
+
if num_inference_steps is None:
|
701 |
+
num_inference_steps = self.default_num_inference_steps
|
702 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
703 |
+
timesteps = self.timesteps
|
704 |
+
use_custom_timesteps = False
|
705 |
+
else:
|
706 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
707 |
+
timesteps = self.scheduler.timesteps
|
708 |
+
use_custom_timesteps = True
|
709 |
+
if guidance_scale is None:
|
710 |
+
guidance_scale = self.default_guidance_scale
|
711 |
+
|
712 |
+
if isinstance(prompts, str):
|
713 |
+
prompts = [prompts]
|
714 |
+
if isinstance(negative_prompts, str):
|
715 |
+
negative_prompts = [negative_prompts]
|
716 |
+
|
717 |
+
# Calculate text embeddings.
|
718 |
+
uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
719 |
+
text_embeds = torch.cat([uncond_embeds.mean(dim=0, keepdim=True), text_embeds.mean(dim=0, keepdim=True)])
|
720 |
+
|
721 |
+
# Define panorama grid and get views
|
722 |
+
h = height // self.vae_scale_factor
|
723 |
+
w = width // self.vae_scale_factor
|
724 |
+
latent = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
725 |
+
|
726 |
+
if tile_size is None:
|
727 |
+
tile_size = min(min(height, width), 512)
|
728 |
+
views, masks = get_panorama_views(h, w, tile_size // self.vae_scale_factor)
|
729 |
+
masks = masks.to(dtype=self.dtype, device=self.device)
|
730 |
+
value = torch.zeros_like(latent)
|
731 |
+
with torch.autocast('cuda'):
|
732 |
+
for i, t in enumerate(tqdm(timesteps)):
|
733 |
+
value.zero_()
|
734 |
+
|
735 |
+
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
736 |
+
# TODO we can support batches, and pass multiple views at once to the unet
|
737 |
+
latent_view = latent[:, :, h_start:h_end, w_start:w_end]
|
738 |
+
|
739 |
+
# Expand the latents if we are doing classifier-free guidance.
|
740 |
+
latent_model_input = torch.cat([latent_view] * 2)
|
741 |
+
|
742 |
+
# Perform one step of the reverse diffusion.
|
743 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
|
744 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
745 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
746 |
+
|
747 |
+
# Compute the denoising step.
|
748 |
+
latents_view_denoised = self.scheduler_step(noise_pred, i, latent_view) # (1, 4, h, w)
|
749 |
+
mask = masks[..., j:j + 1, h_start:h_end, w_start:w_end] # (1, 1, h, w)
|
750 |
+
value[..., h_start:h_end, w_start:w_end] += mask * latents_view_denoised # (1, 1, h, w)
|
751 |
+
|
752 |
+
# Update denoised latent.
|
753 |
+
latent = value.clone()
|
754 |
+
|
755 |
+
if i < len(timesteps) - 1:
|
756 |
+
latent = self.scheduler_add_noise(latent, None, i + 1)
|
757 |
+
|
758 |
+
# Return PIL Image.
|
759 |
+
imgs = self.decode_latents(latent)
|
760 |
+
img = T.ToPILImage()(imgs[0].cpu())
|
761 |
+
return img
|
762 |
+
|
763 |
+
@torch.no_grad()
|
764 |
+
def __call__(
|
765 |
+
self,
|
766 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
767 |
+
negative_prompts: Union[str, List[str]] = '',
|
768 |
+
suffix: Optional[str] = None, #', background is ',
|
769 |
+
background: Optional[Union[torch.Tensor, Image.Image]] = None,
|
770 |
+
background_prompt: Optional[str] = None,
|
771 |
+
background_negative_prompt: str = '',
|
772 |
+
height: int = 512,
|
773 |
+
width: int = 512,
|
774 |
+
num_inference_steps: Optional[int] = None,
|
775 |
+
guidance_scale: Optional[float] = None,
|
776 |
+
prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
777 |
+
masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
778 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
779 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
780 |
+
use_boolean_mask: bool = True,
|
781 |
+
do_blend: bool = True,
|
782 |
+
tile_size: int = 768,
|
783 |
+
bootstrap_steps: Optional[int] = None,
|
784 |
+
boostrap_mix_steps: Optional[float] = None,
|
785 |
+
bootstrap_leak_sensitivity: Optional[float] = None,
|
786 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
787 |
+
) -> Image.Image:
|
788 |
+
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
789 |
+
text prompt-mask pairs.
|
790 |
+
|
791 |
+
This is a main routine for this pipeline.
|
792 |
+
|
793 |
+
Example:
|
794 |
+
>>> device = torch.device('cuda:0')
|
795 |
+
>>> smd = StableMultiDiffusionPipeline(device)
|
796 |
+
>>> prompts = {... specify prompts}
|
797 |
+
>>> masks = {... specify mask tensors}
|
798 |
+
>>> height, width = masks.shape[-2:]
|
799 |
+
>>> image = smd(
|
800 |
+
>>> prompts, masks=masks.float(), height=height, width=width)
|
801 |
+
>>> image.save('my_beautiful_creation.png')
|
802 |
+
|
803 |
+
Args:
|
804 |
+
prompts (Union[str, List[str]]): A text prompt.
|
805 |
+
negative_prompts (Union[str, List[str]]): A negative text prompt.
|
806 |
+
suffix (Optional[str]): One option for blending foreground prompts
|
807 |
+
with background prompts by simply appending background prompt
|
808 |
+
to the end of each foreground prompt with this `middle word` in
|
809 |
+
between. For example, if you set this as `, background is`,
|
810 |
+
then the foreground prompt will be changed into
|
811 |
+
`(fg), background is (bg)` before conditional generation.
|
812 |
+
background (Optional[Union[torch.Tensor, Image.Image]]): a
|
813 |
+
background image, if the user wants to draw in front of the
|
814 |
+
specified image. Background prompt will automatically generated
|
815 |
+
with a BLIP-2 model.
|
816 |
+
background_prompt (Optional[str]): The background prompt is used
|
817 |
+
for preprocessing foreground prompt embeddings to blend
|
818 |
+
foreground and background.
|
819 |
+
background_negative_prompt (Optional[str]): The negative background
|
820 |
+
prompt.
|
821 |
+
height (int): Height of a generated image. It is tiled if larger
|
822 |
+
than `tile_size`.
|
823 |
+
width (int): Width of a generated image. It is tiled if larger
|
824 |
+
than `tile_size`.
|
825 |
+
num_inference_steps (Optional[int]): Number of inference steps.
|
826 |
+
Default inference scheduling is used if none is specified.
|
827 |
+
guidance_scale (Optional[float]): Classifier guidance scale.
|
828 |
+
Default value is used if none is specified.
|
829 |
+
prompt_strength (float): Overrides default value. Preprocess
|
830 |
+
foreground prompts globally by linearly interpolating its
|
831 |
+
embedding with the background prompt embeddint with specified
|
832 |
+
mix ratio. Useful control handle for foreground blending.
|
833 |
+
Recommended range: 0.5-1.
|
834 |
+
masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
|
835 |
+
mask images. Each mask associates with each of the text prompts
|
836 |
+
and each of the negative prompts. If specified as an image, it
|
837 |
+
regards the image as a boolean mask. Also accepts torch.Tensor
|
838 |
+
masks, which can have nonbinary values for fine-grained
|
839 |
+
controls in mixing regional generations.
|
840 |
+
mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
|
841 |
+
Overrides the default value. an be assigned for each mask
|
842 |
+
separately. Preprocess mask by multiplying it globally with the
|
843 |
+
specified variable. Caution: extremely sensitive. Recommended
|
844 |
+
range: 0.98-1.
|
845 |
+
mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
|
846 |
+
Overrides the default value. Can be assigned for each mask
|
847 |
+
separately. Preprocess mask with Gaussian blur with specified
|
848 |
+
standard deviation. Recommended range: 0-64.
|
849 |
+
use_boolean_mask (bool): Turn this off if you want to treat the
|
850 |
+
mask image as nonbinary one. The module will use the last
|
851 |
+
channel of the given image in `masks` as the mask value.
|
852 |
+
do_blend (bool): Blend the generated foreground and the optionally
|
853 |
+
predefined background by smooth boundary obtained from Gaussian
|
854 |
+
blurs of the foreground `masks` with the given `mask_stds`.
|
855 |
+
tile_size (Optional[int]): Tile size of the panorama generation.
|
856 |
+
Works best with the default training size of the Stable-
|
857 |
+
Diffusion model, i.e., 512 or 768 for SD1.5 and 1024 for SDXL.
|
858 |
+
bootstrap_steps (int): Overrides the default value. Bootstrapping
|
859 |
+
stage steps to encourage region separation. Recommended range:
|
860 |
+
1-3.
|
861 |
+
boostrap_mix_steps (float): Overrides the default value.
|
862 |
+
Bootstrapping background is a linear interpolation between
|
863 |
+
background latent and the white image latent. This handle
|
864 |
+
controls the mix ratio. Available range: 0-(number of
|
865 |
+
bootstrapping inference steps). For example, 2.3 means that for
|
866 |
+
the first two steps, white image is used as a bootstrapping
|
867 |
+
background and in the third step, mixture of white (0.3) and
|
868 |
+
registered background (0.7) is used as a bootstrapping
|
869 |
+
background.
|
870 |
+
bootstrap_leak_sensitivity (float): Overrides the default value.
|
871 |
+
Postprocessing at each inference step by masking away the
|
872 |
+
remaining bootstrap backgrounds t Recommended range: 0-1.
|
873 |
+
preprocess_mask_cover_alpha (float): Overrides the default value.
|
874 |
+
Optional preprocessing where each mask covered by other masks
|
875 |
+
is reduced in its alpha value by this specified factor.
|
876 |
+
|
877 |
+
Returns: A PIL.Image image of a panorama (large-size) image.
|
878 |
+
"""
|
879 |
+
|
880 |
+
### Simplest cases
|
881 |
+
|
882 |
+
# prompts is None: return background.
|
883 |
+
# masks is None but prompts is not None: return prompts
|
884 |
+
# masks is not None and prompts is not None: Do StableMultiDiffusion.
|
885 |
+
|
886 |
+
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
887 |
+
if background is None and background_prompt is not None:
|
888 |
+
return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
|
889 |
+
return background
|
890 |
+
elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
|
891 |
+
return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
|
892 |
+
|
893 |
+
|
894 |
+
### Prepare generation
|
895 |
+
|
896 |
+
if num_inference_steps is not None:
|
897 |
+
self.prepare_lcm_schedule(list(range(num_inference_steps)), num_inference_steps)
|
898 |
+
|
899 |
+
if guidance_scale is None:
|
900 |
+
guidance_scale = self.default_guidance_scale
|
901 |
+
|
902 |
+
|
903 |
+
### Prompts & Masks
|
904 |
+
|
905 |
+
# asserts #m > 0 and #p > 0.
|
906 |
+
# #m == #p == #n > 0: We happily generate according to the prompts & masks.
|
907 |
+
# #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
|
908 |
+
# #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
|
909 |
+
|
910 |
+
if isinstance(masks, Image.Image):
|
911 |
+
masks = [masks]
|
912 |
+
if isinstance(prompts, str):
|
913 |
+
prompts = [prompts]
|
914 |
+
if isinstance(negative_prompts, str):
|
915 |
+
negative_prompts = [negative_prompts]
|
916 |
+
num_masks = len(masks)
|
917 |
+
num_prompts = len(prompts)
|
918 |
+
num_nprompts = len(negative_prompts)
|
919 |
+
assert num_prompts in (num_masks, 1), \
|
920 |
+
f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
|
921 |
+
assert num_nprompts in (num_prompts, 1), \
|
922 |
+
f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
|
923 |
+
|
924 |
+
fg_masks, masks_g, std = self.process_mask(
|
925 |
+
masks,
|
926 |
+
mask_strengths,
|
927 |
+
mask_stds,
|
928 |
+
height=height,
|
929 |
+
width=width,
|
930 |
+
use_boolean_mask=use_boolean_mask,
|
931 |
+
timesteps=self.timesteps,
|
932 |
+
preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
|
933 |
+
) # (p, t, 1, H, W)
|
934 |
+
bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
|
935 |
+
has_background = bg_masks.sum() > 0
|
936 |
+
|
937 |
+
h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
|
938 |
+
w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
|
939 |
+
|
940 |
+
|
941 |
+
### Background
|
942 |
+
|
943 |
+
# background == None && background_prompt == None: Initialize with white background.
|
944 |
+
# background == None && background_prompt != None: Generate background *along with other prompts*.
|
945 |
+
# background != None && background_prompt == None: Retrieve text prompt using BLIP.
|
946 |
+
# background != None && background_prompt != None: Use the given arguments.
|
947 |
+
|
948 |
+
# not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
|
949 |
+
# has_background && prompt_strength != 1: mix only for this case.
|
950 |
+
|
951 |
+
bg_latent = None
|
952 |
+
if has_background:
|
953 |
+
if background is None and background_prompt is not None:
|
954 |
+
fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
|
955 |
+
if suffix is not None:
|
956 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
957 |
+
prompts = [background_prompt] + prompts
|
958 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
959 |
+
has_background = False # Regard that background does not exist.
|
960 |
+
else:
|
961 |
+
if background is None and background_prompt is None:
|
962 |
+
background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
963 |
+
background_prompt = 'simple white background image'
|
964 |
+
elif background is not None and background_prompt is None:
|
965 |
+
background_prompt = self.get_text_prompts(background)
|
966 |
+
if suffix is not None:
|
967 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
968 |
+
prompts = [background_prompt] + prompts
|
969 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
970 |
+
if isinstance(background, Image.Image):
|
971 |
+
background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
|
972 |
+
background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
|
973 |
+
bg_latent = self.encode_imgs(background)
|
974 |
+
|
975 |
+
# Bootstrapping stage preparation.
|
976 |
+
|
977 |
+
if bootstrap_steps is None:
|
978 |
+
bootstrap_steps = self.default_bootstrap_steps
|
979 |
+
if boostrap_mix_steps is None:
|
980 |
+
boostrap_mix_steps = self.default_boostrap_mix_steps
|
981 |
+
if bootstrap_leak_sensitivity is None:
|
982 |
+
bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
|
983 |
+
if bootstrap_steps > 0:
|
984 |
+
height_ = min(height, tile_size)
|
985 |
+
width_ = min(width, tile_size)
|
986 |
+
white = self.get_white_background(height, width) # (1, 4, h, w)
|
987 |
+
|
988 |
+
|
989 |
+
### Prepare text embeddings (optimized for the minimal encoder batch size)
|
990 |
+
|
991 |
+
uncond_embeds, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]
|
992 |
+
if has_background:
|
993 |
+
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
994 |
+
s = prompt_strengths
|
995 |
+
if prompt_strengths is None:
|
996 |
+
s = self.default_prompt_strength
|
997 |
+
if isinstance(s, (int, float)):
|
998 |
+
s = [s] * num_prompts
|
999 |
+
if isinstance(s, (list, tuple)):
|
1000 |
+
assert len(s) == num_prompts, \
|
1001 |
+
f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
|
1002 |
+
s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
|
1003 |
+
s = s[:, None, None]
|
1004 |
+
|
1005 |
+
be = text_embeds[:1]
|
1006 |
+
bu = uncond_embeds[:1]
|
1007 |
+
fe = text_embeds[1:]
|
1008 |
+
fu = uncond_embeds[1:]
|
1009 |
+
if num_prompts > num_nprompts:
|
1010 |
+
# # negative prompts = 1; # prompts > 1.
|
1011 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
1012 |
+
fu = fu.repeat(num_prompts, 1, 1)
|
1013 |
+
text_embeds = torch.lerp(be, fe, s) # (p, 77, 768)
|
1014 |
+
uncond_embeds = torch.lerp(bu, fu, s) # (n, 77, 768)
|
1015 |
+
elif num_prompts > num_nprompts:
|
1016 |
+
# # negative prompts = 1; # prompts > 1.
|
1017 |
+
assert uncond_embeds.shape[0] == 1 and text_embeds.shape[0] == num_prompts
|
1018 |
+
uncond_embeds = uncond_embeds.repeat(num_prompts, 1, 1)
|
1019 |
+
assert uncond_embeds.shape[0] == text_embeds.shape[0] == num_prompts
|
1020 |
+
if num_masks > num_prompts:
|
1021 |
+
assert masks.shape[0] == num_masks and num_prompts == 1
|
1022 |
+
text_embeds = text_embeds.repeat(num_masks, 1, 1)
|
1023 |
+
uncond_embeds = uncond_embeds.repeat(num_masks, 1, 1)
|
1024 |
+
text_embeds = torch.cat([uncond_embeds, text_embeds])
|
1025 |
+
|
1026 |
+
|
1027 |
+
### Run
|
1028 |
+
|
1029 |
+
# Latent initialization.
|
1030 |
+
if self.timesteps[0] < 999 and has_background:
|
1031 |
+
latent = self.scheduler_add_noise(bg_latent, None, 0)
|
1032 |
+
else:
|
1033 |
+
latent = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
1034 |
+
|
1035 |
+
# Tiling (if needed).
|
1036 |
+
if height > tile_size or width > tile_size:
|
1037 |
+
t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
|
1038 |
+
views, tile_masks = get_panorama_views(h, w, t)
|
1039 |
+
tile_masks = tile_masks.to(self.device)
|
1040 |
+
else:
|
1041 |
+
views = [(0, h, 0, w)]
|
1042 |
+
tile_masks = latent.new_ones((1, 1, h, w))
|
1043 |
+
value = torch.zeros_like(latent)
|
1044 |
+
count_all = torch.zeros_like(latent)
|
1045 |
+
|
1046 |
+
with torch.autocast('cuda'):
|
1047 |
+
for i, t in enumerate(tqdm(self.timesteps)):
|
1048 |
+
fg_mask = fg_masks[:, i]
|
1049 |
+
bg_mask = bg_masks[i:i + 1]
|
1050 |
+
|
1051 |
+
value.zero_()
|
1052 |
+
count_all.zero_()
|
1053 |
+
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
1054 |
+
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
1055 |
+
latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
|
1056 |
+
|
1057 |
+
# Bootstrap for tight background.
|
1058 |
+
if i < bootstrap_steps:
|
1059 |
+
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
1060 |
+
# Treat the first foreground latent as the background latent if one does not exist.
|
1061 |
+
bg_latent_ = bg_latent[..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
|
1062 |
+
white_ = white[..., h_start:h_end, w_start:w_end]
|
1063 |
+
bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
|
1064 |
+
bg_latent_ = self.scheduler_add_noise(bg_latent_, None, i)
|
1065 |
+
latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
|
1066 |
+
|
1067 |
+
# Centering.
|
1068 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
|
1069 |
+
|
1070 |
+
# Perform one step of the reverse diffusion.
|
1071 |
+
noise_pred = self.unet(torch.cat([latent_] * 2), t, encoder_hidden_states=text_embeds)['sample']
|
1072 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
1073 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1074 |
+
latent_ = self.scheduler_step(noise_pred, i, latent_)
|
1075 |
+
|
1076 |
+
if i < bootstrap_steps:
|
1077 |
+
# Uncentering.
|
1078 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
|
1079 |
+
|
1080 |
+
# Remove leakage (optional).
|
1081 |
+
leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
|
1082 |
+
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
1083 |
+
fg_mask_ = fg_mask_ * leak_sigmoid
|
1084 |
+
|
1085 |
+
# Mix the latents.
|
1086 |
+
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
1087 |
+
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
|
1088 |
+
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
1089 |
+
|
1090 |
+
latent = torch.where(count_all > 0, value / count_all, value)
|
1091 |
+
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
1092 |
+
if has_background:
|
1093 |
+
latent = (1 - bg_mask) * latent + bg_mask * bg_latent
|
1094 |
+
|
1095 |
+
# Noise is added after mixing.
|
1096 |
+
if i < len(self.timesteps) - 1:
|
1097 |
+
latent = self.scheduler_add_noise(latent, None, i + 1)
|
1098 |
+
|
1099 |
+
# Return PIL Image.
|
1100 |
+
image = self.decode_latents(latent.to(dtype=self.dtype))[0]
|
1101 |
+
if has_background and do_blend:
|
1102 |
+
fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
|
1103 |
+
image = blend(image, background[0], fg_mask)
|
1104 |
+
else:
|
1105 |
+
image = T.ToPILImage()(image)
|
1106 |
+
return image
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
xformers==0.0.22
|
4 |
+
einops
|
5 |
+
diffusers
|
6 |
+
transformers
|
7 |
+
huggingface_hub[torch]
|
8 |
+
gradio
|
9 |
+
Pillow
|
10 |
+
emoji
|
11 |
+
numpy
|
12 |
+
tqdm
|
13 |
+
jupyterlab
|
14 |
+
spaces
|
util.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import concurrent.futures
|
22 |
+
import time
|
23 |
+
from typing import Any, Callable, List, Tuple, Union
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
import torch.cuda.amp as amp
|
31 |
+
import torchvision.transforms as T
|
32 |
+
import torchvision.transforms.functional as TF
|
33 |
+
|
34 |
+
|
35 |
+
def seed_everything(seed: int) -> None:
|
36 |
+
torch.manual_seed(seed)
|
37 |
+
torch.cuda.manual_seed(seed)
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
torch.backends.cudnn.benchmark = True
|
40 |
+
|
41 |
+
|
42 |
+
def get_cutoff(cutoff: float = None, scale: float = None) -> float:
|
43 |
+
if cutoff is not None:
|
44 |
+
return cutoff
|
45 |
+
|
46 |
+
if scale is not None and cutoff is None:
|
47 |
+
return 0.5 / scale
|
48 |
+
|
49 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
50 |
+
|
51 |
+
|
52 |
+
def get_scale(cutoff: float = None, scale: float = None) -> float:
|
53 |
+
if scale is not None:
|
54 |
+
return scale
|
55 |
+
|
56 |
+
if cutoff is not None and scale is None:
|
57 |
+
return 0.5 / cutoff
|
58 |
+
|
59 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
60 |
+
|
61 |
+
|
62 |
+
def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
63 |
+
assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
|
64 |
+
# assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
|
65 |
+
|
66 |
+
b, c, h, w = x.shape
|
67 |
+
ks = k.shape[-1]
|
68 |
+
k = k.view(1, 1, -1).repeat(c, 1, 1)
|
69 |
+
|
70 |
+
x = x.permute(0, 2, 1, 3)
|
71 |
+
x = x.reshape(b * h, c, w)
|
72 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
73 |
+
x = F.conv1d(x, k, groups=c)
|
74 |
+
x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
|
75 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
76 |
+
x = F.conv1d(x, k, groups=c)
|
77 |
+
x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
82 |
+
assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
|
83 |
+
|
84 |
+
x = F.pad(x, (
|
85 |
+
k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
|
86 |
+
k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
|
87 |
+
), mode='replicate')
|
88 |
+
|
89 |
+
b, c, _, _ = x.shape
|
90 |
+
if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
|
91 |
+
k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
|
92 |
+
x = F.conv2d(x, k, groups=c)
|
93 |
+
elif len(k.shape) == 3:
|
94 |
+
assert k.shape[0] == b, \
|
95 |
+
'The number of kernels should match the batch size.'
|
96 |
+
|
97 |
+
k = k.unsqueeze(1)
|
98 |
+
x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
@amp.autocast(False)
|
103 |
+
def filter_by_kernel(
|
104 |
+
x: torch.Tensor,
|
105 |
+
k: torch.Tensor,
|
106 |
+
is_batch: bool = False,
|
107 |
+
) -> torch.Tensor:
|
108 |
+
k_dim = len(k.shape)
|
109 |
+
if k_dim == 1 or k_dim == 2 and is_batch:
|
110 |
+
return filter_2d_by_kernel_1d(x, k)
|
111 |
+
elif k_dim == 2 or k_dim == 3 and is_batch:
|
112 |
+
return filter_2d_by_kernel_2d(x, k)
|
113 |
+
else:
|
114 |
+
raise ValueError('Kernel size should be one of (1, 2, 3).')
|
115 |
+
|
116 |
+
|
117 |
+
def gen_gauss_lowpass_filter_2d(
|
118 |
+
std: torch.Tensor,
|
119 |
+
window_size: int = None,
|
120 |
+
) -> torch.Tensor:
|
121 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
122 |
+
if window_size is None:
|
123 |
+
window_size = (
|
124 |
+
2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
|
125 |
+
|
126 |
+
y = torch.arange(
|
127 |
+
window_size, dtype=std.dtype, device=std.device
|
128 |
+
).view(-1, 1).repeat(1, window_size)
|
129 |
+
grid = torch.stack((y.t(), y), dim=-1)
|
130 |
+
grid -= 0.5 * (window_size - 1) # (W, W)
|
131 |
+
var = (std * std).unsqueeze(-1).unsqueeze(-1)
|
132 |
+
distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
|
133 |
+
k = torch.exp(-0.5 * distsq / var)
|
134 |
+
k /= k.sum(dim=(-2, -1), keepdim=True)
|
135 |
+
return k
|
136 |
+
|
137 |
+
|
138 |
+
def gaussian_lowpass(
|
139 |
+
x: torch.Tensor,
|
140 |
+
std: Union[float, Tuple[float], torch.Tensor] = None,
|
141 |
+
cutoff: Union[float, torch.Tensor] = None,
|
142 |
+
scale: Union[float, torch.Tensor] = None,
|
143 |
+
) -> torch.Tensor:
|
144 |
+
if std is None:
|
145 |
+
cutoff = get_cutoff(cutoff, scale)
|
146 |
+
std = 0.5 / (np.pi * cutoff)
|
147 |
+
if isinstance(std, (float, int)):
|
148 |
+
std = (std, std)
|
149 |
+
if isinstance(std, torch.Tensor):
|
150 |
+
"""Using nn.functional.conv2d with Gaussian kernels built in runtime is
|
151 |
+
80% faster than transforms.functional.gaussian_blur for individual
|
152 |
+
items.
|
153 |
+
|
154 |
+
(in GPU); However, in CPU, the result is exactly opposite. But you
|
155 |
+
won't gonna run this on CPU, right?
|
156 |
+
"""
|
157 |
+
if len(list(s for s in std.shape if s != 1)) >= 2:
|
158 |
+
raise NotImplementedError(
|
159 |
+
'Anisotropic Gaussian filter is not currently available.')
|
160 |
+
|
161 |
+
# k.shape == (B, W, W).
|
162 |
+
k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
|
163 |
+
if k.shape[0] == 1:
|
164 |
+
return filter_by_kernel(x, k[0], False)
|
165 |
+
else:
|
166 |
+
return filter_by_kernel(x, k, True)
|
167 |
+
else:
|
168 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
169 |
+
window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
|
170 |
+
return TF.gaussian_blur(x, window_size, std)
|
171 |
+
|
172 |
+
|
173 |
+
def blend(
|
174 |
+
fg: Union[torch.Tensor, Image.Image],
|
175 |
+
bg: Union[torch.Tensor, Image.Image],
|
176 |
+
mask: Union[torch.Tensor, Image.Image],
|
177 |
+
std: float = 0.0,
|
178 |
+
) -> Image.Image:
|
179 |
+
if not isinstance(fg, torch.Tensor):
|
180 |
+
fg = T.ToTensor()(fg)
|
181 |
+
if not isinstance(bg, torch.Tensor):
|
182 |
+
bg = T.ToTensor()(bg)
|
183 |
+
if not isinstance(mask, torch.Tensor):
|
184 |
+
mask = (T.ToTensor()(mask) < 0.5).float()[:1]
|
185 |
+
if std > 0:
|
186 |
+
mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
|
187 |
+
return T.ToPILImage()(fg * mask + bg * (1 - mask))
|
188 |
+
|
189 |
+
|
190 |
+
def get_panorama_views(
|
191 |
+
panorama_height: int,
|
192 |
+
panorama_width: int,
|
193 |
+
window_size: int = 64,
|
194 |
+
) -> tuple[List[Tuple[int]], torch.Tensor]:
|
195 |
+
stride = window_size // 2
|
196 |
+
is_horizontal = panorama_width > panorama_height
|
197 |
+
num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
|
198 |
+
num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
|
199 |
+
total_num_blocks = num_blocks_height * num_blocks_width
|
200 |
+
|
201 |
+
half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
|
202 |
+
half_rev = half_fwd.flip(0)
|
203 |
+
if window_size % 2 == 1:
|
204 |
+
half_rev = half_rev[1:]
|
205 |
+
c = torch.cat((half_fwd, half_rev))
|
206 |
+
one = torch.ones_like(c)
|
207 |
+
f = c.clone()
|
208 |
+
f[:window_size // 2] = 1
|
209 |
+
b = c.clone()
|
210 |
+
b[-(window_size // 2):] = 1
|
211 |
+
|
212 |
+
h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
|
213 |
+
w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
|
214 |
+
|
215 |
+
views = []
|
216 |
+
masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
|
217 |
+
for i in range(total_num_blocks):
|
218 |
+
hi, wi = i // num_blocks_width, i % num_blocks_width
|
219 |
+
h_start = hi * stride
|
220 |
+
h_end = min(h_start + window_size, panorama_height)
|
221 |
+
w_start = wi * stride
|
222 |
+
w_end = min(w_start + window_size, panorama_width)
|
223 |
+
views.append((h_start, h_end, w_start, w_end))
|
224 |
+
|
225 |
+
h_width = h_end - h_start
|
226 |
+
w_width = w_end - w_start
|
227 |
+
masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
|
228 |
+
|
229 |
+
# Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
|
230 |
+
return views, masks[None] # (1, n, h, w)
|
231 |
+
|
232 |
+
|
233 |
+
def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
|
234 |
+
h, w = mask.shape[-2:]
|
235 |
+
device = mask.device
|
236 |
+
mask = mask.reshape(-1, h, w)
|
237 |
+
# assert mask.shape[0] == im.shape[0]
|
238 |
+
h_occupied = mask.sum(dim=-2) > 0
|
239 |
+
w_occupied = mask.sum(dim=-1) > 0
|
240 |
+
l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
|
241 |
+
r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
|
242 |
+
t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
|
243 |
+
b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
|
244 |
+
tb = (t + b + 1) // 2
|
245 |
+
lr = (l + r + 1) // 2
|
246 |
+
shifts = (tb - (h // 2), lr - (w // 2))
|
247 |
+
shifts = torch.cat(shifts, dim=1) # (p, 2)
|
248 |
+
if reverse:
|
249 |
+
shifts = shifts * -1
|
250 |
+
return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
|
251 |
+
|
252 |
+
|
253 |
+
class Streamer:
|
254 |
+
def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
|
255 |
+
self.fn = fn
|
256 |
+
self.ema_alpha = ema_alpha
|
257 |
+
|
258 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
259 |
+
self.future = self.executor.submit(fn)
|
260 |
+
self.image = None
|
261 |
+
|
262 |
+
self.prev_exec_time = 0
|
263 |
+
self.ema_exec_time = 0
|
264 |
+
|
265 |
+
@property
|
266 |
+
def throughput(self) -> float:
|
267 |
+
return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
|
268 |
+
|
269 |
+
def timed_fn(self) -> Any:
|
270 |
+
start = time.time()
|
271 |
+
res = self.fn()
|
272 |
+
end = time.time()
|
273 |
+
self.prev_exec_time = end - start
|
274 |
+
self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
|
275 |
+
return res
|
276 |
+
|
277 |
+
def __call__(self) -> Any:
|
278 |
+
if self.future.done() or self.image is None:
|
279 |
+
# get the result (the new image) and start a new task
|
280 |
+
image = self.future.result()
|
281 |
+
self.future = self.executor.submit(self.timed_fn)
|
282 |
+
self.image = image
|
283 |
+
return image
|
284 |
+
else:
|
285 |
+
# if self.fn() is not ready yet, use the previous image
|
286 |
+
# NOTE: This assumes that we have access to a previously generated image here.
|
287 |
+
# If there's no previous image (i.e., this is the first invocation), you could fall
|
288 |
+
# back to some default image or handle it differently based on your requirements.
|
289 |
+
return self.image
|