Spaces:
Running
on
Zero
Running
on
Zero
first commit
Browse files- README.md +3 -3
- app.py +948 -0
- examples/prompt_background.txt +8 -0
- examples/prompt_background_advanced.txt +0 -0
- examples/prompt_boy.txt +15 -0
- examples/prompt_girl.txt +16 -0
- examples/prompt_props.txt +43 -0
- model.py +1095 -0
- prompt_util.py +154 -0
- requirements.txt +16 -0
- share_btn.py +70 -0
- util.py +315 -0
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
+
title: Semantic Palette with Stable Diffusion 3
|
3 |
+
emoji: 🧠🎨3️
|
4 |
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,948 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import sys
|
22 |
+
|
23 |
+
sys.path.append('../../src')
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import random
|
27 |
+
import time
|
28 |
+
import json
|
29 |
+
import os
|
30 |
+
import glob
|
31 |
+
import pathlib
|
32 |
+
from functools import partial
|
33 |
+
from pprint import pprint
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
from PIL import Image
|
37 |
+
import torch
|
38 |
+
|
39 |
+
import gradio as gr
|
40 |
+
from huggingface_hub import snapshot_download
|
41 |
+
|
42 |
+
from model import StableMultiDiffusion3Pipeline
|
43 |
+
from util import seed_everything
|
44 |
+
from prompt_util import preprocess_prompts, _quality_dict, _style_dict
|
45 |
+
|
46 |
+
|
47 |
+
### Utils
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def log_state(state):
|
53 |
+
pprint(vars(opt))
|
54 |
+
if isinstance(state, gr.State):
|
55 |
+
state = state.value
|
56 |
+
pprint(vars(state))
|
57 |
+
|
58 |
+
|
59 |
+
def is_empty_image(im: Image.Image) -> bool:
|
60 |
+
if im is None:
|
61 |
+
return True
|
62 |
+
im = np.array(im)
|
63 |
+
has_alpha = (im.shape[2] == 4)
|
64 |
+
if not has_alpha:
|
65 |
+
return False
|
66 |
+
elif im.sum() == 0:
|
67 |
+
return True
|
68 |
+
else:
|
69 |
+
return False
|
70 |
+
|
71 |
+
|
72 |
+
### Argument passing
|
73 |
+
|
74 |
+
parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SD3 support.')
|
75 |
+
parser.add_argument('-H', '--height', type=int, default=1024)
|
76 |
+
parser.add_argument('-W', '--width', type=int, default=2560)
|
77 |
+
parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
|
78 |
+
parser.add_argument('--bootstrap_steps', type=int, default=2)
|
79 |
+
parser.add_argument('--seed', type=int, default=-1)
|
80 |
+
parser.add_argument('--device', type=int, default=0)
|
81 |
+
parser.add_argument('--port', type=int, default=8000)
|
82 |
+
opt = parser.parse_args()
|
83 |
+
|
84 |
+
|
85 |
+
### Global variables and data structures
|
86 |
+
|
87 |
+
device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
|
88 |
+
|
89 |
+
|
90 |
+
if opt.model is None:
|
91 |
+
model_dict = {
|
92 |
+
'Stable Diffusion 3': 'stabilityai/stable-diffusion-3-medium-diffusers',
|
93 |
+
}
|
94 |
+
else:
|
95 |
+
if opt.model.endswith('.safetensors'):
|
96 |
+
opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
|
97 |
+
model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
|
98 |
+
|
99 |
+
dtype = torch.float32 if device == 'cpu' else torch.float16
|
100 |
+
models = {
|
101 |
+
k: StableMultiDiffusion3Pipeline(device, dtype=dtype, hf_key=v, has_i2t=False)
|
102 |
+
for k, v in model_dict.items()
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
prompt_suggestions = [
|
107 |
+
'1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
|
108 |
+
'1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
|
109 |
+
'1girl, arima kana, oshi no ko, solo, upper body, from behind',
|
110 |
+
]
|
111 |
+
|
112 |
+
opt.max_palettes = 4
|
113 |
+
opt.default_prompt_strength = 1.0
|
114 |
+
opt.default_mask_strength = 1.0
|
115 |
+
opt.default_mask_std = 0.0
|
116 |
+
opt.default_negative_prompt = (
|
117 |
+
'nsfw, worst quality, bad quality, normal quality, cropped, framed'
|
118 |
+
)
|
119 |
+
opt.verbose = True
|
120 |
+
opt.colors = [
|
121 |
+
'#000000',
|
122 |
+
'#2692F3',
|
123 |
+
'#F89E12',
|
124 |
+
'#16C232',
|
125 |
+
'#F92F6C',
|
126 |
+
# '#AC6AEB',
|
127 |
+
# '#92C62C',
|
128 |
+
# '#92C6EC',
|
129 |
+
# '#FECAC0',
|
130 |
+
]
|
131 |
+
|
132 |
+
|
133 |
+
### Event handlers
|
134 |
+
|
135 |
+
def add_palette(state):
|
136 |
+
old_actives = state.active_palettes
|
137 |
+
state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
|
138 |
+
|
139 |
+
if opt.verbose:
|
140 |
+
log_state(state)
|
141 |
+
|
142 |
+
if state.active_palettes != old_actives:
|
143 |
+
return [state] + [
|
144 |
+
gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
|
145 |
+
] + [
|
146 |
+
gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
|
147 |
+
for i in range(opt.max_palettes)
|
148 |
+
]
|
149 |
+
else:
|
150 |
+
return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
|
151 |
+
|
152 |
+
|
153 |
+
def select_palette(state, button, idx):
|
154 |
+
if idx < 0 or idx > opt.max_palettes:
|
155 |
+
idx = 0
|
156 |
+
old_idx = state.current_palette
|
157 |
+
if old_idx == idx:
|
158 |
+
return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
|
159 |
+
|
160 |
+
state.current_palette = idx
|
161 |
+
|
162 |
+
if opt.verbose:
|
163 |
+
log_state(state)
|
164 |
+
|
165 |
+
updates = [state] + [
|
166 |
+
gr.update() if i not in (idx, old_idx) else
|
167 |
+
gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
|
168 |
+
for i in range(opt.max_palettes + 1)
|
169 |
+
]
|
170 |
+
label = 'Background' if idx == 0 else f'Palette {idx}'
|
171 |
+
updates.extend([
|
172 |
+
gr.update(value=button, interactive=(idx > 0)),
|
173 |
+
gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
|
174 |
+
gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
|
175 |
+
(
|
176 |
+
gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
|
177 |
+
gr.update(value=opt.default_mask_strength, interactive=False)
|
178 |
+
),
|
179 |
+
(
|
180 |
+
gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
|
181 |
+
gr.update(value=opt.default_prompt_strength, interactive=False)
|
182 |
+
),
|
183 |
+
(
|
184 |
+
gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
|
185 |
+
gr.update(value=opt.default_mask_std, interactive=False)
|
186 |
+
),
|
187 |
+
])
|
188 |
+
return updates
|
189 |
+
|
190 |
+
|
191 |
+
def change_prompt_strength(state, strength):
|
192 |
+
if state.current_palette == 0:
|
193 |
+
return state
|
194 |
+
|
195 |
+
state.prompt_strengths[state.current_palette - 1] = strength
|
196 |
+
if opt.verbose:
|
197 |
+
log_state(state)
|
198 |
+
|
199 |
+
return state
|
200 |
+
|
201 |
+
|
202 |
+
def change_std(state, std):
|
203 |
+
if state.current_palette == 0:
|
204 |
+
return state
|
205 |
+
|
206 |
+
state.mask_stds[state.current_palette - 1] = std
|
207 |
+
if opt.verbose:
|
208 |
+
log_state(state)
|
209 |
+
|
210 |
+
return state
|
211 |
+
|
212 |
+
|
213 |
+
def change_mask_strength(state, strength):
|
214 |
+
if state.current_palette == 0:
|
215 |
+
return state
|
216 |
+
|
217 |
+
state.mask_strengths[state.current_palette - 1] = strength
|
218 |
+
if opt.verbose:
|
219 |
+
log_state(state)
|
220 |
+
|
221 |
+
return state
|
222 |
+
|
223 |
+
|
224 |
+
def reset_seed(state, seed):
|
225 |
+
state.seed = seed
|
226 |
+
if opt.verbose:
|
227 |
+
log_state(state)
|
228 |
+
|
229 |
+
return state
|
230 |
+
|
231 |
+
def rename_prompt(state, name):
|
232 |
+
state.prompt_names[state.current_palette] = name
|
233 |
+
if opt.verbose:
|
234 |
+
log_state(state)
|
235 |
+
|
236 |
+
return [state] + [
|
237 |
+
gr.update() if i != state.current_palette else gr.update(value=name)
|
238 |
+
for i in range(opt.max_palettes + 1)
|
239 |
+
]
|
240 |
+
|
241 |
+
|
242 |
+
def change_prompt(state, prompt):
|
243 |
+
state.prompts[state.current_palette] = prompt
|
244 |
+
if opt.verbose:
|
245 |
+
log_state(state)
|
246 |
+
|
247 |
+
return state
|
248 |
+
|
249 |
+
|
250 |
+
def change_neg_prompt(state, neg_prompt):
|
251 |
+
state.neg_prompts[state.current_palette] = neg_prompt
|
252 |
+
if opt.verbose:
|
253 |
+
log_state(state)
|
254 |
+
|
255 |
+
return state
|
256 |
+
|
257 |
+
|
258 |
+
def select_model(state, model_id):
|
259 |
+
state.model_id = model_id
|
260 |
+
if opt.verbose:
|
261 |
+
log_state(state)
|
262 |
+
|
263 |
+
return state
|
264 |
+
|
265 |
+
|
266 |
+
def select_style(state, style_name):
|
267 |
+
state.style_name = style_name
|
268 |
+
if opt.verbose:
|
269 |
+
log_state(state)
|
270 |
+
|
271 |
+
return state
|
272 |
+
|
273 |
+
|
274 |
+
def select_quality(state, quality_name):
|
275 |
+
state.quality_name = quality_name
|
276 |
+
if opt.verbose:
|
277 |
+
log_state(state)
|
278 |
+
|
279 |
+
return state
|
280 |
+
|
281 |
+
|
282 |
+
def import_state(state, json_text):
|
283 |
+
current_palette = state.current_palette
|
284 |
+
# active_palettes = state.active_palettes
|
285 |
+
state = argparse.Namespace(**json.loads(json_text))
|
286 |
+
state.active_palettes = opt.max_palettes
|
287 |
+
return [state] + [
|
288 |
+
gr.update(value=v, visible=True) for v in state.prompt_names
|
289 |
+
] + [
|
290 |
+
# state.model_id,
|
291 |
+
# state.style_name,
|
292 |
+
# state.quality_name,
|
293 |
+
state.prompts[current_palette],
|
294 |
+
state.prompt_names[current_palette],
|
295 |
+
state.neg_prompts[current_palette],
|
296 |
+
state.prompt_strengths[current_palette - 1],
|
297 |
+
state.mask_strengths[current_palette - 1],
|
298 |
+
state.mask_stds[current_palette - 1],
|
299 |
+
state.seed,
|
300 |
+
]
|
301 |
+
|
302 |
+
|
303 |
+
### Main worker
|
304 |
+
|
305 |
+
def generate(state, *args, **kwargs):
|
306 |
+
return models[state.model_id](*args, **kwargs)
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
def run(state, drawpad):
|
311 |
+
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
312 |
+
print('Generate!')
|
313 |
+
|
314 |
+
background = drawpad['background'].convert('RGBA')
|
315 |
+
inpainting_mode = np.asarray(background).sum() != 0
|
316 |
+
print('Inpainting mode: ', inpainting_mode)
|
317 |
+
|
318 |
+
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
319 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
320 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
321 |
+
|
322 |
+
palette = torch.tensor([
|
323 |
+
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
324 |
+
for s in opt.colors[1:]
|
325 |
+
]) # (N, 3)
|
326 |
+
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
327 |
+
has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
328 |
+
print('Has mask: ', has_masks)
|
329 |
+
masks = masks * foreground_mask
|
330 |
+
masks = masks[has_masks]
|
331 |
+
|
332 |
+
if inpainting_mode:
|
333 |
+
prompts = [state.prompts[v + 1] for v in has_masks]
|
334 |
+
negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
|
335 |
+
mask_strengths = [state.mask_strengths[v] for v in has_masks]
|
336 |
+
mask_stds = [state.mask_stds[v] for v in has_masks]
|
337 |
+
prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
|
338 |
+
else:
|
339 |
+
masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
340 |
+
prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
|
341 |
+
negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
|
342 |
+
mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
343 |
+
mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
344 |
+
prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
|
345 |
+
|
346 |
+
prompts, negative_prompts = preprocess_prompts(
|
347 |
+
prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
348 |
+
|
349 |
+
return generate(
|
350 |
+
state,
|
351 |
+
prompts,
|
352 |
+
negative_prompts,
|
353 |
+
masks=masks,
|
354 |
+
mask_strengths=mask_strengths,
|
355 |
+
mask_stds=mask_stds,
|
356 |
+
prompt_strengths=prompt_strengths,
|
357 |
+
background=background.convert('RGB'),
|
358 |
+
background_prompt=state.prompts[0],
|
359 |
+
background_negative_prompt=state.neg_prompts[0],
|
360 |
+
height=opt.height,
|
361 |
+
width=opt.width,
|
362 |
+
bootstrap_steps=2,
|
363 |
+
guidance_scale=0,
|
364 |
+
)
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
### Load examples
|
369 |
+
|
370 |
+
|
371 |
+
root = pathlib.Path(__file__).parent
|
372 |
+
print(root)
|
373 |
+
example_root = os.path.join(root, 'examples')
|
374 |
+
example_images = glob.glob(os.path.join(example_root, '*.webp'))
|
375 |
+
example_images = [Image.open(i) for i in example_images]
|
376 |
+
|
377 |
+
with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
|
378 |
+
prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
|
379 |
+
|
380 |
+
with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
|
381 |
+
prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
|
382 |
+
|
383 |
+
with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
|
384 |
+
prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
|
385 |
+
|
386 |
+
with open(os.path.join(example_root, 'prompt_props.txt')) as f:
|
387 |
+
prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
|
388 |
+
prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
|
389 |
+
|
390 |
+
prompt_background = lambda: random.choice(prompts_background)
|
391 |
+
prompt_girl = lambda: random.choice(prompts_girl)
|
392 |
+
prompt_boy = lambda: random.choice(prompts_boy)
|
393 |
+
prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
|
394 |
+
|
395 |
+
|
396 |
+
### Main application
|
397 |
+
|
398 |
+
css = f"""
|
399 |
+
#run-button {{
|
400 |
+
font-size: 30pt;
|
401 |
+
background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
|
402 |
+
margin: 0;
|
403 |
+
padding: 15px 45px;
|
404 |
+
text-align: center;
|
405 |
+
text-transform: uppercase;
|
406 |
+
transition: 0.5s;
|
407 |
+
background-size: 200% auto;
|
408 |
+
color: white;
|
409 |
+
box-shadow: 0 0 20px #eee;
|
410 |
+
border-radius: 10px;
|
411 |
+
display: block;
|
412 |
+
background-position: right center;
|
413 |
+
}}
|
414 |
+
|
415 |
+
#run-button:hover {{
|
416 |
+
background-position: left center;
|
417 |
+
color: #fff;
|
418 |
+
text-decoration: none;
|
419 |
+
}}
|
420 |
+
|
421 |
+
#semantic-palette {{
|
422 |
+
border-style: solid;
|
423 |
+
border-width: 0.2em;
|
424 |
+
border-color: #eee;
|
425 |
+
}}
|
426 |
+
|
427 |
+
#semantic-palette:hover {{
|
428 |
+
box-shadow: 0 0 20px #eee;
|
429 |
+
}}
|
430 |
+
|
431 |
+
#output-screen {{
|
432 |
+
width: 100%;
|
433 |
+
aspect-ratio: {opt.width} / {opt.height};
|
434 |
+
}}
|
435 |
+
|
436 |
+
.layer-wrap {{
|
437 |
+
display: none;
|
438 |
+
}}
|
439 |
+
|
440 |
+
.rainbow {{
|
441 |
+
text-align: center;
|
442 |
+
text-decoration: underline;
|
443 |
+
font-size: 32px;
|
444 |
+
font-family: monospace;
|
445 |
+
letter-spacing: 5px;
|
446 |
+
}}
|
447 |
+
.rainbow_text_animated {{
|
448 |
+
background: linear-gradient(to right, #6666ff, #0099ff , #00ff00, #ff3399, #6666ff);
|
449 |
+
-webkit-background-clip: text;
|
450 |
+
background-clip: text;
|
451 |
+
color: transparent;
|
452 |
+
animation: rainbow_animation 6s ease-in-out infinite;
|
453 |
+
background-size: 400% 100%;
|
454 |
+
}}
|
455 |
+
|
456 |
+
@keyframes rainbow_animation {{
|
457 |
+
0%,100% {{
|
458 |
+
background-position: 0 0;
|
459 |
+
}}
|
460 |
+
|
461 |
+
50% {{
|
462 |
+
background-position: 100% 0;
|
463 |
+
}}
|
464 |
+
}}
|
465 |
+
|
466 |
+
.gallery {{
|
467 |
+
--z: 16px; /* control the zig-zag */
|
468 |
+
--s: 144px; /* control the size */
|
469 |
+
--g: 4px; /* control the gap */
|
470 |
+
|
471 |
+
display: grid;
|
472 |
+
gap: var(--g);
|
473 |
+
width: calc(2*var(--s) + var(--g));
|
474 |
+
grid-auto-flow: column;
|
475 |
+
}}
|
476 |
+
.gallery > a {{
|
477 |
+
width: 0;
|
478 |
+
min-width: calc(100% + var(--z)/2);
|
479 |
+
height: var(--s);
|
480 |
+
object-fit: cover;
|
481 |
+
-webkit-mask: var(--mask);
|
482 |
+
mask: var(--mask);
|
483 |
+
cursor: pointer;
|
484 |
+
transition: .5s;
|
485 |
+
}}
|
486 |
+
.gallery > a:hover {{
|
487 |
+
width: calc(var(--s)/2);
|
488 |
+
}}
|
489 |
+
.gallery > a:first-child {{
|
490 |
+
place-self: start;
|
491 |
+
clip-path: polygon(calc(2*var(--z)) 0,100% 0,100% 100%,0 100%);
|
492 |
+
--mask:
|
493 |
+
conic-gradient(from -135deg at right,#0000,#000 1deg 89deg,#0000 90deg)
|
494 |
+
50%/100% calc(2*var(--z)) repeat-y;
|
495 |
+
}}
|
496 |
+
.gallery > a:last-child {{
|
497 |
+
place-self: end;
|
498 |
+
clip-path: polygon(0 0,100% 0,calc(100% - 2*var(--z)) 100%,0 100%);
|
499 |
+
--mask:
|
500 |
+
conic-gradient(from 45deg at left ,#0000,#000 1deg 89deg,#0000 90deg)
|
501 |
+
50% calc(50% - var(--z))/100% calc(2*var(--z)) repeat-y;
|
502 |
+
}}
|
503 |
+
"""
|
504 |
+
|
505 |
+
for i in range(opt.max_palettes + 1):
|
506 |
+
css = css + f"""
|
507 |
+
.secondary#semantic-palette-{i} {{
|
508 |
+
background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
|
509 |
+
color: white;
|
510 |
+
}}
|
511 |
+
|
512 |
+
.primary#semantic-palette-{i} {{
|
513 |
+
background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
|
514 |
+
color: white;
|
515 |
+
}}
|
516 |
+
"""
|
517 |
+
|
518 |
+
|
519 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
520 |
+
|
521 |
+
iface = argparse.Namespace()
|
522 |
+
|
523 |
+
def _define_state():
|
524 |
+
state = argparse.Namespace()
|
525 |
+
|
526 |
+
# Cursor.
|
527 |
+
state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
|
528 |
+
state.model_id = list(model_dict.keys())[0]
|
529 |
+
state.style_name = '(None)'
|
530 |
+
state.quality_name = '(None)' # 'Standard v3.1'
|
531 |
+
|
532 |
+
# State variables (one-hot).
|
533 |
+
state.active_palettes = 1
|
534 |
+
|
535 |
+
# Front-end initialized to the default values.
|
536 |
+
prompt_props_ = prompt_props()
|
537 |
+
state.prompt_names = [
|
538 |
+
'🌄 Background',
|
539 |
+
'👧 Girl',
|
540 |
+
'👦 Boy',
|
541 |
+
] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
|
542 |
+
state.prompts = [
|
543 |
+
prompt_background(),
|
544 |
+
prompt_girl(),
|
545 |
+
prompt_boy(),
|
546 |
+
] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
|
547 |
+
state.neg_prompts = [
|
548 |
+
opt.default_negative_prompt
|
549 |
+
+ (', humans, humans, humans' if i == 0 else '')
|
550 |
+
for i in range(opt.max_palettes + 1)
|
551 |
+
]
|
552 |
+
state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
|
553 |
+
state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
|
554 |
+
state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
|
555 |
+
state.seed = opt.seed
|
556 |
+
return state
|
557 |
+
|
558 |
+
state = gr.State(value=_define_state)
|
559 |
+
|
560 |
+
|
561 |
+
### Demo user interface
|
562 |
+
|
563 |
+
gr.HTML(
|
564 |
+
"""
|
565 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
566 |
+
<div>
|
567 |
+
<h1>🧠 Semantic Palette with <font class="rainbow rainbow_text_animated">Stable Diffusion 3</font> 🎨</h1>
|
568 |
+
<h5 style="margin: 0;">powered by</h5>
|
569 |
+
<h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
|
570 |
+
<h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
|
571 |
+
</br>
|
572 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
573 |
+
<a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
|
574 |
+
<img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
|
575 |
+
</a>
|
576 |
+
|
577 |
+
<a href='https://arxiv.org/abs/2403.09055'>
|
578 |
+
<img src="https://img.shields.io/badge/arXiv-2403.09055-red">
|
579 |
+
</a>
|
580 |
+
|
581 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion'>
|
582 |
+
<img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
|
583 |
+
</a>
|
584 |
+
|
585 |
+
<a href='https://twitter.com/_ironjr_'>
|
586 |
+
<img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
|
587 |
+
</a>
|
588 |
+
|
589 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
|
590 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
591 |
+
</a>
|
592 |
+
|
593 |
+
<a href='https://huggingface.co/spaces/ironjr/StreamMultiDiffusion'>
|
594 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-StreamMultiDiffusion-yellow'>
|
595 |
+
</a>
|
596 |
+
|
597 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
|
598 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD1.5-yellow'>
|
599 |
+
</a>
|
600 |
+
|
601 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPaletteXL'>
|
602 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SDXL-yellow'>
|
603 |
+
</a>
|
604 |
+
|
605 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPalette3'>
|
606 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD3-yellow'>
|
607 |
+
</a>
|
608 |
+
</div>
|
609 |
+
</div>
|
610 |
+
</div>
|
611 |
+
<div>
|
612 |
+
</br>
|
613 |
+
</div>
|
614 |
+
"""
|
615 |
+
)
|
616 |
+
|
617 |
+
with gr.Row():
|
618 |
+
|
619 |
+
iface.image_slot = gr.Image(
|
620 |
+
interactive=False,
|
621 |
+
show_label=False,
|
622 |
+
show_download_button=True,
|
623 |
+
type='pil',
|
624 |
+
label='Generated Result',
|
625 |
+
elem_id='output-screen',
|
626 |
+
value=lambda: random.choice(example_images),
|
627 |
+
)
|
628 |
+
|
629 |
+
with gr.Row():
|
630 |
+
|
631 |
+
with gr.Column(scale=1):
|
632 |
+
|
633 |
+
with gr.Group(elem_id='semantic-palette'):
|
634 |
+
|
635 |
+
gr.HTML(
|
636 |
+
"""
|
637 |
+
<div style="justify-content: center; align-items: center;">
|
638 |
+
<br/>
|
639 |
+
<h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
|
640 |
+
<br/>
|
641 |
+
</div>
|
642 |
+
"""
|
643 |
+
)
|
644 |
+
|
645 |
+
iface.btn_semantics = [gr.Button(
|
646 |
+
value=state.value.prompt_names[0],
|
647 |
+
variant='primary',
|
648 |
+
elem_id='semantic-palette-0',
|
649 |
+
)]
|
650 |
+
for i in range(opt.max_palettes):
|
651 |
+
iface.btn_semantics.append(gr.Button(
|
652 |
+
value=state.value.prompt_names[i + 1],
|
653 |
+
variant='secondary',
|
654 |
+
visible=(i < state.value.active_palettes),
|
655 |
+
elem_id=f'semantic-palette-{i + 1}'
|
656 |
+
))
|
657 |
+
|
658 |
+
iface.btn_add_palette = gr.Button(
|
659 |
+
value='Create New Semantic Brush',
|
660 |
+
variant='primary',
|
661 |
+
)
|
662 |
+
|
663 |
+
with gr.Accordion(label='Import/Export Semantic Palette', open=False):
|
664 |
+
iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
|
665 |
+
iface.json_state_export = gr.JSON(label='Exported Palette')
|
666 |
+
iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
|
667 |
+
iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
|
668 |
+
|
669 |
+
gr.HTML(
|
670 |
+
"""
|
671 |
+
<div>
|
672 |
+
</br>
|
673 |
+
</div>
|
674 |
+
<div style="justify-content: center; align-items: center;">
|
675 |
+
<h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
|
676 |
+
</br>
|
677 |
+
<div style="justify-content: center; align-items: left; text-align: left;">
|
678 |
+
<p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
|
679 |
+
<p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
|
680 |
+
<p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
|
681 |
+
<p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
|
682 |
+
<p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
|
683 |
+
<p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
|
684 |
+
</div>
|
685 |
+
</div>
|
686 |
+
"""
|
687 |
+
)
|
688 |
+
|
689 |
+
gr.HTML(
|
690 |
+
"""
|
691 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
692 |
+
<h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
|
693 |
+
</div>
|
694 |
+
"""
|
695 |
+
)
|
696 |
+
|
697 |
+
gr.DuplicateButton()
|
698 |
+
|
699 |
+
with gr.Column(scale=4):
|
700 |
+
|
701 |
+
with gr.Row():
|
702 |
+
|
703 |
+
with gr.Column(scale=3):
|
704 |
+
|
705 |
+
iface.ctrl_semantic = gr.ImageEditor(
|
706 |
+
image_mode='RGBA',
|
707 |
+
sources=['upload', 'clipboard', 'webcam'],
|
708 |
+
transforms=['crop'],
|
709 |
+
crop_size=(opt.width, opt.height),
|
710 |
+
brush=gr.Brush(
|
711 |
+
colors=opt.colors[1:],
|
712 |
+
color_mode="fixed",
|
713 |
+
),
|
714 |
+
layers=False,
|
715 |
+
canvas_size=(opt.width, opt.height),
|
716 |
+
type='pil',
|
717 |
+
label='Semantic Drawpad',
|
718 |
+
elem_id='drawpad',
|
719 |
+
)
|
720 |
+
|
721 |
+
with gr.Column(scale=1):
|
722 |
+
|
723 |
+
iface.btn_generate = gr.Button(
|
724 |
+
value='Generate!',
|
725 |
+
variant='primary',
|
726 |
+
# scale=1,
|
727 |
+
elem_id='run-button'
|
728 |
+
)
|
729 |
+
|
730 |
+
|
731 |
+
gr.HTML(
|
732 |
+
"""
|
733 |
+
<h3 style="text-align: center;">Try other demos in HF 🤗 Space!</h3>
|
734 |
+
<div style="display: flex; justify-content: center; text-align: center;">
|
735 |
+
<div><b style="color: #2692F3">Semantic Palette<br>Animagine XL 3.1</b></div>
|
736 |
+
<div style="margin-left: 10px; margin-right: 10px; margin-top: 8px">or</div>
|
737 |
+
<div><b style="color: #F89E12">Official Demo of<br>StreamMultiDiffusion</b></div>
|
738 |
+
</div>
|
739 |
+
<div style="display: inline-block; margin-top: 10px">
|
740 |
+
<div class="gallery">
|
741 |
+
<a href="https://huggingface.co/spaces/ironjr/SemanticPaletteXL" target="_blank">
|
742 |
+
<img alt="AnimagineXL3.1 Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/sdxl.webp?raw=true">
|
743 |
+
</a>
|
744 |
+
<a href="https://huggingface.co/spaces/ironjr/StreamMultiDiffusion" target="_blank">
|
745 |
+
<img alt="StreamMultiDiffusion Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/smd.gif?raw=true">
|
746 |
+
</a>
|
747 |
+
</div>
|
748 |
+
</div>
|
749 |
+
"""
|
750 |
+
)
|
751 |
+
|
752 |
+
# iface.model_select = gr.Radio(
|
753 |
+
# list(model_dict.keys()),
|
754 |
+
# label='Stable Diffusion Checkpoint',
|
755 |
+
# info='Choose your favorite style.',
|
756 |
+
# value=state.value.model_id,
|
757 |
+
# )
|
758 |
+
|
759 |
+
# with gr.Accordion(label='Prompt Engineering', open=True):
|
760 |
+
# iface.quality_select = gr.Dropdown(
|
761 |
+
# label='Quality Presets',
|
762 |
+
# interactive=True,
|
763 |
+
# choices=list(_quality_dict.keys()),
|
764 |
+
# value='Standard v3.1',
|
765 |
+
# )
|
766 |
+
# iface.style_select = gr.Radio(
|
767 |
+
# label='Style Preset',
|
768 |
+
# container=True,
|
769 |
+
# interactive=True,
|
770 |
+
# choices=list(_style_dict.keys()),
|
771 |
+
# value='(None)',
|
772 |
+
# )
|
773 |
+
|
774 |
+
with gr.Group(elem_id='control-panel'):
|
775 |
+
|
776 |
+
with gr.Row():
|
777 |
+
iface.tbox_prompt = gr.Textbox(
|
778 |
+
label='Edit Prompt for Background',
|
779 |
+
info='What do you want to draw?',
|
780 |
+
value=state.value.prompts[0],
|
781 |
+
placeholder=lambda: random.choice(prompt_suggestions),
|
782 |
+
scale=2,
|
783 |
+
)
|
784 |
+
|
785 |
+
iface.tbox_name = gr.Textbox(
|
786 |
+
label='Edit Brush Name',
|
787 |
+
info='Just for your convenience.',
|
788 |
+
value=state.value.prompt_names[0],
|
789 |
+
placeholder='🌄 Background',
|
790 |
+
scale=1,
|
791 |
+
)
|
792 |
+
|
793 |
+
with gr.Row():
|
794 |
+
iface.tbox_neg_prompt = gr.Textbox(
|
795 |
+
label='Edit Negative Prompt for Background',
|
796 |
+
info='Add unwanted objects for this semantic brush.',
|
797 |
+
value=opt.default_negative_prompt,
|
798 |
+
scale=2,
|
799 |
+
)
|
800 |
+
|
801 |
+
iface.slider_strength = gr.Slider(
|
802 |
+
label='Prompt Strength',
|
803 |
+
info='Blends fg & bg in the prompt level, >0.8 Preferred.',
|
804 |
+
minimum=0.5,
|
805 |
+
maximum=1.0,
|
806 |
+
value=opt.default_prompt_strength,
|
807 |
+
scale=1,
|
808 |
+
)
|
809 |
+
|
810 |
+
with gr.Row():
|
811 |
+
iface.slider_alpha = gr.Slider(
|
812 |
+
label='Mask Alpha',
|
813 |
+
info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
|
814 |
+
minimum=0.5,
|
815 |
+
maximum=1.0,
|
816 |
+
value=opt.default_mask_strength,
|
817 |
+
)
|
818 |
+
|
819 |
+
iface.slider_std = gr.Slider(
|
820 |
+
label='Mask Blur STD',
|
821 |
+
info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
|
822 |
+
minimum=0.0001,
|
823 |
+
maximum=100.0,
|
824 |
+
value=opt.default_mask_std,
|
825 |
+
)
|
826 |
+
|
827 |
+
iface.slider_seed = gr.Slider(
|
828 |
+
label='Seed',
|
829 |
+
info='The global seed.',
|
830 |
+
minimum=-1,
|
831 |
+
maximum=2147483647,
|
832 |
+
step=1,
|
833 |
+
value=opt.seed,
|
834 |
+
)
|
835 |
+
|
836 |
+
### Attach event handlers
|
837 |
+
|
838 |
+
for idx, btn in enumerate(iface.btn_semantics):
|
839 |
+
btn.click(
|
840 |
+
fn=partial(select_palette, idx=idx),
|
841 |
+
inputs=[state, btn],
|
842 |
+
outputs=[state] + iface.btn_semantics + [
|
843 |
+
iface.tbox_name,
|
844 |
+
iface.tbox_prompt,
|
845 |
+
iface.tbox_neg_prompt,
|
846 |
+
iface.slider_alpha,
|
847 |
+
iface.slider_strength,
|
848 |
+
iface.slider_std,
|
849 |
+
],
|
850 |
+
api_name=f'select_palette_{idx}',
|
851 |
+
)
|
852 |
+
|
853 |
+
iface.btn_add_palette.click(
|
854 |
+
fn=add_palette,
|
855 |
+
inputs=state,
|
856 |
+
outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
|
857 |
+
api_name='create_new',
|
858 |
+
)
|
859 |
+
|
860 |
+
iface.btn_generate.click(
|
861 |
+
fn=run,
|
862 |
+
inputs=[state, iface.ctrl_semantic],
|
863 |
+
outputs=iface.image_slot,
|
864 |
+
api_name='run',
|
865 |
+
)
|
866 |
+
|
867 |
+
iface.slider_alpha.input(
|
868 |
+
fn=change_mask_strength,
|
869 |
+
inputs=[state, iface.slider_alpha],
|
870 |
+
outputs=state,
|
871 |
+
api_name='change_alpha',
|
872 |
+
)
|
873 |
+
iface.slider_std.input(
|
874 |
+
fn=change_std,
|
875 |
+
inputs=[state, iface.slider_std],
|
876 |
+
outputs=state,
|
877 |
+
api_name='change_std',
|
878 |
+
)
|
879 |
+
iface.slider_strength.input(
|
880 |
+
fn=change_prompt_strength,
|
881 |
+
inputs=[state, iface.slider_strength],
|
882 |
+
outputs=state,
|
883 |
+
api_name='change_strength',
|
884 |
+
)
|
885 |
+
iface.slider_seed.input(
|
886 |
+
fn=reset_seed,
|
887 |
+
inputs=[state, iface.slider_seed],
|
888 |
+
outputs=state,
|
889 |
+
api_name='reset_seed',
|
890 |
+
)
|
891 |
+
|
892 |
+
iface.tbox_name.input(
|
893 |
+
fn=rename_prompt,
|
894 |
+
inputs=[state, iface.tbox_name],
|
895 |
+
outputs=[state] + iface.btn_semantics,
|
896 |
+
api_name='prompt_rename',
|
897 |
+
)
|
898 |
+
iface.tbox_prompt.input(
|
899 |
+
fn=change_prompt,
|
900 |
+
inputs=[state, iface.tbox_prompt],
|
901 |
+
outputs=state,
|
902 |
+
api_name='prompt_edit',
|
903 |
+
)
|
904 |
+
iface.tbox_neg_prompt.input(
|
905 |
+
fn=change_neg_prompt,
|
906 |
+
inputs=[state, iface.tbox_neg_prompt],
|
907 |
+
outputs=state,
|
908 |
+
api_name='neg_prompt_edit',
|
909 |
+
)
|
910 |
+
|
911 |
+
# iface.model_select.change(
|
912 |
+
# fn=select_model,
|
913 |
+
# inputs=[state, iface.model_select],
|
914 |
+
# outputs=state,
|
915 |
+
# api_name='model_select',
|
916 |
+
# )
|
917 |
+
# iface.style_select.change(
|
918 |
+
# fn=select_style,
|
919 |
+
# inputs=[state, iface.style_select],
|
920 |
+
# outputs=state,
|
921 |
+
# api_name='style_select',
|
922 |
+
# )
|
923 |
+
# iface.quality_select.change(
|
924 |
+
# fn=select_quality,
|
925 |
+
# inputs=[state, iface.quality_select],
|
926 |
+
# outputs=state,
|
927 |
+
# api_name='quality_select',
|
928 |
+
# )
|
929 |
+
|
930 |
+
iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
|
931 |
+
iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
|
932 |
+
state,
|
933 |
+
*iface.btn_semantics,
|
934 |
+
# iface.model_select,
|
935 |
+
# iface.style_select,
|
936 |
+
# iface.quality_select,
|
937 |
+
iface.tbox_prompt,
|
938 |
+
iface.tbox_name,
|
939 |
+
iface.tbox_neg_prompt,
|
940 |
+
iface.slider_strength,
|
941 |
+
iface.slider_alpha,
|
942 |
+
iface.slider_std,
|
943 |
+
iface.slider_seed,
|
944 |
+
])
|
945 |
+
|
946 |
+
|
947 |
+
if __name__ == '__main__':
|
948 |
+
demo.launch(server_port=opt.port)
|
examples/prompt_background.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
|
2 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
|
3 |
+
Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
|
4 |
+
Maximalism, best quality, high quality, no humans, background, galaxy
|
5 |
+
Maximalism, best quality, high quality, no humans, background, sky, daylight
|
6 |
+
Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
|
7 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
|
8 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
|
examples/prompt_background_advanced.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/prompt_boy.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1boy, looking at viewer, brown hair, blue shirt
|
2 |
+
1boy, looking at viewer, brown hair, red shirt
|
3 |
+
1boy, looking at viewer, brown hair, purple shirt
|
4 |
+
1boy, looking at viewer, brown hair, orange shirt
|
5 |
+
1boy, looking at viewer, brown hair, yellow shirt
|
6 |
+
1boy, looking at viewer, brown hair, green shirt
|
7 |
+
1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
|
8 |
+
1boy, looking back, short hair, renaissance cloths, noble boy
|
9 |
+
1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
|
10 |
+
1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
|
11 |
+
1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
|
12 |
+
1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
|
13 |
+
1boy, looking at viewer, black haired, old eastern cloth
|
14 |
+
1boy, looking back, messy hair, suit, short beard, noir
|
15 |
+
1boy, looking at viewer, cute face, light smile, starry eyes, jeans
|
examples/prompt_girl.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
|
2 |
+
1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
|
3 |
+
1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
|
4 |
+
1girl, looking at viewer, fantasy adventurer, backpack
|
5 |
+
1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
|
6 |
+
1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
|
7 |
+
1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
|
8 |
+
1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
|
9 |
+
1girl, looking at viewer, evil smile, very short hair, suit, evil genius
|
10 |
+
1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
|
11 |
+
1girl, looking at viewer, purple hair, happy face, black leather jacket
|
12 |
+
1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
|
13 |
+
1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
|
14 |
+
1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
|
15 |
+
1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
|
16 |
+
1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
|
examples/prompt_props.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
🏯 Palace, Gyeongbokgung palace
|
2 |
+
🌳 Garden, Chinese garden
|
3 |
+
🏛️ Rome, Ancient city of Rome
|
4 |
+
🧱 Wall, Castle wall
|
5 |
+
🔴 Mars, Martian desert, Red rocky desert
|
6 |
+
🌻 Grassland, Grasslands
|
7 |
+
🏡 Village, A fantasy village
|
8 |
+
🐉 Dragon, a flying chinese dragon
|
9 |
+
🌏 Earth, Earth seen from ISS
|
10 |
+
🚀 Space Station, the international space station
|
11 |
+
🪻 Grassland, Rusty grassland with flowers
|
12 |
+
🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
|
13 |
+
🏙️ City Ruin, city, ruins, ruins, ruins, deserted
|
14 |
+
🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
|
15 |
+
🌷 Flowers, Flower garden
|
16 |
+
🌼 Flowers, Flower garden, spring garden
|
17 |
+
🌹 Flowers, Flowers flowers, flowers
|
18 |
+
⛰️ Dolomites Mountains, Dolomites
|
19 |
+
⛰️ Himalayas Mountains, Himalayas
|
20 |
+
⛰️ Alps Mountains, Alps
|
21 |
+
⛰️ Mountains, Mountains
|
22 |
+
❄️⛰️ Mountains, Winter mountains
|
23 |
+
🌷⛰️ Mountains, Spring mountains
|
24 |
+
🌞⛰️ Mountains, Summer mountains
|
25 |
+
🌵 Desert, A sandy desert, dunes
|
26 |
+
🪨🌵 Desert, A rocky desert
|
27 |
+
💦 Waterfall, A giant waterfall
|
28 |
+
🌊 Ocean, Ocean
|
29 |
+
⛱️ Seashore, Seashore
|
30 |
+
🌅 Sea Horizon, Sea horizon
|
31 |
+
🌊 Lake, Clear blue lake
|
32 |
+
💻 Computer, A giant supecomputer
|
33 |
+
🌳 Tree, A giant tree
|
34 |
+
🌳 Forest, A forest
|
35 |
+
🌳🌳 Forest, A dense forest
|
36 |
+
🌲 Forest, Winter forest
|
37 |
+
🌴 Forest, Summer forest, tropical forest
|
38 |
+
👒 Hat, A hat
|
39 |
+
🐶 Dog, Doggy body parts
|
40 |
+
😻 Cat, A cat
|
41 |
+
🦉 Owl, A small sitting owl
|
42 |
+
🦅 Eagle, A small sitting eagle
|
43 |
+
🚀 Rocket, A flying rocket
|
model.py
ADDED
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import inspect
|
22 |
+
from typing import Any, Callable, Dict, List, Literal, Tuple, Optional, Union
|
23 |
+
from tqdm import tqdm
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torchvision.transforms as T
|
30 |
+
from einops import rearrange
|
31 |
+
|
32 |
+
from transformers import (
|
33 |
+
CLIPTextModelWithProjection,
|
34 |
+
CLIPTokenizer,
|
35 |
+
T5EncoderModel,
|
36 |
+
T5TokenizerFast,
|
37 |
+
)
|
38 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
39 |
+
|
40 |
+
from diffusers.image_processor import VaeImageProcessor
|
41 |
+
from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
42 |
+
from diffusers.models.attention_processor import (
|
43 |
+
AttnProcessor2_0,
|
44 |
+
FusedAttnProcessor2_0,
|
45 |
+
LoRAAttnProcessor2_0,
|
46 |
+
LoRAXFormersAttnProcessor,
|
47 |
+
XFormersAttnProcessor,
|
48 |
+
)
|
49 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
50 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
51 |
+
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3PipelineOutput
|
52 |
+
from diffusers.schedulers import (
|
53 |
+
FlowMatchEulerDiscreteScheduler,
|
54 |
+
FlashFlowMatchEulerDiscreteScheduler,
|
55 |
+
)
|
56 |
+
from diffusers.utils import (
|
57 |
+
is_torch_xla_available,
|
58 |
+
logging,
|
59 |
+
replace_example_docstring,
|
60 |
+
)
|
61 |
+
from diffusers.utils.torch_utils import randn_tensor
|
62 |
+
from diffusers import (
|
63 |
+
DiffusionPipeline,
|
64 |
+
StableDiffusion3Pipeline,
|
65 |
+
)
|
66 |
+
|
67 |
+
from peft import PeftModel
|
68 |
+
|
69 |
+
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
70 |
+
|
71 |
+
|
72 |
+
if is_torch_xla_available():
|
73 |
+
import torch_xla.core.xla_model as xm
|
74 |
+
|
75 |
+
XLA_AVAILABLE = True
|
76 |
+
else:
|
77 |
+
XLA_AVAILABLE = False
|
78 |
+
|
79 |
+
|
80 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
81 |
+
|
82 |
+
EXAMPLE_DOC_STRING = """
|
83 |
+
Examples:
|
84 |
+
```py
|
85 |
+
>>> import torch
|
86 |
+
>>> from diffusers import StableDiffusion3Pipeline
|
87 |
+
|
88 |
+
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
89 |
+
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
90 |
+
... )
|
91 |
+
>>> pipe.to("cuda")
|
92 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
93 |
+
>>> image = pipe(prompt).images[0]
|
94 |
+
>>> image.save("sd3.png")
|
95 |
+
```
|
96 |
+
"""
|
97 |
+
|
98 |
+
|
99 |
+
class StableMultiDiffusion3Pipeline(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
device: torch.device,
|
103 |
+
dtype: torch.dtype = torch.float16,
|
104 |
+
hf_key: Optional[str] = None,
|
105 |
+
lora_key: Optional[str] = None,
|
106 |
+
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
107 |
+
default_mask_std: float = 1.0, # 8.0
|
108 |
+
default_mask_strength: float = 1.0,
|
109 |
+
default_prompt_strength: float = 1.0, # 8.0
|
110 |
+
default_bootstrap_steps: int = 1,
|
111 |
+
default_boostrap_mix_steps: float = 1.0,
|
112 |
+
default_bootstrap_leak_sensitivity: float = 0.2,
|
113 |
+
default_preprocess_mask_cover_alpha: float = 0.3,
|
114 |
+
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number.
|
115 |
+
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
116 |
+
has_i2t: bool = True,
|
117 |
+
lora_weight: float = 1.0,
|
118 |
+
) -> None:
|
119 |
+
r"""Stabilized MultiDiffusion for fast sampling.
|
120 |
+
|
121 |
+
Accelrated region-based text-to-image synthesis with Latent Consistency
|
122 |
+
Model while preserving mask fidelity and quality.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
device (torch.device): Specify CUDA device.
|
126 |
+
hf_key (Optional[str]): Custom StableDiffusion checkpoint for
|
127 |
+
stylized generation.
|
128 |
+
lora_key (Optional[str]): Custom Lightning LoRA for acceleration.
|
129 |
+
load_from_local (bool): Turn on if you have already downloaed LoRA
|
130 |
+
& Hugging Face hub is down.
|
131 |
+
default_mask_std (float): Preprocess mask with Gaussian blur with
|
132 |
+
specified standard deviation.
|
133 |
+
default_mask_strength (float): Preprocess mask by multiplying it
|
134 |
+
globally with the specified variable. Caution: extremely
|
135 |
+
sensitive. Recommended range: 0.98-1.
|
136 |
+
default_prompt_strength (float): Preprocess foreground prompts
|
137 |
+
globally by linearly interpolating its embedding with the
|
138 |
+
background prompt embeddint with specified mix ratio. Useful
|
139 |
+
control handle for foreground blending. Recommended range:
|
140 |
+
0.5-1.
|
141 |
+
default_bootstrap_steps (int): Bootstrapping stage steps to
|
142 |
+
encourage region separation. Recommended range: 1-3.
|
143 |
+
default_boostrap_mix_steps (float): Bootstrapping background is a
|
144 |
+
linear interpolation between background latent and the white
|
145 |
+
image latent. This handle controls the mix ratio. Available
|
146 |
+
range: 0-(number of bootstrapping inference steps). For
|
147 |
+
example, 2.3 means that for the first two steps, white image
|
148 |
+
is used as a bootstrapping background and in the third step,
|
149 |
+
mixture of white (0.3) and registered background (0.7) is used
|
150 |
+
as a bootstrapping background.
|
151 |
+
default_bootstrap_leak_sensitivity (float): Postprocessing at each
|
152 |
+
inference step by masking away the remaining bootstrap
|
153 |
+
backgrounds t Recommended range: 0-1.
|
154 |
+
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
155 |
+
where each mask covered by other masks is reduced in its alpha
|
156 |
+
value by this specified factor.
|
157 |
+
t_index_list (List[int]): The default scheduling for the scheduler.
|
158 |
+
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
159 |
+
defines the mask quantization modes. Details in the codes of
|
160 |
+
`self.process_mask`. Basically, this (subtly) controls the
|
161 |
+
smoothness of foreground-background blending. More continuous
|
162 |
+
means more blending, but smaller generated patch depending on
|
163 |
+
the mask standard deviation.
|
164 |
+
has_i2t (bool): Automatic background image to text prompt con-
|
165 |
+
version with BLIP-2 model. May not be necessary for the non-
|
166 |
+
streaming application.
|
167 |
+
lora_weight (float): Adjusts weight of the LCM/Lightning LoRA.
|
168 |
+
Heavily affects the overall quality!
|
169 |
+
"""
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.device = device
|
173 |
+
self.dtype = dtype
|
174 |
+
|
175 |
+
self.default_mask_std = default_mask_std
|
176 |
+
self.default_mask_strength = default_mask_strength
|
177 |
+
self.default_prompt_strength = default_prompt_strength
|
178 |
+
self.default_t_list = t_index_list
|
179 |
+
self.default_bootstrap_steps = default_bootstrap_steps
|
180 |
+
self.default_boostrap_mix_steps = default_boostrap_mix_steps
|
181 |
+
self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
|
182 |
+
self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
|
183 |
+
self.mask_type = mask_type
|
184 |
+
|
185 |
+
# Create model.
|
186 |
+
print(f'[INFO] Loading Stable Diffusion...')
|
187 |
+
if hf_key is not None:
|
188 |
+
print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
|
189 |
+
else:
|
190 |
+
hf_key = "stabilityai/stable-diffusion-3-medium-diffusers"
|
191 |
+
|
192 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
193 |
+
hf_key,
|
194 |
+
subfolder="transformer",
|
195 |
+
torch_dtype=torch.float16,
|
196 |
+
).to(self.device)
|
197 |
+
|
198 |
+
transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-sd3").to(self.device)
|
199 |
+
|
200 |
+
self.pipe = StableDiffusion3Pipeline.from_pretrained(
|
201 |
+
"stabilityai/stable-diffusion-3-medium-diffusers",
|
202 |
+
transformer=transformer,
|
203 |
+
torch_dtype=torch.float16,
|
204 |
+
text_encoder_3=None,
|
205 |
+
tokenizer_3=None
|
206 |
+
).to(self.device)
|
207 |
+
|
208 |
+
# Create model
|
209 |
+
if has_i2t:
|
210 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
211 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
212 |
+
|
213 |
+
# Use SDXL-Lightning LoRA by default.
|
214 |
+
self.pipe.scheduler = FlashFlowMatchEulerDiscreteScheduler.from_pretrained(
|
215 |
+
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="scheduler")
|
216 |
+
self.pipe = self.pipe.to(self.device)
|
217 |
+
|
218 |
+
self.scheduler = self.pipe.scheduler
|
219 |
+
self.default_num_inference_steps = 4
|
220 |
+
self.default_guidance_scale = 0.0
|
221 |
+
|
222 |
+
if t_index_list is None:
|
223 |
+
self.prepare_flashflowmatch_schedule(
|
224 |
+
list(range(self.default_num_inference_steps)),
|
225 |
+
self.default_num_inference_steps,
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
self.prepare_flashflowmatch_schedule(t_index_list, 50)
|
229 |
+
|
230 |
+
self.vae = self.pipe.vae
|
231 |
+
self.tokenizer = self.pipe.tokenizer
|
232 |
+
self.tokenizer_2 = self.pipe.tokenizer_2
|
233 |
+
self.tokenizer_3 = self.pipe.tokenizer_3
|
234 |
+
self.text_encoder = self.pipe.text_encoder
|
235 |
+
self.text_encoder_2 = self.pipe.text_encoder_2
|
236 |
+
self.text_encoder_3 = self.pipe.text_encoder_3
|
237 |
+
self.transformer = self.pipe.transformer
|
238 |
+
self.vae_scale_factor = self.pipe.vae_scale_factor
|
239 |
+
|
240 |
+
# Prepare white background for bootstrapping.
|
241 |
+
self.get_white_background(1024, 1024)
|
242 |
+
|
243 |
+
print(f'[INFO] Model is loaded!')
|
244 |
+
|
245 |
+
def prepare_flashflowmatch_schedule(
|
246 |
+
self,
|
247 |
+
t_index_list: Optional[List[int]] = None,
|
248 |
+
num_inference_steps: Optional[int] = None,
|
249 |
+
) -> None:
|
250 |
+
r"""Set up different inference schedule for the diffusion model.
|
251 |
+
|
252 |
+
You do not have to run this explicitly if you want to use the default
|
253 |
+
setting, but if you want other time schedules, run this function
|
254 |
+
between the module initialization and the main call.
|
255 |
+
|
256 |
+
Note:
|
257 |
+
- Recommended t_index_lists for LCMs:
|
258 |
+
- [0, 12, 25, 37]: Default schedule for 4 steps. Best for
|
259 |
+
panorama. Not recommended if you want to use bootstrapping.
|
260 |
+
Because bootstrapping stage affects the initial structuring
|
261 |
+
of the generated image & in this four step LCM, this is done
|
262 |
+
with only at the first step, the structure may be distorted.
|
263 |
+
- [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
|
264 |
+
strapping. Default initialization in this implementation.
|
265 |
+
- [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
|
266 |
+
bootstrapping.
|
267 |
+
- Due to the characteristic of SD1.5 LCM LoRA, setting
|
268 |
+
`num_inference_steps` larger than 20 may results in overly blurry
|
269 |
+
and unrealistic images. Beware!
|
270 |
+
|
271 |
+
Args:
|
272 |
+
t_index_list (Optional[List[int]]): The specified scheduling step
|
273 |
+
regarding the maximum timestep as `num_inference_steps`, which
|
274 |
+
is by default, 50. That means that
|
275 |
+
`t_index_list=[0, 12, 25, 37]` is a relative time indices basd
|
276 |
+
on the full scale of 50. If None, reinitialize the module with
|
277 |
+
the default value.
|
278 |
+
num_inference_steps (Optional[int]): The maximum timestep of the
|
279 |
+
sampler. Defines relative scale of the `t_index_list`. Rarely
|
280 |
+
used in practice. If None, reinitialize the module with the
|
281 |
+
default value.
|
282 |
+
"""
|
283 |
+
if t_index_list is None:
|
284 |
+
t_index_list = self.default_t_list
|
285 |
+
if num_inference_steps is None:
|
286 |
+
num_inference_steps = self.default_num_inference_steps
|
287 |
+
|
288 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
289 |
+
self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)].to(self.device)
|
290 |
+
|
291 |
+
# FlashFlowMatchEulerDiscreteScheduler
|
292 |
+
# https://github.com/initml/diffusers/blob/clement/feature/flash_sd3/src/diffusers/schedulers/scheduling_flash_flow_match_euler_discrete.py
|
293 |
+
|
294 |
+
self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)].to(self.device)
|
295 |
+
self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:].to(self.device)
|
296 |
+
|
297 |
+
noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
|
298 |
+
self.noise_lvs = noise_lvs[None, :, None, None, None]
|
299 |
+
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
|
300 |
+
|
301 |
+
@torch.no_grad()
|
302 |
+
def get_text_prompts(self, image: Image.Image) -> str:
|
303 |
+
r"""A convenient method to extract text prompt from an image.
|
304 |
+
|
305 |
+
This is called if the user does not provide background prompt but only
|
306 |
+
the background image. We use BLIP-2 to automatically generate prompts.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
image (Image.Image): A PIL image.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
A single string of text prompt.
|
313 |
+
"""
|
314 |
+
if hasattr(self, 'i2t_model'):
|
315 |
+
question = 'Question: What are in the image? Answer:'
|
316 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
317 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
318 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
319 |
+
return prompt
|
320 |
+
else:
|
321 |
+
return ''
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def encode_imgs(
|
325 |
+
self,
|
326 |
+
imgs: torch.Tensor,
|
327 |
+
generator: Optional[torch.Generator] = None,
|
328 |
+
vae: Optional[nn.Module] = None,
|
329 |
+
) -> torch.Tensor:
|
330 |
+
r"""A wrapper function for VAE encoder of the latent diffusion model.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
imgs (torch.Tensor): An image to get StableDiffusion latents.
|
334 |
+
Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
|
335 |
+
generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
|
336 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
337 |
+
the demo application with TinyVAE).
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
341 |
+
encoder. Shape: (B, 4, H//8, W//8).
|
342 |
+
"""
|
343 |
+
def _retrieve_latents(
|
344 |
+
encoder_output: torch.Tensor,
|
345 |
+
generator: Optional[torch.Generator] = None,
|
346 |
+
sample_mode: str = 'sample',
|
347 |
+
):
|
348 |
+
if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
|
349 |
+
return encoder_output.latent_dist.sample(generator)
|
350 |
+
elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
|
351 |
+
return encoder_output.latent_dist.mode()
|
352 |
+
elif hasattr(encoder_output, 'latents'):
|
353 |
+
return encoder_output.latents
|
354 |
+
else:
|
355 |
+
raise AttributeError('Could not access latents of provided encoder_output')
|
356 |
+
|
357 |
+
vae = self.vae if vae is None else vae
|
358 |
+
imgs = 2 * imgs - 1
|
359 |
+
latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
|
360 |
+
return latents
|
361 |
+
|
362 |
+
@torch.no_grad()
|
363 |
+
def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
|
364 |
+
r"""A wrapper function for VAE decoder of the latent diffusion model.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
latents (torch.Tensor): An image latent to get associated images.
|
368 |
+
Expected shape: (B, 4, H//8, W//8).
|
369 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
370 |
+
the demo application with TinyVAE).
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
374 |
+
encoder. Shape: (B, 3, H, W).
|
375 |
+
"""
|
376 |
+
vae = self.vae if vae is None else vae
|
377 |
+
latents = 1 / vae.config.scaling_factor * latents
|
378 |
+
imgs = vae.decode(latents).sample
|
379 |
+
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
380 |
+
return imgs
|
381 |
+
|
382 |
+
@torch.no_grad()
|
383 |
+
def get_white_background(self, height: int, width: int) -> torch.Tensor:
|
384 |
+
r"""White background image latent for bootstrapping or in case of
|
385 |
+
absent background.
|
386 |
+
|
387 |
+
Additionally stores the maximally-sized white latent for fast retrieval
|
388 |
+
in the future. By default, we initially call this with 1024x1024 sized
|
389 |
+
white image, so the function is rarely visited twice.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
height (int): The height of the white *image*, not its latent.
|
393 |
+
width (int): The width of the white *image*, not its latent.
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
A white image latent of size (1, 4, height//8, width//8). A cropped
|
397 |
+
version of the stored white latent is returned if the requested
|
398 |
+
size is smaller than what we already have created.
|
399 |
+
"""
|
400 |
+
if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
|
401 |
+
white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
402 |
+
self.white = self.encode_imgs(white)
|
403 |
+
return self.white
|
404 |
+
return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def process_mask(
|
408 |
+
self,
|
409 |
+
masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
|
410 |
+
strength: Optional[Union[torch.Tensor, float]] = None,
|
411 |
+
std: Optional[Union[torch.Tensor, float]] = None,
|
412 |
+
height: int = 1024,
|
413 |
+
width: int = 1024,
|
414 |
+
use_boolean_mask: bool = True,
|
415 |
+
timesteps: Optional[torch.Tensor] = None,
|
416 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
417 |
+
) -> Tuple[torch.Tensor]:
|
418 |
+
r"""Fast preprocess of masks for region-based generation with fine-
|
419 |
+
grained controls.
|
420 |
+
|
421 |
+
Mask preprocessing is done in four steps:
|
422 |
+
1. Resizing: Resize the masks into the specified width and height by
|
423 |
+
nearest neighbor interpolation.
|
424 |
+
2. (Optional) Ordering: Masks with higher indices are considered to
|
425 |
+
cover the masks with smaller indices. Covered masks are decayed
|
426 |
+
in its alpha value by the specified factor of
|
427 |
+
`preprocess_mask_cover_alpha`.
|
428 |
+
3. Blurring: Gaussian blur is applied to the mask with the specified
|
429 |
+
standard deviation (isotropic). This results in gradual increase of
|
430 |
+
masked region as the timesteps evolve, naturally blending fore-
|
431 |
+
ground and the predesignated background. Not strictly required if
|
432 |
+
you want to produce images from scratch withoout background.
|
433 |
+
4. Quantization: Split the real-numbered masks of value between [0, 1]
|
434 |
+
into predefined noise levels for each quantized scheduling step of
|
435 |
+
the diffusion sampler. For example, if the diffusion model sampler
|
436 |
+
has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
|
437 |
+
is the default noise level of this module with schedule [0, 4, 12,
|
438 |
+
25, 37], the masks are split into binary masks whose values are
|
439 |
+
greater than these levels. This results in tradual increase of mask
|
440 |
+
region as the timesteps increase. Details are described in our
|
441 |
+
paper at https://arxiv.org/pdf/2403.09055.pdf.
|
442 |
+
|
443 |
+
On the Three Modes of `mask_type`:
|
444 |
+
`self.mask_type` is predefined at the initialization stage of this
|
445 |
+
pipeline. Three possible modes are available: 'discrete', 'semi-
|
446 |
+
continuous', and 'continuous'. These define the mask quantization
|
447 |
+
modes we use. Basically, this (subtly) controls the smoothness of
|
448 |
+
foreground-background blending. Continuous modes produces nonbinary
|
449 |
+
masks to further blend foreground and background latents by linear-
|
450 |
+
ly interpolating between them. Semi-continuous masks only applies
|
451 |
+
continuous mask at the last step of the LCM sampler. Due to the
|
452 |
+
large step size of the LCM scheduler, we find that our continuous
|
453 |
+
blending helps generating seamless inpainting and editing results.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
|
457 |
+
strength (Optional[Union[torch.Tensor, float]]): Mask strength that
|
458 |
+
overrides the default value. A globally multiplied factor to
|
459 |
+
the mask at the initial stage of processing. Can be applied
|
460 |
+
seperately for each mask.
|
461 |
+
std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
|
462 |
+
kernel's standard deviation. Overrides the default value. Can
|
463 |
+
be applied seperately for each mask.
|
464 |
+
height (int): The height of the expected generation. Mask is
|
465 |
+
resized to (height//8, width//8) with nearest neighbor inter-
|
466 |
+
polation.
|
467 |
+
width (int): The width of the expected generation. Mask is resized
|
468 |
+
to (height//8, width//8) with nearest neighbor interpolation.
|
469 |
+
use_boolean_mask (bool): Specify this to treat the mask image as
|
470 |
+
a boolean tensor. The retion with dark part darker than 0.5 of
|
471 |
+
the maximal pixel value (that is, 127.5) is considered as the
|
472 |
+
designated mask.
|
473 |
+
timesteps (Optional[torch.Tensor]): Defines the scheduler noise
|
474 |
+
levels that acts as bins of mask quantization.
|
475 |
+
preprocess_mask_cover_alpha (Optional[float]): Optional pre-
|
476 |
+
processing where each mask covered by other masks is reduced in
|
477 |
+
its alpha value by this specified factor. Overrides the default
|
478 |
+
value.
|
479 |
+
|
480 |
+
Returns: A tuple of tensors.
|
481 |
+
- masks: Preprocessed (ordered, blurred, and quantized) binary/non-
|
482 |
+
binary masks (see the explanation on `mask_type` above) for
|
483 |
+
region-based image synthesis.
|
484 |
+
- masks_blurred: Gaussian blurred masks. Used for optionally
|
485 |
+
specified foreground-background blending after image
|
486 |
+
generation.
|
487 |
+
- std: Mask blur standard deviation. Used for optionally specified
|
488 |
+
foreground-background blending after image generation.
|
489 |
+
"""
|
490 |
+
if isinstance(masks, Image.Image):
|
491 |
+
masks = [masks]
|
492 |
+
if isinstance(masks, (tuple, list)):
|
493 |
+
# Assumes white background for Image.Image;
|
494 |
+
# inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
|
495 |
+
if use_boolean_mask:
|
496 |
+
proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
|
497 |
+
else:
|
498 |
+
proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
|
499 |
+
masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
|
500 |
+
masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
|
501 |
+
masks = masks.to(self.device)
|
502 |
+
|
503 |
+
# Background mask alpha is decayed by the specified factor where foreground masks covers it.
|
504 |
+
if preprocess_mask_cover_alpha is None:
|
505 |
+
preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
|
506 |
+
if preprocess_mask_cover_alpha > 0:
|
507 |
+
masks = torch.stack([
|
508 |
+
torch.where(
|
509 |
+
masks[i + 1:].sum(dim=0) > 0,
|
510 |
+
mask * preprocess_mask_cover_alpha,
|
511 |
+
mask,
|
512 |
+
) if i < len(masks) - 1 else mask
|
513 |
+
for i, mask in enumerate(masks)
|
514 |
+
], dim=0)
|
515 |
+
|
516 |
+
# Scheduler noise levels for mask quantization.
|
517 |
+
if timesteps is None:
|
518 |
+
noise_lvs = self.noise_lvs
|
519 |
+
next_noise_lvs = self.next_noise_lvs
|
520 |
+
else:
|
521 |
+
noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
|
522 |
+
# noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
|
523 |
+
noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device)
|
524 |
+
next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
|
525 |
+
|
526 |
+
# Mask preprocessing parameters are fetched from the default settings.
|
527 |
+
if std is None:
|
528 |
+
std = self.default_mask_std
|
529 |
+
if isinstance(std, (int, float)):
|
530 |
+
std = [std] * len(masks)
|
531 |
+
if isinstance(std, (list, tuple)):
|
532 |
+
std = torch.as_tensor(std, dtype=torch.float, device=self.device)
|
533 |
+
|
534 |
+
if strength is None:
|
535 |
+
strength = self.default_mask_strength
|
536 |
+
if isinstance(strength, (int, float)):
|
537 |
+
strength = [strength] * len(masks)
|
538 |
+
if isinstance(strength, (list, tuple)):
|
539 |
+
strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
|
540 |
+
|
541 |
+
if (std > 0).any():
|
542 |
+
std = torch.where(std > 0, std, 1e-5)
|
543 |
+
masks = gaussian_lowpass(masks, std)
|
544 |
+
masks_blurred = masks
|
545 |
+
|
546 |
+
# NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
|
547 |
+
# gives unpleasant results.
|
548 |
+
masks = masks * strength[:, None, None, None]
|
549 |
+
masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
|
550 |
+
|
551 |
+
# Mask is quantized according to the current noise levels specified by the scheduler.
|
552 |
+
if self.mask_type == 'discrete':
|
553 |
+
# Discrete mode.
|
554 |
+
masks = masks > noise_lvs
|
555 |
+
elif self.mask_type == 'semi-continuous':
|
556 |
+
# Semi-continuous mode (continuous at the last step only).
|
557 |
+
masks = torch.cat((
|
558 |
+
masks[:, :-1] > noise_lvs[:, :-1],
|
559 |
+
(
|
560 |
+
(masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
|
561 |
+
).clip_(0, 1),
|
562 |
+
), dim=1)
|
563 |
+
elif self.mask_type == 'continuous':
|
564 |
+
# Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
|
565 |
+
# decreases continuously after the discrete mode boundary to become `0` at the
|
566 |
+
# next lower threshold.
|
567 |
+
masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
|
568 |
+
|
569 |
+
# NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
|
570 |
+
# fine-grained mask alpha channel tuning is available with this form.
|
571 |
+
# masks = masks * strength[None, :, None, None, None]
|
572 |
+
|
573 |
+
h = height // self.vae_scale_factor
|
574 |
+
w = width // self.vae_scale_factor
|
575 |
+
masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
|
576 |
+
masks = F.interpolate(masks, size=(h, w), mode='nearest')
|
577 |
+
masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
|
578 |
+
return masks, masks_blurred, std
|
579 |
+
|
580 |
+
def scheduler_step(
|
581 |
+
self,
|
582 |
+
noise_pred: torch.Tensor,
|
583 |
+
idx: int,
|
584 |
+
latent: torch.Tensor,
|
585 |
+
) -> torch.Tensor:
|
586 |
+
r"""Denoise-only step for reverse diffusion scheduler.
|
587 |
+
|
588 |
+
Designed to match the interface of the original `pipe.scheduler.step`,
|
589 |
+
which is a combination of this method and the following
|
590 |
+
`scheduler_add_noise`.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
noise_pred (torch.Tensor): Noise prediction results from the U-Net.
|
594 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
595 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
596 |
+
latent (torch.Tensor): Noisy latent.
|
597 |
+
|
598 |
+
Returns:
|
599 |
+
A denoised tensor with the same size as latent.
|
600 |
+
"""
|
601 |
+
# Upcast to avoid precision issues when computing prev_sample.
|
602 |
+
latent = latent.to(torch.float32)
|
603 |
+
prev_sample = latent - noise_pred * self.sigmas[idx]
|
604 |
+
return prev_sample.to(self.dtype)
|
605 |
+
|
606 |
+
def scheduler_add_noise(
|
607 |
+
self,
|
608 |
+
latent: torch.Tensor,
|
609 |
+
noise: Optional[torch.Tensor],
|
610 |
+
idx: int,
|
611 |
+
) -> torch.Tensor:
|
612 |
+
r"""Separated noise-add step for the reverse diffusion scheduler.
|
613 |
+
|
614 |
+
Designed to match the interface of the original
|
615 |
+
`pipe.scheduler.add_noise`.
|
616 |
+
|
617 |
+
Args:
|
618 |
+
latent (torch.Tensor): Denoised latent.
|
619 |
+
noise (torch.Tensor): Added noise. Can be None. If None, a random
|
620 |
+
noise is newly sampled for addition.
|
621 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
622 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
623 |
+
|
624 |
+
Returns:
|
625 |
+
A noisy tensor with the same size as latent.
|
626 |
+
"""
|
627 |
+
if idx < len(self.sigmas) and idx >= 0:
|
628 |
+
noise = torch.randn_like(latent) if noise is None else noise
|
629 |
+
return (1.0 - self.sigmas[idx]) * latent + self.sigmas[idx] * noise
|
630 |
+
else:
|
631 |
+
return latent
|
632 |
+
|
633 |
+
@torch.no_grad()
|
634 |
+
def __call__(
|
635 |
+
self,
|
636 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
637 |
+
negative_prompts: Union[str, List[str]] = '',
|
638 |
+
suffix: Optional[str] = None, #', background is ',
|
639 |
+
background: Optional[Union[torch.Tensor, Image.Image]] = None,
|
640 |
+
background_prompt: Optional[str] = None,
|
641 |
+
background_negative_prompt: str = '',
|
642 |
+
height: int = 1024,
|
643 |
+
width: int = 1024,
|
644 |
+
num_inference_steps: Optional[int] = None,
|
645 |
+
guidance_scale: Optional[float] = None,
|
646 |
+
prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
647 |
+
masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
648 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
649 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
650 |
+
use_boolean_mask: bool = True,
|
651 |
+
do_blend: bool = True,
|
652 |
+
tile_size: int = 1024,
|
653 |
+
bootstrap_steps: Optional[int] = None,
|
654 |
+
boostrap_mix_steps: Optional[float] = None,
|
655 |
+
bootstrap_leak_sensitivity: Optional[float] = None,
|
656 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
657 |
+
# SDXL Pipeline setting.
|
658 |
+
guidance_rescale: float = 0.7,
|
659 |
+
output_type = 'pil',
|
660 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
661 |
+
clip_skip: Optional[int] = None,
|
662 |
+
) -> Image.Image:
|
663 |
+
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
664 |
+
text prompt-mask pairs.
|
665 |
+
|
666 |
+
This is a main routine for this pipeline.
|
667 |
+
|
668 |
+
Example:
|
669 |
+
>>> device = torch.device('cuda:0')
|
670 |
+
>>> smd = StableMultiDiffusionPipeline(device)
|
671 |
+
>>> prompts = {... specify prompts}
|
672 |
+
>>> masks = {... specify mask tensors}
|
673 |
+
>>> height, width = masks.shape[-2:]
|
674 |
+
>>> image = smd(
|
675 |
+
>>> prompts, masks=masks.float(), height=height, width=width)
|
676 |
+
>>> image.save('my_beautiful_creation.png')
|
677 |
+
|
678 |
+
Args:
|
679 |
+
prompts (Union[str, List[str]]): A text prompt.
|
680 |
+
negative_prompts (Union[str, List[str]]): A negative text prompt.
|
681 |
+
suffix (Optional[str]): One option for blending foreground prompts
|
682 |
+
with background prompts by simply appending background prompt
|
683 |
+
to the end of each foreground prompt with this `middle word` in
|
684 |
+
between. For example, if you set this as `, background is`,
|
685 |
+
then the foreground prompt will be changed into
|
686 |
+
`(fg), background is (bg)` before conditional generation.
|
687 |
+
background (Optional[Union[torch.Tensor, Image.Image]]): a
|
688 |
+
background image, if the user wants to draw in front of the
|
689 |
+
specified image. Background prompt will automatically generated
|
690 |
+
with a BLIP-2 model.
|
691 |
+
background_prompt (Optional[str]): The background prompt is used
|
692 |
+
for preprocessing foreground prompt embeddings to blend
|
693 |
+
foreground and background.
|
694 |
+
background_negative_prompt (Optional[str]): The negative background
|
695 |
+
prompt.
|
696 |
+
height (int): Height of a generated image. It is tiled if larger
|
697 |
+
than `tile_size`.
|
698 |
+
width (int): Width of a generated image. It is tiled if larger
|
699 |
+
than `tile_size`.
|
700 |
+
num_inference_steps (Optional[int]): Number of inference steps.
|
701 |
+
Default inference scheduling is used if none is specified.
|
702 |
+
guidance_scale (Optional[float]): Classifier guidance scale.
|
703 |
+
Default value is used if none is specified.
|
704 |
+
prompt_strength (float): Overrides default value. Preprocess
|
705 |
+
foreground prompts globally by linearly interpolating its
|
706 |
+
embedding with the background prompt embeddint with specified
|
707 |
+
mix ratio. Useful control handle for foreground blending.
|
708 |
+
Recommended range: 0.5-1.
|
709 |
+
masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
|
710 |
+
mask images. Each mask associates with each of the text prompts
|
711 |
+
and each of the negative prompts. If specified as an image, it
|
712 |
+
regards the image as a boolean mask. Also accepts torch.Tensor
|
713 |
+
masks, which can have nonbinary values for fine-grained
|
714 |
+
controls in mixing regional generations.
|
715 |
+
mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
|
716 |
+
Overrides the default value. an be assigned for each mask
|
717 |
+
separately. Preprocess mask by multiplying it globally with the
|
718 |
+
specified variable. Caution: extremely sensitive. Recommended
|
719 |
+
range: 0.98-1.
|
720 |
+
mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
|
721 |
+
Overrides the default value. Can be assigned for each mask
|
722 |
+
separately. Preprocess mask with Gaussian blur with specified
|
723 |
+
standard deviation. Recommended range: 0-64.
|
724 |
+
use_boolean_mask (bool): Turn this off if you want to treat the
|
725 |
+
mask image as nonbinary one. The module will use the last
|
726 |
+
channel of the given image in `masks` as the mask value.
|
727 |
+
do_blend (bool): Blend the generated foreground and the optionally
|
728 |
+
predefined background by smooth boundary obtained from Gaussian
|
729 |
+
blurs of the foreground `masks` with the given `mask_stds`.
|
730 |
+
tile_size (Optional[int]): Tile size of the panorama generation.
|
731 |
+
Works best with the default training size of the Stable-
|
732 |
+
Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL.
|
733 |
+
bootstrap_steps (int): Overrides the default value. Bootstrapping
|
734 |
+
stage steps to encourage region separation. Recommended range:
|
735 |
+
1-3.
|
736 |
+
boostrap_mix_steps (float): Overrides the default value.
|
737 |
+
Bootstrapping background is a linear interpolation between
|
738 |
+
background latent and the white image latent. This handle
|
739 |
+
controls the mix ratio. Available range: 0-(number of
|
740 |
+
bootstrapping inference steps). For example, 2.3 means that for
|
741 |
+
the first two steps, white image is used as a bootstrapping
|
742 |
+
background and in the third step, mixture of white (0.3) and
|
743 |
+
registered background (0.7) is used as a bootstrapping
|
744 |
+
background.
|
745 |
+
bootstrap_leak_sensitivity (float): Overrides the default value.
|
746 |
+
Postprocessing at each inference step by masking away the
|
747 |
+
remaining bootstrap backgrounds t Recommended range: 0-1.
|
748 |
+
preprocess_mask_cover_alpha (float): Overrides the default value.
|
749 |
+
Optional preprocessing where each mask covered by other masks
|
750 |
+
is reduced in its alpha value by this specified factor.
|
751 |
+
|
752 |
+
Returns: A PIL.Image image of a panorama (large-size) image.
|
753 |
+
"""
|
754 |
+
|
755 |
+
### Simplest cases
|
756 |
+
|
757 |
+
# prompts is None: return background.
|
758 |
+
# masks is None but prompts is not None: return prompts
|
759 |
+
# masks is not None and prompts is not None: Do StableMultiDiffusion.
|
760 |
+
|
761 |
+
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
762 |
+
if background is None and background_prompt is not None:
|
763 |
+
return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
|
764 |
+
return background
|
765 |
+
elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
|
766 |
+
return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
|
767 |
+
|
768 |
+
|
769 |
+
### Prepare generation
|
770 |
+
|
771 |
+
if num_inference_steps is not None:
|
772 |
+
self.prepare_flashflowmatch_schedule(list(range(num_inference_steps)), num_inference_steps)
|
773 |
+
|
774 |
+
if guidance_scale is None:
|
775 |
+
guidance_scale = self.default_guidance_scale
|
776 |
+
self.pipe._guidance_scale = guidance_scale
|
777 |
+
self.pipe._clip_skip = clip_skip
|
778 |
+
self.pipe._joint_attention_kwargs = joint_attention_kwargs
|
779 |
+
self.pipe._interrupt = False
|
780 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
781 |
+
|
782 |
+
|
783 |
+
### Prompts & Masks
|
784 |
+
|
785 |
+
# asserts #m > 0 and #p > 0.
|
786 |
+
# #m == #p == #n > 0: We happily generate according to the prompts & masks.
|
787 |
+
# #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
|
788 |
+
# #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
|
789 |
+
|
790 |
+
if isinstance(masks, Image.Image):
|
791 |
+
masks = [masks]
|
792 |
+
if isinstance(prompts, str):
|
793 |
+
prompts = [prompts]
|
794 |
+
if isinstance(negative_prompts, str):
|
795 |
+
negative_prompts = [negative_prompts]
|
796 |
+
num_masks = len(masks)
|
797 |
+
num_prompts = len(prompts)
|
798 |
+
num_nprompts = len(negative_prompts)
|
799 |
+
assert num_prompts in (num_masks, 1), \
|
800 |
+
f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
|
801 |
+
assert num_nprompts in (num_prompts, 1), \
|
802 |
+
f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
|
803 |
+
|
804 |
+
fg_masks, masks_g, std = self.process_mask(
|
805 |
+
masks,
|
806 |
+
mask_strengths,
|
807 |
+
mask_stds,
|
808 |
+
height=height,
|
809 |
+
width=width,
|
810 |
+
use_boolean_mask=use_boolean_mask,
|
811 |
+
timesteps=self.timesteps,
|
812 |
+
preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
|
813 |
+
) # (p, t, 1, H, W)
|
814 |
+
bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
|
815 |
+
has_background = bg_masks.sum() > 0
|
816 |
+
|
817 |
+
h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
|
818 |
+
w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
|
819 |
+
|
820 |
+
|
821 |
+
### Background
|
822 |
+
|
823 |
+
# background == None && background_prompt == None: Initialize with white background.
|
824 |
+
# background == None && background_prompt != None: Generate background *along with other prompts*.
|
825 |
+
# background != None && background_prompt == None: Retrieve text prompt using BLIP.
|
826 |
+
# background != None && background_prompt != None: Use the given arguments.
|
827 |
+
|
828 |
+
# not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
|
829 |
+
# has_background && prompt_strength != 1: mix only for this case.
|
830 |
+
|
831 |
+
bg_latent = None
|
832 |
+
if has_background:
|
833 |
+
if background is None and background_prompt is not None:
|
834 |
+
fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
|
835 |
+
if suffix is not None:
|
836 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
837 |
+
prompts = [background_prompt] + prompts
|
838 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
839 |
+
has_background = False # Regard that background does not exist.
|
840 |
+
else:
|
841 |
+
if background is None and background_prompt is None:
|
842 |
+
background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
843 |
+
background_prompt = 'simple white background image'
|
844 |
+
elif background is not None and background_prompt is None:
|
845 |
+
background_prompt = self.get_text_prompts(background)
|
846 |
+
if suffix is not None:
|
847 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
848 |
+
prompts = [background_prompt] + prompts
|
849 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
850 |
+
if isinstance(background, Image.Image):
|
851 |
+
background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
|
852 |
+
background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
|
853 |
+
bg_latent = self.encode_imgs(background)
|
854 |
+
|
855 |
+
# Bootstrapping stage preparation.
|
856 |
+
|
857 |
+
if bootstrap_steps is None:
|
858 |
+
bootstrap_steps = self.default_bootstrap_steps
|
859 |
+
if boostrap_mix_steps is None:
|
860 |
+
boostrap_mix_steps = self.default_boostrap_mix_steps
|
861 |
+
if bootstrap_leak_sensitivity is None:
|
862 |
+
bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
|
863 |
+
if bootstrap_steps > 0:
|
864 |
+
height_ = min(height, tile_size)
|
865 |
+
width_ = min(width, tile_size)
|
866 |
+
white = self.get_white_background(height, width) # (1, 4, h, w)
|
867 |
+
|
868 |
+
|
869 |
+
### Prepare text embeddings (optimized for the minimal encoder batch size)
|
870 |
+
|
871 |
+
# SD3 pipeline settings.
|
872 |
+
batch_size = 1
|
873 |
+
num_images_per_prompt = 1
|
874 |
+
|
875 |
+
original_size = (height, width)
|
876 |
+
target_size = (height, width)
|
877 |
+
crops_coords_top_left = (0, 0)
|
878 |
+
negative_original_size = None
|
879 |
+
negative_target_size = None
|
880 |
+
negative_crops_coords_top_left = (0, 0)
|
881 |
+
|
882 |
+
prompt_2 = None
|
883 |
+
prompt_3 = None
|
884 |
+
negative_prompt_2 = None
|
885 |
+
negative_prompt_3 = None
|
886 |
+
prompt_embeds = None
|
887 |
+
negative_prompt_embeds = None
|
888 |
+
pooled_prompt_embeds = None
|
889 |
+
negative_pooled_prompt_embeds = None
|
890 |
+
text_encoder_lora_scale = None
|
891 |
+
|
892 |
+
(
|
893 |
+
prompt_embeds,
|
894 |
+
negative_prompt_embeds,
|
895 |
+
pooled_prompt_embeds,
|
896 |
+
negative_pooled_prompt_embeds,
|
897 |
+
) = self.pipe.encode_prompt(
|
898 |
+
prompt=prompts,
|
899 |
+
prompt_2=prompt_2,
|
900 |
+
prompt_3=prompt_3,
|
901 |
+
negative_prompt=negative_prompts,
|
902 |
+
negative_prompt_2=negative_prompt_2,
|
903 |
+
negative_prompt_3=negative_prompt_3,
|
904 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
905 |
+
prompt_embeds=prompt_embeds,
|
906 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
907 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
908 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
909 |
+
device=self.device,
|
910 |
+
clip_skip=self.pipe.clip_skip,
|
911 |
+
num_images_per_prompt=num_images_per_prompt,
|
912 |
+
)
|
913 |
+
|
914 |
+
if has_background:
|
915 |
+
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
916 |
+
s = prompt_strengths
|
917 |
+
if prompt_strengths is None:
|
918 |
+
s = self.default_prompt_strength
|
919 |
+
if isinstance(s, (int, float)):
|
920 |
+
s = [s] * num_prompts
|
921 |
+
if isinstance(s, (list, tuple)):
|
922 |
+
assert len(s) == num_prompts, \
|
923 |
+
f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
|
924 |
+
s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
|
925 |
+
s = s[:, None, None]
|
926 |
+
|
927 |
+
be = prompt_embeds[:1]
|
928 |
+
fe = prompt_embeds[1:]
|
929 |
+
prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024)
|
930 |
+
|
931 |
+
if negative_prompt_embeds is not None:
|
932 |
+
bu = negative_prompt_embeds[:1]
|
933 |
+
fu = negative_prompt_embeds[1:]
|
934 |
+
if num_prompts > num_nprompts:
|
935 |
+
# # negative prompts = 1; # prompts > 1.
|
936 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
937 |
+
fu = fu.repeat(num_prompts, 1, 1)
|
938 |
+
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
939 |
+
|
940 |
+
be = pooled_prompt_embeds[:1]
|
941 |
+
fe = pooled_prompt_embeds[1:]
|
942 |
+
pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280)
|
943 |
+
|
944 |
+
if negative_pooled_prompt_embeds is not None:
|
945 |
+
bu = negative_pooled_prompt_embeds[:1]
|
946 |
+
fu = negative_pooled_prompt_embeds[1:]
|
947 |
+
if num_prompts > num_nprompts:
|
948 |
+
# # negative prompts = 1; # prompts > 1.
|
949 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
950 |
+
fu = fu.repeat(num_prompts, 1)
|
951 |
+
negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280)
|
952 |
+
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
953 |
+
# # negative prompts = 1; # prompts > 1.
|
954 |
+
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
955 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
956 |
+
|
957 |
+
assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
|
958 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
|
959 |
+
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
960 |
+
if num_masks > num_prompts:
|
961 |
+
assert masks.shape[0] == num_masks and num_prompts == 1
|
962 |
+
prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1)
|
963 |
+
if negative_prompt_embeds is not None:
|
964 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
965 |
+
|
966 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
|
967 |
+
if negative_pooled_prompt_embeds is not None:
|
968 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
|
969 |
+
|
970 |
+
# SD3 pipeline settings.
|
971 |
+
if do_classifier_free_guidance:
|
972 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
973 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
974 |
+
del negative_prompt_embeds, negative_pooled_prompt_embeds
|
975 |
+
|
976 |
+
prompt_embeds = prompt_embeds.to(self.device)
|
977 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
|
978 |
+
|
979 |
+
|
980 |
+
### Run
|
981 |
+
|
982 |
+
# Latent initialization.
|
983 |
+
num_channels_latents = self.transformer.config.in_channels
|
984 |
+
noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
|
985 |
+
if self.timesteps[0] < 999 and has_background:
|
986 |
+
latent = self.scheduler_add_noise(bg_latent, noise, 0)
|
987 |
+
else:
|
988 |
+
noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
|
989 |
+
latent = noise
|
990 |
+
|
991 |
+
if has_background:
|
992 |
+
noise_bg_latents = [
|
993 |
+
self.scheduler_add_noise(bg_latent, noise, i) for i in range(len(self.timesteps))
|
994 |
+
] + [bg_latent]
|
995 |
+
|
996 |
+
# Tiling (if needed).
|
997 |
+
if height > tile_size or width > tile_size:
|
998 |
+
t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
|
999 |
+
views, tile_masks = get_panorama_views(h, w, t)
|
1000 |
+
tile_masks = tile_masks.to(self.device)
|
1001 |
+
else:
|
1002 |
+
views = [(0, h, 0, w)]
|
1003 |
+
tile_masks = latent.new_ones((1, 1, h, w))
|
1004 |
+
value = torch.zeros_like(latent)
|
1005 |
+
count_all = torch.zeros_like(latent)
|
1006 |
+
|
1007 |
+
with torch.autocast('cuda'):
|
1008 |
+
for i, t in enumerate(tqdm(self.timesteps)):
|
1009 |
+
if self.pipe.interrupt:
|
1010 |
+
continue
|
1011 |
+
|
1012 |
+
fg_mask = fg_masks[:, i]
|
1013 |
+
bg_mask = bg_masks[i:i + 1]
|
1014 |
+
|
1015 |
+
value.zero_()
|
1016 |
+
count_all.zero_()
|
1017 |
+
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
1018 |
+
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
1019 |
+
latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
|
1020 |
+
|
1021 |
+
# Bootstrap for tight background.
|
1022 |
+
if i < bootstrap_steps:
|
1023 |
+
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
1024 |
+
# Treat the first foreground latent as the background latent if one does not exist.
|
1025 |
+
bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
|
1026 |
+
white_ = white[..., h_start:h_end, w_start:w_end]
|
1027 |
+
white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i)
|
1028 |
+
bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
|
1029 |
+
latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
|
1030 |
+
|
1031 |
+
# Centering.
|
1032 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
|
1033 |
+
|
1034 |
+
# expand the latents if we are doing classifier free guidance
|
1035 |
+
latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
|
1036 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1037 |
+
timestep = t.expand(latent_model_input.shape[0])
|
1038 |
+
|
1039 |
+
# Perform one step of the reverse diffusion.
|
1040 |
+
noise_pred = self.transformer(
|
1041 |
+
hidden_states=latent_model_input,
|
1042 |
+
timestep=timestep,
|
1043 |
+
encoder_hidden_states=prompt_embeds,
|
1044 |
+
pooled_projections=pooled_prompt_embeds,
|
1045 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
1046 |
+
return_dict=False,
|
1047 |
+
)[0]
|
1048 |
+
|
1049 |
+
if do_classifier_free_guidance:
|
1050 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
1051 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1052 |
+
|
1053 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
1054 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1055 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
1056 |
+
|
1057 |
+
latent_ = self.scheduler_step(noise_pred, i, latent_)
|
1058 |
+
|
1059 |
+
if i < bootstrap_steps:
|
1060 |
+
# Uncentering.
|
1061 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
|
1062 |
+
|
1063 |
+
# Remove leakage (optional).
|
1064 |
+
leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
|
1065 |
+
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
1066 |
+
fg_mask_ = fg_mask_ * leak_sigmoid
|
1067 |
+
|
1068 |
+
# Mix the latents.
|
1069 |
+
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
1070 |
+
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
|
1071 |
+
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
1072 |
+
|
1073 |
+
latent = torch.where(count_all > 0, value / count_all, value)
|
1074 |
+
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
1075 |
+
if has_background:
|
1076 |
+
latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
|
1077 |
+
|
1078 |
+
# Noise is added after mixing.
|
1079 |
+
if i < len(self.timesteps) - 1:
|
1080 |
+
latent = self.scheduler_add_noise(latent, None, i + 1)
|
1081 |
+
|
1082 |
+
if not output_type == "latent":
|
1083 |
+
latent = (latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1084 |
+
image = self.vae.decode(latent, return_dict=False)[0]
|
1085 |
+
else:
|
1086 |
+
image = latent
|
1087 |
+
|
1088 |
+
# Return PIL Image.
|
1089 |
+
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
1090 |
+
if has_background and do_blend:
|
1091 |
+
fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
|
1092 |
+
image = blend(image, background[0], fg_mask)
|
1093 |
+
else:
|
1094 |
+
image = T.ToPILImage()(image)
|
1095 |
+
return image
|
prompt_util.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Union
|
2 |
+
|
3 |
+
|
4 |
+
quality_prompt_list = [
|
5 |
+
{
|
6 |
+
"name": "(None)",
|
7 |
+
"prompt": "{prompt}",
|
8 |
+
"negative_prompt": "nsfw, lowres",
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"name": "Standard v3.0",
|
12 |
+
"prompt": "{prompt}, masterpiece, best quality",
|
13 |
+
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Standard v3.1",
|
17 |
+
"prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
|
18 |
+
"negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "Light v3.1",
|
22 |
+
"prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
|
23 |
+
"negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "Heavy v3.1",
|
27 |
+
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
|
28 |
+
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
|
29 |
+
},
|
30 |
+
]
|
31 |
+
|
32 |
+
style_list = [
|
33 |
+
{
|
34 |
+
"name": "(None)",
|
35 |
+
"prompt": "{prompt}",
|
36 |
+
"negative_prompt": "",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"name": "Cinematic",
|
40 |
+
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
41 |
+
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "Photographic",
|
45 |
+
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
46 |
+
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Anime",
|
50 |
+
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
51 |
+
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Manga",
|
55 |
+
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
56 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "Digital Art",
|
60 |
+
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
61 |
+
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "Pixel art",
|
65 |
+
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
66 |
+
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"name": "Fantasy art",
|
70 |
+
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
71 |
+
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "Neonpunk",
|
75 |
+
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
76 |
+
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "3D Model",
|
80 |
+
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
81 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
82 |
+
},
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
_style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
87 |
+
_quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
|
88 |
+
|
89 |
+
|
90 |
+
def preprocess_prompt(
|
91 |
+
positive: str,
|
92 |
+
negative: str = "",
|
93 |
+
style_dict: Dict[str, dict] = _quality_dict,
|
94 |
+
style_name: str = "Standard v3.1", # "Heavy v3.1"
|
95 |
+
add_style: bool = True,
|
96 |
+
) -> Tuple[str, str]:
|
97 |
+
p, n = style_dict.get(style_name, style_dict["(None)"])
|
98 |
+
|
99 |
+
if add_style and positive.strip():
|
100 |
+
formatted_positive = p.format(prompt=positive)
|
101 |
+
else:
|
102 |
+
formatted_positive = positive
|
103 |
+
|
104 |
+
combined_negative = n
|
105 |
+
if negative.strip():
|
106 |
+
if combined_negative:
|
107 |
+
combined_negative += ", " + negative
|
108 |
+
else:
|
109 |
+
combined_negative = negative
|
110 |
+
|
111 |
+
return formatted_positive, combined_negative
|
112 |
+
|
113 |
+
|
114 |
+
def preprocess_prompts(
|
115 |
+
positives: List[str],
|
116 |
+
negatives: List[str] = None,
|
117 |
+
style_dict = _style_dict,
|
118 |
+
style_name: str = "Manga", # "(None)"
|
119 |
+
quality_dict = _quality_dict,
|
120 |
+
quality_name: str = "Standard v3.1", # "Heavy v3.1"
|
121 |
+
add_style: bool = True,
|
122 |
+
add_quality_tags = True,
|
123 |
+
) -> Tuple[List[str], List[str]]:
|
124 |
+
if negatives is None:
|
125 |
+
negatives = ['' for _ in positives]
|
126 |
+
|
127 |
+
positives_ = []
|
128 |
+
negatives_ = []
|
129 |
+
for pos, neg in zip(positives, negatives):
|
130 |
+
pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
|
131 |
+
pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
|
132 |
+
positives_.append(pos)
|
133 |
+
negatives_.append(neg)
|
134 |
+
return positives_, negatives_
|
135 |
+
|
136 |
+
|
137 |
+
def print_prompts(
|
138 |
+
positives: Union[str, List[str]],
|
139 |
+
negatives: Union[str, List[str]],
|
140 |
+
has_background: bool = False,
|
141 |
+
) -> None:
|
142 |
+
if isinstance(positives, str):
|
143 |
+
positives = [positives]
|
144 |
+
if isinstance(negatives, str):
|
145 |
+
negatives = [negatives]
|
146 |
+
|
147 |
+
for i, prompt in enumerate(positives):
|
148 |
+
prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
|
149 |
+
if has_background else f'Prompt{i + 1}')
|
150 |
+
print(prefix + ': ' + prompt)
|
151 |
+
for i, prompt in enumerate(negatives):
|
152 |
+
prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
|
153 |
+
if has_background else f'Negative Prompt{i + 1}')
|
154 |
+
print(prefix + ': ' + prompt)
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
xformers==0.0.22
|
4 |
+
einops
|
5 |
+
diffusers @ git+https://github.com/initml/diffusers.git@clement/feature/flash_sd3
|
6 |
+
transformers
|
7 |
+
huggingface_hub[torch]
|
8 |
+
gradio==4.39.0
|
9 |
+
Pillow
|
10 |
+
emoji
|
11 |
+
numpy
|
12 |
+
tqdm
|
13 |
+
jupyterlab
|
14 |
+
peft>=0.10.0
|
15 |
+
sentencepiece
|
16 |
+
protobuf
|
share_btn.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
share_js = """async () => {
|
2 |
+
async function uploadFile(file) {
|
3 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
4 |
+
const response = await fetch(UPLOAD_URL, {
|
5 |
+
method: 'POST',
|
6 |
+
headers: {
|
7 |
+
'Content-Type': file.type,
|
8 |
+
'X-Requested-With': 'XMLHttpRequest',
|
9 |
+
},
|
10 |
+
body: file, /// <- File inherits from Blob
|
11 |
+
});
|
12 |
+
const url = await response.text();
|
13 |
+
return url;
|
14 |
+
}
|
15 |
+
async function getBase64(file) {
|
16 |
+
var reader = new FileReader();
|
17 |
+
reader.readAsDataURL(file);
|
18 |
+
reader.onload = function () {
|
19 |
+
console.log(reader.result);
|
20 |
+
};
|
21 |
+
reader.onerror = function (error) {
|
22 |
+
console.log('Error: ', error);
|
23 |
+
};
|
24 |
+
}
|
25 |
+
const toDataURL = url => fetch(url)
|
26 |
+
.then(response => response.blob())
|
27 |
+
.then(blob => new Promise((resolve, reject) => {
|
28 |
+
const reader = new FileReader()
|
29 |
+
reader.onloadend = () => resolve(reader.result)
|
30 |
+
reader.onerror = reject
|
31 |
+
reader.readAsDataURL(blob)
|
32 |
+
}));
|
33 |
+
async function dataURLtoFile(dataurl, filename) {
|
34 |
+
var arr = dataurl.split(','), mime = arr[0].match(/:(.*?);/)[1],
|
35 |
+
bstr = atob(arr[1]), n = bstr.length, u8arr = new Uint8Array(n);
|
36 |
+
while (n--) {
|
37 |
+
u8arr[n] = bstr.charCodeAt(n);
|
38 |
+
}
|
39 |
+
return new File([u8arr], filename, {type:mime});
|
40 |
+
};
|
41 |
+
|
42 |
+
const gradioEl = document.querySelector('body > gradio-app');
|
43 |
+
const imgEls = gradioEl.querySelectorAll('#output-screen img');
|
44 |
+
if(!imgEls.length){
|
45 |
+
return;
|
46 |
+
};
|
47 |
+
|
48 |
+
const urls = await Promise.all([...imgEls].map((imgEl) => {
|
49 |
+
const origURL = imgEl.src;
|
50 |
+
const imgId = Date.now() % 200;
|
51 |
+
const fileName = 'semantic-palette-xl-' + imgId + '.png';
|
52 |
+
return toDataURL(origURL)
|
53 |
+
.then(dataUrl => {
|
54 |
+
return dataURLtoFile(dataUrl, fileName);
|
55 |
+
})
|
56 |
+
})).then(fileData => {return Promise.all([...fileData].map((file) => {
|
57 |
+
return uploadFile(file);
|
58 |
+
}))});
|
59 |
+
|
60 |
+
const htmlImgs = urls.map(url => `<img src='${url}' width='2560' height='1024'>`);
|
61 |
+
const descriptionMd = `<div style='display: flex; flex-wrap: wrap; column-gap: 0.75rem;'>
|
62 |
+
${htmlImgs.join(`\n`)}
|
63 |
+
</div>`;
|
64 |
+
const params = new URLSearchParams({
|
65 |
+
title: `My creation`,
|
66 |
+
description: descriptionMd,
|
67 |
+
});
|
68 |
+
const paramsStr = params.toString();
|
69 |
+
window.open(`https://huggingface.co/spaces/ironjr/SemanticPaletteXL/discussions/new?${paramsStr}`, '_blank');
|
70 |
+
}"""
|
util.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import concurrent.futures
|
22 |
+
import time
|
23 |
+
from typing import Any, Callable, List, Literal, Tuple, Union
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
import torch.cuda.amp as amp
|
31 |
+
import torchvision.transforms as T
|
32 |
+
import torchvision.transforms.functional as TF
|
33 |
+
|
34 |
+
from diffusers import (
|
35 |
+
DiffusionPipeline,
|
36 |
+
StableDiffusionPipeline,
|
37 |
+
StableDiffusionXLPipeline,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def seed_everything(seed: int) -> None:
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed(seed)
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
torch.backends.cudnn.benchmark = True
|
46 |
+
|
47 |
+
|
48 |
+
def load_model(
|
49 |
+
model_key: str,
|
50 |
+
sd_version: Literal['1.5', 'xl'],
|
51 |
+
device: torch.device,
|
52 |
+
dtype: torch.dtype,
|
53 |
+
) -> torch.nn.Module:
|
54 |
+
if model_key.endswith('.safetensors'):
|
55 |
+
if sd_version == '1.5':
|
56 |
+
pipeline = StableDiffusionPipeline
|
57 |
+
elif sd_version == 'xl':
|
58 |
+
pipeline = StableDiffusionXLPipeline
|
59 |
+
else:
|
60 |
+
raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
|
61 |
+
return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
|
62 |
+
try:
|
63 |
+
return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
|
64 |
+
except:
|
65 |
+
return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
|
66 |
+
|
67 |
+
|
68 |
+
def get_cutoff(cutoff: float = None, scale: float = None) -> float:
|
69 |
+
if cutoff is not None:
|
70 |
+
return cutoff
|
71 |
+
|
72 |
+
if scale is not None and cutoff is None:
|
73 |
+
return 0.5 / scale
|
74 |
+
|
75 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
76 |
+
|
77 |
+
|
78 |
+
def get_scale(cutoff: float = None, scale: float = None) -> float:
|
79 |
+
if scale is not None:
|
80 |
+
return scale
|
81 |
+
|
82 |
+
if cutoff is not None and scale is None:
|
83 |
+
return 0.5 / cutoff
|
84 |
+
|
85 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
86 |
+
|
87 |
+
|
88 |
+
def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
89 |
+
assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
|
90 |
+
# assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
|
91 |
+
|
92 |
+
b, c, h, w = x.shape
|
93 |
+
ks = k.shape[-1]
|
94 |
+
k = k.view(1, 1, -1).repeat(c, 1, 1)
|
95 |
+
|
96 |
+
x = x.permute(0, 2, 1, 3)
|
97 |
+
x = x.reshape(b * h, c, w)
|
98 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
99 |
+
x = F.conv1d(x, k, groups=c)
|
100 |
+
x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
|
101 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
102 |
+
x = F.conv1d(x, k, groups=c)
|
103 |
+
x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
108 |
+
assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
|
109 |
+
|
110 |
+
x = F.pad(x, (
|
111 |
+
k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
|
112 |
+
k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
|
113 |
+
), mode='replicate')
|
114 |
+
|
115 |
+
b, c, _, _ = x.shape
|
116 |
+
if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
|
117 |
+
k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
|
118 |
+
x = F.conv2d(x, k, groups=c)
|
119 |
+
elif len(k.shape) == 3:
|
120 |
+
assert k.shape[0] == b, \
|
121 |
+
'The number of kernels should match the batch size.'
|
122 |
+
|
123 |
+
k = k.unsqueeze(1)
|
124 |
+
x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
@amp.autocast(False)
|
129 |
+
def filter_by_kernel(
|
130 |
+
x: torch.Tensor,
|
131 |
+
k: torch.Tensor,
|
132 |
+
is_batch: bool = False,
|
133 |
+
) -> torch.Tensor:
|
134 |
+
k_dim = len(k.shape)
|
135 |
+
if k_dim == 1 or k_dim == 2 and is_batch:
|
136 |
+
return filter_2d_by_kernel_1d(x, k)
|
137 |
+
elif k_dim == 2 or k_dim == 3 and is_batch:
|
138 |
+
return filter_2d_by_kernel_2d(x, k)
|
139 |
+
else:
|
140 |
+
raise ValueError('Kernel size should be one of (1, 2, 3).')
|
141 |
+
|
142 |
+
|
143 |
+
def gen_gauss_lowpass_filter_2d(
|
144 |
+
std: torch.Tensor,
|
145 |
+
window_size: int = None,
|
146 |
+
) -> torch.Tensor:
|
147 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
148 |
+
if window_size is None:
|
149 |
+
window_size = (
|
150 |
+
2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
|
151 |
+
|
152 |
+
y = torch.arange(
|
153 |
+
window_size, dtype=std.dtype, device=std.device
|
154 |
+
).view(-1, 1).repeat(1, window_size)
|
155 |
+
grid = torch.stack((y.t(), y), dim=-1)
|
156 |
+
grid -= 0.5 * (window_size - 1) # (W, W)
|
157 |
+
var = (std * std).unsqueeze(-1).unsqueeze(-1)
|
158 |
+
distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
|
159 |
+
k = torch.exp(-0.5 * distsq / var)
|
160 |
+
k /= k.sum(dim=(-2, -1), keepdim=True)
|
161 |
+
return k
|
162 |
+
|
163 |
+
|
164 |
+
def gaussian_lowpass(
|
165 |
+
x: torch.Tensor,
|
166 |
+
std: Union[float, Tuple[float], torch.Tensor] = None,
|
167 |
+
cutoff: Union[float, torch.Tensor] = None,
|
168 |
+
scale: Union[float, torch.Tensor] = None,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
if std is None:
|
171 |
+
cutoff = get_cutoff(cutoff, scale)
|
172 |
+
std = 0.5 / (np.pi * cutoff)
|
173 |
+
if isinstance(std, (float, int)):
|
174 |
+
std = (std, std)
|
175 |
+
if isinstance(std, torch.Tensor):
|
176 |
+
"""Using nn.functional.conv2d with Gaussian kernels built in runtime is
|
177 |
+
80% faster than transforms.functional.gaussian_blur for individual
|
178 |
+
items.
|
179 |
+
|
180 |
+
(in GPU); However, in CPU, the result is exactly opposite. But you
|
181 |
+
won't gonna run this on CPU, right?
|
182 |
+
"""
|
183 |
+
if len(list(s for s in std.shape if s != 1)) >= 2:
|
184 |
+
raise NotImplementedError(
|
185 |
+
'Anisotropic Gaussian filter is not currently available.')
|
186 |
+
|
187 |
+
# k.shape == (B, W, W).
|
188 |
+
k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
|
189 |
+
if k.shape[0] == 1:
|
190 |
+
return filter_by_kernel(x, k[0], False)
|
191 |
+
else:
|
192 |
+
return filter_by_kernel(x, k, True)
|
193 |
+
else:
|
194 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
195 |
+
window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
|
196 |
+
return TF.gaussian_blur(x, window_size, std)
|
197 |
+
|
198 |
+
|
199 |
+
def blend(
|
200 |
+
fg: Union[torch.Tensor, Image.Image],
|
201 |
+
bg: Union[torch.Tensor, Image.Image],
|
202 |
+
mask: Union[torch.Tensor, Image.Image],
|
203 |
+
std: float = 0.0,
|
204 |
+
) -> Image.Image:
|
205 |
+
if not isinstance(fg, torch.Tensor):
|
206 |
+
fg = T.ToTensor()(fg)
|
207 |
+
if not isinstance(bg, torch.Tensor):
|
208 |
+
bg = T.ToTensor()(bg)
|
209 |
+
if not isinstance(mask, torch.Tensor):
|
210 |
+
mask = (T.ToTensor()(mask) < 0.5).float()[:1]
|
211 |
+
if std > 0:
|
212 |
+
mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
|
213 |
+
return T.ToPILImage()(fg * mask + bg * (1 - mask))
|
214 |
+
|
215 |
+
|
216 |
+
def get_panorama_views(
|
217 |
+
panorama_height: int,
|
218 |
+
panorama_width: int,
|
219 |
+
window_size: int = 64,
|
220 |
+
) -> tuple[List[Tuple[int]], torch.Tensor]:
|
221 |
+
stride = window_size // 2
|
222 |
+
is_horizontal = panorama_width > panorama_height
|
223 |
+
num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
|
224 |
+
num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
|
225 |
+
total_num_blocks = num_blocks_height * num_blocks_width
|
226 |
+
|
227 |
+
half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
|
228 |
+
half_rev = half_fwd.flip(0)
|
229 |
+
if window_size % 2 == 1:
|
230 |
+
half_rev = half_rev[1:]
|
231 |
+
c = torch.cat((half_fwd, half_rev))
|
232 |
+
one = torch.ones_like(c)
|
233 |
+
f = c.clone()
|
234 |
+
f[:window_size // 2] = 1
|
235 |
+
b = c.clone()
|
236 |
+
b[-(window_size // 2):] = 1
|
237 |
+
|
238 |
+
h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
|
239 |
+
w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
|
240 |
+
|
241 |
+
views = []
|
242 |
+
masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
|
243 |
+
for i in range(total_num_blocks):
|
244 |
+
hi, wi = i // num_blocks_width, i % num_blocks_width
|
245 |
+
h_start = hi * stride
|
246 |
+
h_end = min(h_start + window_size, panorama_height)
|
247 |
+
w_start = wi * stride
|
248 |
+
w_end = min(w_start + window_size, panorama_width)
|
249 |
+
views.append((h_start, h_end, w_start, w_end))
|
250 |
+
|
251 |
+
h_width = h_end - h_start
|
252 |
+
w_width = w_end - w_start
|
253 |
+
masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
|
254 |
+
|
255 |
+
# Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
|
256 |
+
return views, masks[None] # (1, n, h, w)
|
257 |
+
|
258 |
+
|
259 |
+
def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
|
260 |
+
h, w = mask.shape[-2:]
|
261 |
+
device = mask.device
|
262 |
+
mask = mask.reshape(-1, h, w)
|
263 |
+
# assert mask.shape[0] == im.shape[0]
|
264 |
+
h_occupied = mask.sum(dim=-2) > 0
|
265 |
+
w_occupied = mask.sum(dim=-1) > 0
|
266 |
+
l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
|
267 |
+
r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
|
268 |
+
t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
|
269 |
+
b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
|
270 |
+
tb = (t + b + 1) // 2
|
271 |
+
lr = (l + r + 1) // 2
|
272 |
+
shifts = (tb - (h // 2), lr - (w // 2))
|
273 |
+
shifts = torch.cat(shifts, dim=1) # (p, 2)
|
274 |
+
if reverse:
|
275 |
+
shifts = shifts * -1
|
276 |
+
return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
|
277 |
+
|
278 |
+
|
279 |
+
class Streamer:
|
280 |
+
def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
|
281 |
+
self.fn = fn
|
282 |
+
self.ema_alpha = ema_alpha
|
283 |
+
|
284 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
285 |
+
self.future = self.executor.submit(fn)
|
286 |
+
self.image = None
|
287 |
+
|
288 |
+
self.prev_exec_time = 0
|
289 |
+
self.ema_exec_time = 0
|
290 |
+
|
291 |
+
@property
|
292 |
+
def throughput(self) -> float:
|
293 |
+
return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
|
294 |
+
|
295 |
+
def timed_fn(self) -> Any:
|
296 |
+
start = time.time()
|
297 |
+
res = self.fn()
|
298 |
+
end = time.time()
|
299 |
+
self.prev_exec_time = end - start
|
300 |
+
self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
|
301 |
+
return res
|
302 |
+
|
303 |
+
def __call__(self) -> Any:
|
304 |
+
if self.future.done() or self.image is None:
|
305 |
+
# get the result (the new image) and start a new task
|
306 |
+
image = self.future.result()
|
307 |
+
self.future = self.executor.submit(self.timed_fn)
|
308 |
+
self.image = image
|
309 |
+
return image
|
310 |
+
else:
|
311 |
+
# if self.fn() is not ready yet, use the previous image
|
312 |
+
# NOTE: This assumes that we have access to a previously generated image here.
|
313 |
+
# If there's no previous image (i.e., this is the first invocation), you could fall
|
314 |
+
# back to some default image or handle it differently based on your requirements.
|
315 |
+
return self.image
|