Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- apps/app_sana.py +502 -0
- apps/app_sana_4bit.py +409 -0
- apps/app_sana_4bit_compare_bf16.py +313 -0
- apps/app_sana_controlnet_hed.py +306 -0
- apps/app_sana_multithread.py +565 -0
- apps/safety_check.py +72 -0
- apps/sana_controlnet_pipeline.py +353 -0
- apps/sana_pipeline.py +304 -0
apps/app_sana.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# SPDX-License-Identifier: Apache-2.0
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import socket
|
23 |
+
import sqlite3
|
24 |
+
import time
|
25 |
+
import uuid
|
26 |
+
from datetime import datetime
|
27 |
+
|
28 |
+
import gradio as gr
|
29 |
+
import numpy as np
|
30 |
+
import spaces
|
31 |
+
import torch
|
32 |
+
from PIL import Image
|
33 |
+
from torchvision.utils import make_grid, save_image
|
34 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
35 |
+
|
36 |
+
from app import safety_check
|
37 |
+
from app.sana_pipeline import SanaPipeline
|
38 |
+
|
39 |
+
MAX_SEED = np.iinfo(np.int32).max
|
40 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
41 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
42 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
43 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
44 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
45 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
46 |
+
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
47 |
+
|
48 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
49 |
+
|
50 |
+
style_list = [
|
51 |
+
{
|
52 |
+
"name": "(No style)",
|
53 |
+
"prompt": "{prompt}",
|
54 |
+
"negative_prompt": "",
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "Cinematic",
|
58 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
59 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
60 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"name": "Photographic",
|
64 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
65 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "Anime",
|
69 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
70 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "Manga",
|
74 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
75 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "Digital Art",
|
79 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
80 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "Pixel art",
|
84 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
85 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"name": "Fantasy art",
|
89 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
90 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
91 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
92 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
93 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"name": "Neonpunk",
|
97 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
98 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
99 |
+
"ultra detailed, intricate, professional",
|
100 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"name": "3D Model",
|
104 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
105 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
106 |
+
},
|
107 |
+
]
|
108 |
+
|
109 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
110 |
+
STYLE_NAMES = list(styles.keys())
|
111 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
112 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
113 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
114 |
+
NUM_IMAGES_PER_PROMPT = 1
|
115 |
+
INFER_SPEED = 0
|
116 |
+
|
117 |
+
|
118 |
+
def norm_ip(img, low, high):
|
119 |
+
img.clamp_(min=low, max=high)
|
120 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
121 |
+
return img
|
122 |
+
|
123 |
+
|
124 |
+
def open_db():
|
125 |
+
db = sqlite3.connect(COUNTER_DB)
|
126 |
+
db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
|
127 |
+
db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)')
|
128 |
+
return db
|
129 |
+
|
130 |
+
|
131 |
+
def read_inference_count():
|
132 |
+
with open_db() as db:
|
133 |
+
cur = db.execute('SELECT value FROM counter WHERE app="Sana"')
|
134 |
+
db.commit()
|
135 |
+
return cur.fetchone()[0]
|
136 |
+
|
137 |
+
|
138 |
+
def write_inference_count(count):
|
139 |
+
count = max(0, int(count))
|
140 |
+
with open_db() as db:
|
141 |
+
db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"')
|
142 |
+
db.commit()
|
143 |
+
|
144 |
+
|
145 |
+
def run_inference(num_imgs=1):
|
146 |
+
write_inference_count(num_imgs)
|
147 |
+
count = read_inference_count()
|
148 |
+
|
149 |
+
return (
|
150 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
151 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
def update_inference_count():
|
156 |
+
count = read_inference_count()
|
157 |
+
return (
|
158 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
159 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
164 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
165 |
+
if not negative:
|
166 |
+
negative = ""
|
167 |
+
return p.replace("{prompt}", positive), n + negative
|
168 |
+
|
169 |
+
|
170 |
+
def get_args():
|
171 |
+
parser = argparse.ArgumentParser()
|
172 |
+
parser.add_argument("--config", type=str, help="config")
|
173 |
+
parser.add_argument(
|
174 |
+
"--model_path",
|
175 |
+
nargs="?",
|
176 |
+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
177 |
+
type=str,
|
178 |
+
help="Path to the model file (positional)",
|
179 |
+
)
|
180 |
+
parser.add_argument("--output", default="./", type=str)
|
181 |
+
parser.add_argument("--bs", default=1, type=int)
|
182 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
183 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
184 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
185 |
+
parser.add_argument("--seed", default=42, type=int)
|
186 |
+
parser.add_argument("--step", default=-1, type=int)
|
187 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
188 |
+
parser.add_argument("--share", action="store_true")
|
189 |
+
parser.add_argument(
|
190 |
+
"--shield_model_path",
|
191 |
+
type=str,
|
192 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
193 |
+
default="google/shieldgemma-2b",
|
194 |
+
)
|
195 |
+
|
196 |
+
return parser.parse_known_args()[0]
|
197 |
+
|
198 |
+
|
199 |
+
args = get_args()
|
200 |
+
|
201 |
+
if torch.cuda.is_available():
|
202 |
+
model_path = args.model_path
|
203 |
+
pipe = SanaPipeline(args.config)
|
204 |
+
pipe.from_pretrained(model_path)
|
205 |
+
pipe.register_progress_bar(gr.Progress())
|
206 |
+
|
207 |
+
# safety checker
|
208 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
209 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
210 |
+
args.shield_model_path,
|
211 |
+
device_map="auto",
|
212 |
+
torch_dtype=torch.bfloat16,
|
213 |
+
).to(device)
|
214 |
+
|
215 |
+
|
216 |
+
def save_image_sana(img, seed="", save_img=False):
|
217 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
218 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
219 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
220 |
+
os.makedirs(save_path, exist_ok=True)
|
221 |
+
unique_name = os.path.join(save_path, unique_name)
|
222 |
+
if save_img:
|
223 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
224 |
+
|
225 |
+
return unique_name
|
226 |
+
|
227 |
+
|
228 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
229 |
+
if randomize_seed:
|
230 |
+
seed = random.randint(0, MAX_SEED)
|
231 |
+
return seed
|
232 |
+
|
233 |
+
|
234 |
+
@torch.no_grad()
|
235 |
+
@torch.inference_mode()
|
236 |
+
@spaces.GPU(enable_queue=True)
|
237 |
+
def generate(
|
238 |
+
prompt: str = None,
|
239 |
+
negative_prompt: str = "",
|
240 |
+
style: str = DEFAULT_STYLE_NAME,
|
241 |
+
use_negative_prompt: bool = False,
|
242 |
+
num_imgs: int = 1,
|
243 |
+
seed: int = 0,
|
244 |
+
height: int = 1024,
|
245 |
+
width: int = 1024,
|
246 |
+
flow_dpms_guidance_scale: float = 5.0,
|
247 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
248 |
+
flow_dpms_inference_steps: int = 20,
|
249 |
+
randomize_seed: bool = False,
|
250 |
+
):
|
251 |
+
global INFER_SPEED
|
252 |
+
# seed = 823753551
|
253 |
+
box = run_inference(num_imgs)
|
254 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
255 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
256 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
257 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
258 |
+
prompt = "A red heart."
|
259 |
+
|
260 |
+
print(prompt)
|
261 |
+
|
262 |
+
num_inference_steps = flow_dpms_inference_steps
|
263 |
+
guidance_scale = flow_dpms_guidance_scale
|
264 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
265 |
+
|
266 |
+
if not use_negative_prompt:
|
267 |
+
negative_prompt = None # type: ignore
|
268 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
269 |
+
|
270 |
+
pipe.progress_fn(0, desc="Sana Start")
|
271 |
+
|
272 |
+
time_start = time.time()
|
273 |
+
images = pipe(
|
274 |
+
prompt=prompt,
|
275 |
+
height=height,
|
276 |
+
width=width,
|
277 |
+
negative_prompt=negative_prompt,
|
278 |
+
guidance_scale=guidance_scale,
|
279 |
+
pag_guidance_scale=pag_guidance_scale,
|
280 |
+
num_inference_steps=num_inference_steps,
|
281 |
+
num_images_per_prompt=num_imgs,
|
282 |
+
generator=generator,
|
283 |
+
)
|
284 |
+
|
285 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
286 |
+
INFER_SPEED = (time.time() - time_start) / num_imgs
|
287 |
+
|
288 |
+
save_img = False
|
289 |
+
if save_img:
|
290 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
291 |
+
print(img)
|
292 |
+
else:
|
293 |
+
img = [
|
294 |
+
Image.fromarray(
|
295 |
+
norm_ip(img, -1, 1)
|
296 |
+
.mul(255)
|
297 |
+
.add_(0.5)
|
298 |
+
.clamp_(0, 255)
|
299 |
+
.permute(1, 2, 0)
|
300 |
+
.to("cpu", torch.uint8)
|
301 |
+
.numpy()
|
302 |
+
.astype(np.uint8)
|
303 |
+
)
|
304 |
+
for img in images
|
305 |
+
]
|
306 |
+
|
307 |
+
torch.cuda.empty_cache()
|
308 |
+
|
309 |
+
return (
|
310 |
+
img,
|
311 |
+
seed,
|
312 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
313 |
+
box,
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
318 |
+
title = f"""
|
319 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
320 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
321 |
+
</div>
|
322 |
+
"""
|
323 |
+
DESCRIPTION = f"""
|
324 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
325 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
326 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
327 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
328 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
329 |
+
"""
|
330 |
+
if model_size == "0.6":
|
331 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
332 |
+
if not torch.cuda.is_available():
|
333 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
334 |
+
|
335 |
+
examples = [
|
336 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
337 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
338 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
339 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
340 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
341 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
342 |
+
"👧 with 🌹 in the ❄️",
|
343 |
+
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
344 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
345 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
346 |
+
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
347 |
+
]
|
348 |
+
|
349 |
+
css = """
|
350 |
+
.gradio-container{max-width: 640px !important}
|
351 |
+
h1{text-align:center}
|
352 |
+
"""
|
353 |
+
with gr.Blocks(css=css, title="Sana") as demo:
|
354 |
+
gr.Markdown(title)
|
355 |
+
gr.HTML(DESCRIPTION)
|
356 |
+
gr.DuplicateButton(
|
357 |
+
value="Duplicate Space for private use",
|
358 |
+
elem_id="duplicate-button",
|
359 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
360 |
+
)
|
361 |
+
info_box = gr.Markdown(
|
362 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
363 |
+
)
|
364 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
365 |
+
# with gr.Row(equal_height=False):
|
366 |
+
with gr.Group():
|
367 |
+
with gr.Row():
|
368 |
+
prompt = gr.Text(
|
369 |
+
label="Prompt",
|
370 |
+
show_label=False,
|
371 |
+
max_lines=1,
|
372 |
+
placeholder="Enter your prompt",
|
373 |
+
container=False,
|
374 |
+
)
|
375 |
+
run_button = gr.Button("Run", scale=0)
|
376 |
+
result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
|
377 |
+
speed_box = gr.Markdown(
|
378 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
379 |
+
)
|
380 |
+
with gr.Accordion("Advanced options", open=False):
|
381 |
+
with gr.Group():
|
382 |
+
with gr.Row(visible=True):
|
383 |
+
height = gr.Slider(
|
384 |
+
label="Height",
|
385 |
+
minimum=256,
|
386 |
+
maximum=MAX_IMAGE_SIZE,
|
387 |
+
step=32,
|
388 |
+
value=args.image_size,
|
389 |
+
)
|
390 |
+
width = gr.Slider(
|
391 |
+
label="Width",
|
392 |
+
minimum=256,
|
393 |
+
maximum=MAX_IMAGE_SIZE,
|
394 |
+
step=32,
|
395 |
+
value=args.image_size,
|
396 |
+
)
|
397 |
+
with gr.Row():
|
398 |
+
flow_dpms_inference_steps = gr.Slider(
|
399 |
+
label="Sampling steps",
|
400 |
+
minimum=5,
|
401 |
+
maximum=40,
|
402 |
+
step=1,
|
403 |
+
value=20,
|
404 |
+
)
|
405 |
+
flow_dpms_guidance_scale = gr.Slider(
|
406 |
+
label="CFG Guidance scale",
|
407 |
+
minimum=1,
|
408 |
+
maximum=10,
|
409 |
+
step=0.1,
|
410 |
+
value=4.5,
|
411 |
+
)
|
412 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
413 |
+
label="PAG Guidance scale",
|
414 |
+
minimum=1,
|
415 |
+
maximum=4,
|
416 |
+
step=0.5,
|
417 |
+
value=1.0,
|
418 |
+
)
|
419 |
+
with gr.Row():
|
420 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
421 |
+
negative_prompt = gr.Text(
|
422 |
+
label="Negative prompt",
|
423 |
+
max_lines=1,
|
424 |
+
placeholder="Enter a negative prompt",
|
425 |
+
visible=True,
|
426 |
+
)
|
427 |
+
style_selection = gr.Radio(
|
428 |
+
show_label=True,
|
429 |
+
container=True,
|
430 |
+
interactive=True,
|
431 |
+
choices=STYLE_NAMES,
|
432 |
+
value=DEFAULT_STYLE_NAME,
|
433 |
+
label="Image Style",
|
434 |
+
)
|
435 |
+
seed = gr.Slider(
|
436 |
+
label="Seed",
|
437 |
+
minimum=0,
|
438 |
+
maximum=MAX_SEED,
|
439 |
+
step=1,
|
440 |
+
value=0,
|
441 |
+
)
|
442 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
443 |
+
with gr.Row(visible=True):
|
444 |
+
schedule = gr.Radio(
|
445 |
+
show_label=True,
|
446 |
+
container=True,
|
447 |
+
interactive=True,
|
448 |
+
choices=SCHEDULE_NAME,
|
449 |
+
value=DEFAULT_SCHEDULE_NAME,
|
450 |
+
label="Sampler Schedule",
|
451 |
+
visible=True,
|
452 |
+
)
|
453 |
+
num_imgs = gr.Slider(
|
454 |
+
label="Num Images",
|
455 |
+
minimum=1,
|
456 |
+
maximum=6,
|
457 |
+
step=1,
|
458 |
+
value=1,
|
459 |
+
)
|
460 |
+
|
461 |
+
gr.Examples(
|
462 |
+
examples=examples,
|
463 |
+
inputs=prompt,
|
464 |
+
outputs=[result, seed],
|
465 |
+
fn=generate,
|
466 |
+
cache_examples=CACHE_EXAMPLES,
|
467 |
+
)
|
468 |
+
|
469 |
+
use_negative_prompt.change(
|
470 |
+
fn=lambda x: gr.update(visible=x),
|
471 |
+
inputs=use_negative_prompt,
|
472 |
+
outputs=negative_prompt,
|
473 |
+
api_name=False,
|
474 |
+
)
|
475 |
+
|
476 |
+
gr.on(
|
477 |
+
triggers=[
|
478 |
+
prompt.submit,
|
479 |
+
negative_prompt.submit,
|
480 |
+
run_button.click,
|
481 |
+
],
|
482 |
+
fn=generate,
|
483 |
+
inputs=[
|
484 |
+
prompt,
|
485 |
+
negative_prompt,
|
486 |
+
style_selection,
|
487 |
+
use_negative_prompt,
|
488 |
+
num_imgs,
|
489 |
+
seed,
|
490 |
+
height,
|
491 |
+
width,
|
492 |
+
flow_dpms_guidance_scale,
|
493 |
+
flow_dpms_pag_guidance_scale,
|
494 |
+
flow_dpms_inference_steps,
|
495 |
+
randomize_seed,
|
496 |
+
],
|
497 |
+
outputs=[result, seed, speed_box, info_box],
|
498 |
+
api_name="run",
|
499 |
+
)
|
500 |
+
|
501 |
+
if __name__ == "__main__":
|
502 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_4bit.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
#!/usr/bin/env python
|
6 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
#
|
20 |
+
# SPDX-License-Identifier: Apache-2.0
|
21 |
+
from __future__ import annotations
|
22 |
+
|
23 |
+
import argparse
|
24 |
+
import os
|
25 |
+
import random
|
26 |
+
import time
|
27 |
+
import uuid
|
28 |
+
from datetime import datetime
|
29 |
+
|
30 |
+
import gradio as gr
|
31 |
+
import numpy as np
|
32 |
+
import spaces
|
33 |
+
import torch
|
34 |
+
from diffusers import SanaPipeline
|
35 |
+
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
36 |
+
from torchvision.utils import save_image
|
37 |
+
|
38 |
+
MAX_SEED = np.iinfo(np.int32).max
|
39 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
40 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
41 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
42 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
43 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
44 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
45 |
+
COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
|
46 |
+
INFER_SPEED = 0
|
47 |
+
|
48 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
49 |
+
|
50 |
+
style_list = [
|
51 |
+
{
|
52 |
+
"name": "(No style)",
|
53 |
+
"prompt": "{prompt}",
|
54 |
+
"negative_prompt": "",
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "Cinematic",
|
58 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
59 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
60 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"name": "Photographic",
|
64 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
65 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "Anime",
|
69 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
70 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "Manga",
|
74 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
75 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "Digital Art",
|
79 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
80 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "Pixel art",
|
84 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
85 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"name": "Fantasy art",
|
89 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
90 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
91 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
92 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
93 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"name": "Neonpunk",
|
97 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
98 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
99 |
+
"ultra detailed, intricate, professional",
|
100 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"name": "3D Model",
|
104 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
105 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
106 |
+
},
|
107 |
+
]
|
108 |
+
|
109 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
110 |
+
STYLE_NAMES = list(styles.keys())
|
111 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
112 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
113 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
114 |
+
NUM_IMAGES_PER_PROMPT = 1
|
115 |
+
|
116 |
+
|
117 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
118 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
119 |
+
if not negative:
|
120 |
+
negative = ""
|
121 |
+
return p.replace("{prompt}", positive), n + negative
|
122 |
+
|
123 |
+
|
124 |
+
def get_args():
|
125 |
+
parser = argparse.ArgumentParser()
|
126 |
+
parser.add_argument(
|
127 |
+
"--model_path",
|
128 |
+
nargs="?",
|
129 |
+
default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
130 |
+
type=str,
|
131 |
+
help="Path to the model file (positional)",
|
132 |
+
)
|
133 |
+
parser.add_argument("--share", action="store_true")
|
134 |
+
|
135 |
+
return parser.parse_known_args()[0]
|
136 |
+
|
137 |
+
|
138 |
+
args = get_args()
|
139 |
+
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
|
142 |
+
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
143 |
+
pipe = SanaPipeline.from_pretrained(
|
144 |
+
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
145 |
+
transformer=transformer,
|
146 |
+
variant="bf16",
|
147 |
+
torch_dtype=torch.bfloat16,
|
148 |
+
).to(device)
|
149 |
+
|
150 |
+
pipe.text_encoder.to(torch.bfloat16)
|
151 |
+
pipe.vae.to(torch.bfloat16)
|
152 |
+
|
153 |
+
|
154 |
+
def save_image_sana(img, seed="", save_img=False):
|
155 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
156 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
157 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
158 |
+
os.makedirs(save_path, exist_ok=True)
|
159 |
+
unique_name = os.path.join(save_path, unique_name)
|
160 |
+
if save_img:
|
161 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
162 |
+
|
163 |
+
return unique_name
|
164 |
+
|
165 |
+
|
166 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
167 |
+
if randomize_seed:
|
168 |
+
seed = random.randint(0, MAX_SEED)
|
169 |
+
return seed
|
170 |
+
|
171 |
+
|
172 |
+
@torch.no_grad()
|
173 |
+
@torch.inference_mode()
|
174 |
+
@spaces.GPU(enable_queue=True)
|
175 |
+
def generate(
|
176 |
+
prompt: str = None,
|
177 |
+
negative_prompt: str = "",
|
178 |
+
style: str = DEFAULT_STYLE_NAME,
|
179 |
+
use_negative_prompt: bool = False,
|
180 |
+
num_imgs: int = 1,
|
181 |
+
seed: int = 0,
|
182 |
+
height: int = 1024,
|
183 |
+
width: int = 1024,
|
184 |
+
flow_dpms_guidance_scale: float = 5.0,
|
185 |
+
flow_dpms_inference_steps: int = 20,
|
186 |
+
randomize_seed: bool = False,
|
187 |
+
):
|
188 |
+
global INFER_SPEED
|
189 |
+
# seed = 823753551
|
190 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
191 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
192 |
+
print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}")
|
193 |
+
|
194 |
+
print(prompt)
|
195 |
+
|
196 |
+
num_inference_steps = flow_dpms_inference_steps
|
197 |
+
guidance_scale = flow_dpms_guidance_scale
|
198 |
+
|
199 |
+
if not use_negative_prompt:
|
200 |
+
negative_prompt = None # type: ignore
|
201 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
202 |
+
|
203 |
+
time_start = time.time()
|
204 |
+
images = pipe(
|
205 |
+
prompt=prompt,
|
206 |
+
height=height,
|
207 |
+
width=width,
|
208 |
+
negative_prompt=negative_prompt,
|
209 |
+
guidance_scale=guidance_scale,
|
210 |
+
num_inference_steps=num_inference_steps,
|
211 |
+
num_images_per_prompt=num_imgs,
|
212 |
+
generator=generator,
|
213 |
+
).images
|
214 |
+
INFER_SPEED = (time.time() - time_start) / num_imgs
|
215 |
+
|
216 |
+
save_img = False
|
217 |
+
if save_img:
|
218 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
219 |
+
print(img)
|
220 |
+
else:
|
221 |
+
img = images
|
222 |
+
|
223 |
+
torch.cuda.empty_cache()
|
224 |
+
|
225 |
+
return (
|
226 |
+
img,
|
227 |
+
seed,
|
228 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
233 |
+
title = f"""
|
234 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
235 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="30%" alt="logo"/>
|
236 |
+
</div>
|
237 |
+
"""
|
238 |
+
DESCRIPTION = f"""
|
239 |
+
<p style="font-size: 30px; font-weight: bold; text-align: center;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)</p>
|
240 |
+
"""
|
241 |
+
if model_size == "0.6":
|
242 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
243 |
+
if not torch.cuda.is_available():
|
244 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
245 |
+
|
246 |
+
examples = [
|
247 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
248 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
249 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
250 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
251 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
252 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
253 |
+
"👧 with 🌹 in the ❄️",
|
254 |
+
"an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
255 |
+
"professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
256 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed",
|
257 |
+
"a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
258 |
+
]
|
259 |
+
|
260 |
+
css = """
|
261 |
+
.gradio-container {max-width: 850px !important; height: auto !important;}
|
262 |
+
h1 {text-align: center;}
|
263 |
+
"""
|
264 |
+
theme = gr.themes.Base()
|
265 |
+
with gr.Blocks(css=css, theme=theme, title="Sana") as demo:
|
266 |
+
gr.Markdown(title)
|
267 |
+
gr.HTML(DESCRIPTION)
|
268 |
+
gr.DuplicateButton(
|
269 |
+
value="Duplicate Space for private use",
|
270 |
+
elem_id="duplicate-button",
|
271 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
272 |
+
)
|
273 |
+
# with gr.Row(equal_height=False):
|
274 |
+
with gr.Group():
|
275 |
+
with gr.Row():
|
276 |
+
prompt = gr.Text(
|
277 |
+
label="Prompt",
|
278 |
+
show_label=False,
|
279 |
+
max_lines=1,
|
280 |
+
placeholder="Enter your prompt",
|
281 |
+
container=False,
|
282 |
+
)
|
283 |
+
run_button = gr.Button("Run", scale=0)
|
284 |
+
result = gr.Gallery(
|
285 |
+
label="Result",
|
286 |
+
show_label=False,
|
287 |
+
height=750,
|
288 |
+
columns=NUM_IMAGES_PER_PROMPT,
|
289 |
+
format="jpeg",
|
290 |
+
)
|
291 |
+
|
292 |
+
speed_box = gr.Markdown(
|
293 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
|
294 |
+
)
|
295 |
+
with gr.Accordion("Advanced options", open=False):
|
296 |
+
with gr.Group():
|
297 |
+
with gr.Row(visible=True):
|
298 |
+
height = gr.Slider(
|
299 |
+
label="Height",
|
300 |
+
minimum=256,
|
301 |
+
maximum=MAX_IMAGE_SIZE,
|
302 |
+
step=32,
|
303 |
+
value=1024,
|
304 |
+
)
|
305 |
+
width = gr.Slider(
|
306 |
+
label="Width",
|
307 |
+
minimum=256,
|
308 |
+
maximum=MAX_IMAGE_SIZE,
|
309 |
+
step=32,
|
310 |
+
value=1024,
|
311 |
+
)
|
312 |
+
with gr.Row():
|
313 |
+
flow_dpms_inference_steps = gr.Slider(
|
314 |
+
label="Sampling steps",
|
315 |
+
minimum=5,
|
316 |
+
maximum=40,
|
317 |
+
step=1,
|
318 |
+
value=20,
|
319 |
+
)
|
320 |
+
flow_dpms_guidance_scale = gr.Slider(
|
321 |
+
label="CFG Guidance scale",
|
322 |
+
minimum=1,
|
323 |
+
maximum=10,
|
324 |
+
step=0.1,
|
325 |
+
value=4.5,
|
326 |
+
)
|
327 |
+
with gr.Row():
|
328 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
329 |
+
negative_prompt = gr.Text(
|
330 |
+
label="Negative prompt",
|
331 |
+
max_lines=1,
|
332 |
+
placeholder="Enter a negative prompt",
|
333 |
+
visible=True,
|
334 |
+
)
|
335 |
+
style_selection = gr.Radio(
|
336 |
+
show_label=True,
|
337 |
+
container=True,
|
338 |
+
interactive=True,
|
339 |
+
choices=STYLE_NAMES,
|
340 |
+
value=DEFAULT_STYLE_NAME,
|
341 |
+
label="Image Style",
|
342 |
+
)
|
343 |
+
seed = gr.Slider(
|
344 |
+
label="Seed",
|
345 |
+
minimum=0,
|
346 |
+
maximum=MAX_SEED,
|
347 |
+
step=1,
|
348 |
+
value=0,
|
349 |
+
)
|
350 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
351 |
+
with gr.Row(visible=True):
|
352 |
+
schedule = gr.Radio(
|
353 |
+
show_label=True,
|
354 |
+
container=True,
|
355 |
+
interactive=True,
|
356 |
+
choices=SCHEDULE_NAME,
|
357 |
+
value=DEFAULT_SCHEDULE_NAME,
|
358 |
+
label="Sampler Schedule",
|
359 |
+
visible=True,
|
360 |
+
)
|
361 |
+
num_imgs = gr.Slider(
|
362 |
+
label="Num Images",
|
363 |
+
minimum=1,
|
364 |
+
maximum=6,
|
365 |
+
step=1,
|
366 |
+
value=1,
|
367 |
+
)
|
368 |
+
|
369 |
+
gr.Examples(
|
370 |
+
examples=examples,
|
371 |
+
inputs=prompt,
|
372 |
+
outputs=[result, seed],
|
373 |
+
fn=generate,
|
374 |
+
cache_examples=CACHE_EXAMPLES,
|
375 |
+
)
|
376 |
+
|
377 |
+
use_negative_prompt.change(
|
378 |
+
fn=lambda x: gr.update(visible=x),
|
379 |
+
inputs=use_negative_prompt,
|
380 |
+
outputs=negative_prompt,
|
381 |
+
api_name=False,
|
382 |
+
)
|
383 |
+
|
384 |
+
gr.on(
|
385 |
+
triggers=[
|
386 |
+
prompt.submit,
|
387 |
+
negative_prompt.submit,
|
388 |
+
run_button.click,
|
389 |
+
],
|
390 |
+
fn=generate,
|
391 |
+
inputs=[
|
392 |
+
prompt,
|
393 |
+
negative_prompt,
|
394 |
+
style_selection,
|
395 |
+
use_negative_prompt,
|
396 |
+
num_imgs,
|
397 |
+
seed,
|
398 |
+
height,
|
399 |
+
width,
|
400 |
+
flow_dpms_guidance_scale,
|
401 |
+
flow_dpms_inference_steps,
|
402 |
+
randomize_seed,
|
403 |
+
],
|
404 |
+
outputs=[result, seed, speed_box],
|
405 |
+
api_name="run",
|
406 |
+
)
|
407 |
+
|
408 |
+
if __name__ == "__main__":
|
409 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_4bit_compare_bf16.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
import GPUtil
|
9 |
+
|
10 |
+
# import gradio last to avoid conflicts with other imports
|
11 |
+
import gradio as gr
|
12 |
+
import safety_check
|
13 |
+
import spaces
|
14 |
+
import torch
|
15 |
+
from diffusers import SanaPipeline
|
16 |
+
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
|
17 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
18 |
+
|
19 |
+
MAX_IMAGE_SIZE = 2048
|
20 |
+
MAX_SEED = 1000000000
|
21 |
+
|
22 |
+
DEFAULT_HEIGHT = 1024
|
23 |
+
DEFAULT_WIDTH = 1024
|
24 |
+
|
25 |
+
# num_inference_steps, guidance_scale, seed
|
26 |
+
EXAMPLES = [
|
27 |
+
[
|
28 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
29 |
+
1024,
|
30 |
+
1024,
|
31 |
+
20,
|
32 |
+
5,
|
33 |
+
2,
|
34 |
+
],
|
35 |
+
[
|
36 |
+
"大漠孤烟直, 长河落日圆",
|
37 |
+
1024,
|
38 |
+
1024,
|
39 |
+
20,
|
40 |
+
5,
|
41 |
+
23,
|
42 |
+
],
|
43 |
+
[
|
44 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
|
45 |
+
"volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
|
46 |
+
"art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
47 |
+
1024,
|
48 |
+
1024,
|
49 |
+
20,
|
50 |
+
5,
|
51 |
+
233,
|
52 |
+
],
|
53 |
+
[
|
54 |
+
"A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
|
55 |
+
"sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
|
56 |
+
"lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
|
57 |
+
"for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
|
58 |
+
"cinematic lighting, ultra-HD.",
|
59 |
+
1024,
|
60 |
+
1024,
|
61 |
+
20,
|
62 |
+
5,
|
63 |
+
2333,
|
64 |
+
],
|
65 |
+
[
|
66 |
+
"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
|
67 |
+
"She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
|
68 |
+
"She wears sunglasses and red lipstick. She walks confidently and casually. "
|
69 |
+
"The street is damp and reflective, creating a mirror effect of the colorful lights. "
|
70 |
+
"Many pedestrians walk about.",
|
71 |
+
1024,
|
72 |
+
1024,
|
73 |
+
20,
|
74 |
+
5,
|
75 |
+
23333,
|
76 |
+
],
|
77 |
+
[
|
78 |
+
"Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
|
79 |
+
"opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
|
80 |
+
"and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
|
81 |
+
"cinematic lighting, ultra-HD.",
|
82 |
+
1024,
|
83 |
+
1024,
|
84 |
+
20,
|
85 |
+
5,
|
86 |
+
233333,
|
87 |
+
],
|
88 |
+
]
|
89 |
+
|
90 |
+
|
91 |
+
def hash_str_to_int(s: str) -> int:
|
92 |
+
"""Hash a string to an integer."""
|
93 |
+
modulus = 10**9 + 7 # Large prime modulus
|
94 |
+
hash_int = 0
|
95 |
+
for char in s:
|
96 |
+
hash_int = (hash_int * 31 + ord(char)) % modulus
|
97 |
+
return hash_int
|
98 |
+
|
99 |
+
|
100 |
+
def get_pipeline(
|
101 |
+
precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
|
102 |
+
) -> SanaPipeline:
|
103 |
+
if precision == "int4":
|
104 |
+
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
|
105 |
+
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
|
106 |
+
|
107 |
+
pipeline_init_kwargs["transformer"] = transformer
|
108 |
+
if use_qencoder:
|
109 |
+
raise NotImplementedError("Quantized encoder not supported for Sana for now")
|
110 |
+
else:
|
111 |
+
assert precision == "bf16"
|
112 |
+
pipeline = SanaPipeline.from_pretrained(
|
113 |
+
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
114 |
+
variant="bf16",
|
115 |
+
torch_dtype=torch.bfloat16,
|
116 |
+
**pipeline_init_kwargs,
|
117 |
+
)
|
118 |
+
|
119 |
+
pipeline = pipeline.to(device)
|
120 |
+
return pipeline
|
121 |
+
|
122 |
+
|
123 |
+
def get_args() -> argparse.Namespace:
|
124 |
+
parser = argparse.ArgumentParser()
|
125 |
+
parser.add_argument(
|
126 |
+
"-p",
|
127 |
+
"--precisions",
|
128 |
+
type=str,
|
129 |
+
default=["int4"],
|
130 |
+
nargs="*",
|
131 |
+
choices=["int4", "bf16"],
|
132 |
+
help="Which precisions to use",
|
133 |
+
)
|
134 |
+
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
|
135 |
+
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
|
136 |
+
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
|
137 |
+
return parser.parse_args()
|
138 |
+
|
139 |
+
|
140 |
+
args = get_args()
|
141 |
+
|
142 |
+
|
143 |
+
pipelines = []
|
144 |
+
pipeline_init_kwargs = {}
|
145 |
+
for i, precision in enumerate(args.precisions):
|
146 |
+
|
147 |
+
pipeline = get_pipeline(
|
148 |
+
precision=precision,
|
149 |
+
use_qencoder=args.use_qencoder,
|
150 |
+
device="cuda",
|
151 |
+
pipeline_init_kwargs={**pipeline_init_kwargs},
|
152 |
+
)
|
153 |
+
pipelines.append(pipeline)
|
154 |
+
if i == 0:
|
155 |
+
pipeline_init_kwargs["vae"] = pipeline.vae
|
156 |
+
pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
|
157 |
+
|
158 |
+
# safety checker
|
159 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
160 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
161 |
+
args.shield_model_path,
|
162 |
+
device_map="auto",
|
163 |
+
torch_dtype=torch.bfloat16,
|
164 |
+
).to(pipeline.device)
|
165 |
+
|
166 |
+
|
167 |
+
@spaces.GPU(enable_queue=True)
|
168 |
+
def generate(
|
169 |
+
prompt: str = None,
|
170 |
+
height: int = 1024,
|
171 |
+
width: int = 1024,
|
172 |
+
num_inference_steps: int = 4,
|
173 |
+
guidance_scale: float = 0,
|
174 |
+
seed: int = 0,
|
175 |
+
):
|
176 |
+
print(f"Prompt: {prompt}")
|
177 |
+
is_unsafe_prompt = False
|
178 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
179 |
+
prompt = "A peaceful world."
|
180 |
+
images, latency_strs = [], []
|
181 |
+
for i, pipeline in enumerate(pipelines):
|
182 |
+
progress = gr.Progress(track_tqdm=True)
|
183 |
+
start_time = time.time()
|
184 |
+
image = pipeline(
|
185 |
+
prompt=prompt,
|
186 |
+
height=height,
|
187 |
+
width=width,
|
188 |
+
guidance_scale=guidance_scale,
|
189 |
+
num_inference_steps=num_inference_steps,
|
190 |
+
generator=torch.Generator().manual_seed(seed),
|
191 |
+
).images[0]
|
192 |
+
end_time = time.time()
|
193 |
+
latency = end_time - start_time
|
194 |
+
if latency < 1:
|
195 |
+
latency = latency * 1000
|
196 |
+
latency_str = f"{latency:.2f}ms"
|
197 |
+
else:
|
198 |
+
latency_str = f"{latency:.2f}s"
|
199 |
+
images.append(image)
|
200 |
+
latency_strs.append(latency_str)
|
201 |
+
if is_unsafe_prompt:
|
202 |
+
for i in range(len(latency_strs)):
|
203 |
+
latency_strs[i] += " (Unsafe prompt detected)"
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
|
206 |
+
if args.count_use:
|
207 |
+
if os.path.exists("use_count.txt"):
|
208 |
+
with open("use_count.txt") as f:
|
209 |
+
count = int(f.read())
|
210 |
+
else:
|
211 |
+
count = 0
|
212 |
+
count += 1
|
213 |
+
current_time = datetime.now()
|
214 |
+
print(f"{current_time}: {count}")
|
215 |
+
with open("use_count.txt", "w") as f:
|
216 |
+
f.write(str(count))
|
217 |
+
with open("use_record.txt", "a") as f:
|
218 |
+
f.write(f"{current_time}: {count}\n")
|
219 |
+
|
220 |
+
return *images, *latency_strs
|
221 |
+
|
222 |
+
|
223 |
+
with open("./assets/description.html") as f:
|
224 |
+
DESCRIPTION = f.read()
|
225 |
+
gpus = GPUtil.getGPUs()
|
226 |
+
if len(gpus) > 0:
|
227 |
+
gpu = gpus[0]
|
228 |
+
memory = gpu.memoryTotal / 1024
|
229 |
+
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
|
230 |
+
else:
|
231 |
+
device_info = "Running on CPU 🥶 This demo does not work on CPU."
|
232 |
+
notice = f'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
|
233 |
+
|
234 |
+
with gr.Blocks(
|
235 |
+
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
|
236 |
+
title=f"SVDQuant SANA-1600M Demo",
|
237 |
+
) as demo:
|
238 |
+
|
239 |
+
def get_header_str():
|
240 |
+
|
241 |
+
if args.count_use:
|
242 |
+
if os.path.exists("use_count.txt"):
|
243 |
+
with open("use_count.txt") as f:
|
244 |
+
count = int(f.read())
|
245 |
+
else:
|
246 |
+
count = 0
|
247 |
+
count_info = (
|
248 |
+
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
|
249 |
+
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
|
250 |
+
f"<span style='font-size: 18px; color:red; font-weight: bold;'> {count}</span></div>"
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
count_info = ""
|
254 |
+
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
|
255 |
+
return header_str
|
256 |
+
|
257 |
+
header = gr.HTML(get_header_str())
|
258 |
+
demo.load(fn=get_header_str, outputs=header)
|
259 |
+
|
260 |
+
with gr.Row():
|
261 |
+
image_results, latency_results = [], []
|
262 |
+
for i, precision in enumerate(args.precisions):
|
263 |
+
with gr.Column():
|
264 |
+
gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
|
265 |
+
with gr.Group():
|
266 |
+
image_result = gr.Image(
|
267 |
+
format="png",
|
268 |
+
image_mode="RGB",
|
269 |
+
label="Result",
|
270 |
+
show_label=False,
|
271 |
+
show_download_button=True,
|
272 |
+
interactive=False,
|
273 |
+
)
|
274 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
275 |
+
image_results.append(image_result)
|
276 |
+
latency_results.append(latency_result)
|
277 |
+
with gr.Row():
|
278 |
+
prompt = gr.Text(
|
279 |
+
label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
|
280 |
+
)
|
281 |
+
run_button = gr.Button("Run", scale=1)
|
282 |
+
|
283 |
+
with gr.Row():
|
284 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
285 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
286 |
+
with gr.Accordion("Advanced options", open=False):
|
287 |
+
with gr.Group():
|
288 |
+
height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
|
289 |
+
width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
|
290 |
+
with gr.Group():
|
291 |
+
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
|
292 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
|
293 |
+
|
294 |
+
input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed]
|
295 |
+
|
296 |
+
gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
|
297 |
+
|
298 |
+
gr.on(
|
299 |
+
triggers=[prompt.submit, run_button.click],
|
300 |
+
fn=generate,
|
301 |
+
inputs=input_args,
|
302 |
+
outputs=[*image_results, *latency_results],
|
303 |
+
api_name="run",
|
304 |
+
)
|
305 |
+
randomize_seed.click(
|
306 |
+
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
|
307 |
+
).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
|
308 |
+
|
309 |
+
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == "__main__":
|
313 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
|
apps/app_sana_controlnet_hed.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import socket
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
+
|
15 |
+
from app import safety_check
|
16 |
+
from app.sana_controlnet_pipeline import SanaControlNetPipeline
|
17 |
+
|
18 |
+
STYLES = {
|
19 |
+
"None": "{prompt}",
|
20 |
+
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
21 |
+
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
|
22 |
+
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
|
23 |
+
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
|
24 |
+
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
25 |
+
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
|
26 |
+
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
27 |
+
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
28 |
+
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
|
29 |
+
}
|
30 |
+
DEFAULT_STYLE_NAME = "None"
|
31 |
+
STYLE_NAMES = list(STYLES.keys())
|
32 |
+
|
33 |
+
MAX_SEED = 1000000000
|
34 |
+
DEFAULT_SKETCH_GUIDANCE = 0.28
|
35 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
36 |
+
|
37 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
|
40 |
+
|
41 |
+
|
42 |
+
def get_args():
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument("--config", type=str, help="config")
|
45 |
+
parser.add_argument(
|
46 |
+
"--model_path",
|
47 |
+
nargs="?",
|
48 |
+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
49 |
+
type=str,
|
50 |
+
help="Path to the model file (positional)",
|
51 |
+
)
|
52 |
+
parser.add_argument("--output", default="./", type=str)
|
53 |
+
parser.add_argument("--bs", default=1, type=int)
|
54 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
55 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
56 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
57 |
+
parser.add_argument("--seed", default=42, type=int)
|
58 |
+
parser.add_argument("--step", default=-1, type=int)
|
59 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
60 |
+
parser.add_argument("--share", action="store_true")
|
61 |
+
parser.add_argument(
|
62 |
+
"--shield_model_path",
|
63 |
+
type=str,
|
64 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
65 |
+
default="google/shieldgemma-2b",
|
66 |
+
)
|
67 |
+
|
68 |
+
return parser.parse_known_args()[0]
|
69 |
+
|
70 |
+
|
71 |
+
args = get_args()
|
72 |
+
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
model_path = args.model_path
|
75 |
+
pipe = SanaControlNetPipeline(args.config)
|
76 |
+
pipe.from_pretrained(model_path)
|
77 |
+
pipe.register_progress_bar(gr.Progress())
|
78 |
+
|
79 |
+
# safety checker
|
80 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
81 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
82 |
+
args.shield_model_path,
|
83 |
+
device_map="auto",
|
84 |
+
torch_dtype=torch.bfloat16,
|
85 |
+
).to(device)
|
86 |
+
|
87 |
+
|
88 |
+
def save_image(img):
|
89 |
+
if isinstance(img, dict):
|
90 |
+
img = img["composite"]
|
91 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
92 |
+
img.save(temp_file.name)
|
93 |
+
return temp_file.name
|
94 |
+
|
95 |
+
|
96 |
+
def norm_ip(img, low, high):
|
97 |
+
img.clamp_(min=low, max=high)
|
98 |
+
img.sub_(low).div_(max(high - low, 1e-5))
|
99 |
+
return img
|
100 |
+
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
@torch.inference_mode()
|
104 |
+
def run(
|
105 |
+
image,
|
106 |
+
prompt: str,
|
107 |
+
prompt_template: str,
|
108 |
+
sketch_thickness: int,
|
109 |
+
guidance_scale: float,
|
110 |
+
inference_steps: int,
|
111 |
+
seed: int,
|
112 |
+
blend_alpha: float,
|
113 |
+
) -> tuple[Image, str]:
|
114 |
+
|
115 |
+
print(f"Prompt: {prompt}")
|
116 |
+
image_numpy = np.array(image["composite"].convert("RGB"))
|
117 |
+
|
118 |
+
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
|
119 |
+
return blank_image, "Please input the prompt or draw something."
|
120 |
+
|
121 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
122 |
+
prompt = "A red heart."
|
123 |
+
|
124 |
+
prompt = prompt_template.format(prompt=prompt)
|
125 |
+
pipe.set_blend_alpha(blend_alpha)
|
126 |
+
start_time = time.time()
|
127 |
+
images = pipe(
|
128 |
+
prompt=prompt,
|
129 |
+
ref_image=image["composite"],
|
130 |
+
guidance_scale=guidance_scale,
|
131 |
+
num_inference_steps=inference_steps,
|
132 |
+
num_images_per_prompt=1,
|
133 |
+
sketch_thickness=sketch_thickness,
|
134 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
135 |
+
)
|
136 |
+
|
137 |
+
latency = time.time() - start_time
|
138 |
+
|
139 |
+
if latency < 1:
|
140 |
+
latency = latency * 1000
|
141 |
+
latency_str = f"{latency:.2f}ms"
|
142 |
+
else:
|
143 |
+
latency_str = f"{latency:.2f}s"
|
144 |
+
torch.cuda.empty_cache()
|
145 |
+
|
146 |
+
img = [
|
147 |
+
Image.fromarray(
|
148 |
+
norm_ip(img, -1, 1)
|
149 |
+
.mul(255)
|
150 |
+
.add_(0.5)
|
151 |
+
.clamp_(0, 255)
|
152 |
+
.permute(1, 2, 0)
|
153 |
+
.to("cpu", torch.uint8)
|
154 |
+
.numpy()
|
155 |
+
.astype(np.uint8)
|
156 |
+
)
|
157 |
+
for img in images
|
158 |
+
]
|
159 |
+
img = img[0]
|
160 |
+
return img, latency_str
|
161 |
+
|
162 |
+
|
163 |
+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
|
164 |
+
title = f"""
|
165 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
166 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
167 |
+
</div>
|
168 |
+
"""
|
169 |
+
DESCRIPTION = f"""
|
170 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
171 |
+
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
172 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
173 |
+
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
|
174 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
175 |
+
"""
|
176 |
+
if model_size == "0.6":
|
177 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
178 |
+
if not torch.cuda.is_available():
|
179 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
180 |
+
|
181 |
+
|
182 |
+
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
|
183 |
+
gr.Markdown(title)
|
184 |
+
gr.HTML(DESCRIPTION)
|
185 |
+
|
186 |
+
with gr.Row(elem_id="main_row"):
|
187 |
+
with gr.Column(elem_id="column_input"):
|
188 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
189 |
+
with gr.Group():
|
190 |
+
canvas = gr.Sketchpad(
|
191 |
+
value=blank_image,
|
192 |
+
height=640,
|
193 |
+
image_mode="RGB",
|
194 |
+
sources=["upload", "clipboard"],
|
195 |
+
type="pil",
|
196 |
+
label="Sketch",
|
197 |
+
show_label=False,
|
198 |
+
show_download_button=True,
|
199 |
+
interactive=True,
|
200 |
+
transforms=[],
|
201 |
+
canvas_size=(1024, 1024),
|
202 |
+
scale=1,
|
203 |
+
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
|
204 |
+
format="png",
|
205 |
+
layers=False,
|
206 |
+
)
|
207 |
+
with gr.Row():
|
208 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
|
209 |
+
run_button = gr.Button("Run", scale=1, elem_id="run_button")
|
210 |
+
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
|
211 |
+
with gr.Row():
|
212 |
+
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
|
213 |
+
prompt_template = gr.Textbox(
|
214 |
+
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Row():
|
218 |
+
sketch_thickness = gr.Slider(
|
219 |
+
label="Sketch Thickness",
|
220 |
+
minimum=1,
|
221 |
+
maximum=4,
|
222 |
+
step=1,
|
223 |
+
value=2,
|
224 |
+
)
|
225 |
+
with gr.Row():
|
226 |
+
inference_steps = gr.Slider(
|
227 |
+
label="Sampling steps",
|
228 |
+
minimum=5,
|
229 |
+
maximum=40,
|
230 |
+
step=1,
|
231 |
+
value=20,
|
232 |
+
)
|
233 |
+
guidance_scale = gr.Slider(
|
234 |
+
label="CFG Guidance scale",
|
235 |
+
minimum=1,
|
236 |
+
maximum=10,
|
237 |
+
step=0.1,
|
238 |
+
value=4.5,
|
239 |
+
)
|
240 |
+
blend_alpha = gr.Slider(
|
241 |
+
label="Blend Alpha",
|
242 |
+
minimum=0,
|
243 |
+
maximum=1,
|
244 |
+
step=0.1,
|
245 |
+
value=0,
|
246 |
+
)
|
247 |
+
with gr.Row():
|
248 |
+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
|
249 |
+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
|
250 |
+
|
251 |
+
with gr.Column(elem_id="column_output"):
|
252 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
253 |
+
with gr.Group():
|
254 |
+
result = gr.Image(
|
255 |
+
format="png",
|
256 |
+
height=640,
|
257 |
+
image_mode="RGB",
|
258 |
+
type="pil",
|
259 |
+
label="Result",
|
260 |
+
show_label=False,
|
261 |
+
show_download_button=True,
|
262 |
+
interactive=False,
|
263 |
+
elem_id="output_image",
|
264 |
+
)
|
265 |
+
latency_result = gr.Text(label="Inference Latency", show_label=True)
|
266 |
+
|
267 |
+
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
|
268 |
+
gr.Markdown("### Instructions")
|
269 |
+
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
|
270 |
+
gr.Markdown("**2**. Start sketching or upload a reference image")
|
271 |
+
gr.Markdown("**3**. Change the image style using a style template")
|
272 |
+
gr.Markdown("**4**. Try different seeds to generate different results")
|
273 |
+
|
274 |
+
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
|
275 |
+
run_outputs = [result, latency_result]
|
276 |
+
|
277 |
+
randomize_seed.click(
|
278 |
+
lambda: random.randint(0, MAX_SEED),
|
279 |
+
inputs=[],
|
280 |
+
outputs=seed,
|
281 |
+
api_name=False,
|
282 |
+
queue=False,
|
283 |
+
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
284 |
+
|
285 |
+
style.change(
|
286 |
+
lambda x: STYLES[x],
|
287 |
+
inputs=[style],
|
288 |
+
outputs=[prompt_template],
|
289 |
+
api_name=False,
|
290 |
+
queue=False,
|
291 |
+
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
|
292 |
+
gr.on(
|
293 |
+
triggers=[prompt.submit, run_button.click, canvas.change],
|
294 |
+
fn=run,
|
295 |
+
inputs=run_inputs,
|
296 |
+
outputs=run_outputs,
|
297 |
+
api_name=False,
|
298 |
+
)
|
299 |
+
|
300 |
+
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
|
301 |
+
download_result.click(fn=save_image, inputs=result, outputs=download_result)
|
302 |
+
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
|
apps/app_sana_multithread.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# SPDX-License-Identifier: Apache-2.0
|
17 |
+
from __future__ import annotations
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import uuid
|
23 |
+
from datetime import datetime
|
24 |
+
|
25 |
+
import gradio as gr
|
26 |
+
import numpy as np
|
27 |
+
import spaces
|
28 |
+
import torch
|
29 |
+
from diffusers import FluxPipeline
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision.utils import make_grid, save_image
|
32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33 |
+
|
34 |
+
from app import safety_check
|
35 |
+
from app.sana_pipeline import SanaPipeline
|
36 |
+
|
37 |
+
MAX_SEED = np.iinfo(np.int32).max
|
38 |
+
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
39 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
40 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
41 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
42 |
+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
|
43 |
+
os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
|
44 |
+
|
45 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
46 |
+
|
47 |
+
style_list = [
|
48 |
+
{
|
49 |
+
"name": "(No style)",
|
50 |
+
"prompt": "{prompt}",
|
51 |
+
"negative_prompt": "",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Cinematic",
|
55 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
|
56 |
+
"cinemascope, moody, epic, gorgeous, film grain, grainy",
|
57 |
+
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"name": "Photographic",
|
61 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
62 |
+
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "Anime",
|
66 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
67 |
+
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"name": "Manga",
|
71 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
72 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "Digital Art",
|
76 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
77 |
+
"negative_prompt": "photo, photorealistic, realism, ugly",
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"name": "Pixel art",
|
81 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
82 |
+
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"name": "Fantasy art",
|
86 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
|
87 |
+
"majestic, magical, fantasy art, cover art, dreamy",
|
88 |
+
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
|
89 |
+
"glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
|
90 |
+
"disfigured, sloppy, duplicate, mutated, black and white",
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"name": "Neonpunk",
|
94 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
|
95 |
+
"detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
|
96 |
+
"ultra detailed, intricate, professional",
|
97 |
+
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "3D Model",
|
101 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
102 |
+
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
103 |
+
},
|
104 |
+
]
|
105 |
+
|
106 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
107 |
+
STYLE_NAMES = list(styles.keys())
|
108 |
+
DEFAULT_STYLE_NAME = "(No style)"
|
109 |
+
SCHEDULE_NAME = ["Flow_DPM_Solver"]
|
110 |
+
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
|
111 |
+
NUM_IMAGES_PER_PROMPT = 1
|
112 |
+
TEST_TIMES = 0
|
113 |
+
FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
|
114 |
+
|
115 |
+
|
116 |
+
def set_env(seed=0):
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
torch.set_grad_enabled(False)
|
119 |
+
|
120 |
+
|
121 |
+
def read_inference_count():
|
122 |
+
global TEST_TIMES
|
123 |
+
try:
|
124 |
+
with open(FILENAME) as f:
|
125 |
+
count = int(f.read().strip())
|
126 |
+
except FileNotFoundError:
|
127 |
+
count = 0
|
128 |
+
TEST_TIMES = count
|
129 |
+
|
130 |
+
return count
|
131 |
+
|
132 |
+
|
133 |
+
def write_inference_count(count):
|
134 |
+
with open(FILENAME, "w") as f:
|
135 |
+
f.write(str(count))
|
136 |
+
|
137 |
+
|
138 |
+
def run_inference(num_imgs=1):
|
139 |
+
TEST_TIMES = read_inference_count()
|
140 |
+
TEST_TIMES += int(num_imgs)
|
141 |
+
write_inference_count(TEST_TIMES)
|
142 |
+
|
143 |
+
return (
|
144 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
145 |
+
f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
def update_inference_count():
|
150 |
+
count = read_inference_count()
|
151 |
+
return (
|
152 |
+
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
|
153 |
+
f"16px; color:red; font-weight: bold;'>{count}</span>"
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
|
158 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
159 |
+
if not negative:
|
160 |
+
negative = ""
|
161 |
+
return p.replace("{prompt}", positive), n + negative
|
162 |
+
|
163 |
+
|
164 |
+
def get_args():
|
165 |
+
parser = argparse.ArgumentParser()
|
166 |
+
parser.add_argument("--config", type=str, help="config")
|
167 |
+
parser.add_argument(
|
168 |
+
"--model_path",
|
169 |
+
nargs="?",
|
170 |
+
default="output/Sana_D20/SANA.pth",
|
171 |
+
type=str,
|
172 |
+
help="Path to the model file (positional)",
|
173 |
+
)
|
174 |
+
parser.add_argument("--output", default="./", type=str)
|
175 |
+
parser.add_argument("--bs", default=1, type=int)
|
176 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
177 |
+
parser.add_argument("--cfg_scale", default=5.0, type=float)
|
178 |
+
parser.add_argument("--pag_scale", default=2.0, type=float)
|
179 |
+
parser.add_argument("--seed", default=42, type=int)
|
180 |
+
parser.add_argument("--step", default=-1, type=int)
|
181 |
+
parser.add_argument("--custom_image_size", default=None, type=int)
|
182 |
+
parser.add_argument(
|
183 |
+
"--shield_model_path",
|
184 |
+
type=str,
|
185 |
+
help="The path to shield model, we employ ShieldGemma-2B by default.",
|
186 |
+
default="google/shieldgemma-2b",
|
187 |
+
)
|
188 |
+
|
189 |
+
return parser.parse_args()
|
190 |
+
|
191 |
+
|
192 |
+
args = get_args()
|
193 |
+
|
194 |
+
if torch.cuda.is_available():
|
195 |
+
weight_dtype = torch.float16
|
196 |
+
model_path = args.model_path
|
197 |
+
pipe = SanaPipeline(args.config)
|
198 |
+
pipe.from_pretrained(model_path)
|
199 |
+
pipe.register_progress_bar(gr.Progress())
|
200 |
+
|
201 |
+
repo_name = "black-forest-labs/FLUX.1-dev"
|
202 |
+
pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
|
203 |
+
|
204 |
+
# safety checker
|
205 |
+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
206 |
+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
207 |
+
args.shield_model_path,
|
208 |
+
device_map="auto",
|
209 |
+
torch_dtype=torch.bfloat16,
|
210 |
+
).to(device)
|
211 |
+
|
212 |
+
set_env(42)
|
213 |
+
|
214 |
+
|
215 |
+
def save_image_sana(img, seed="", save_img=False):
|
216 |
+
unique_name = f"{str(uuid.uuid4())}_{seed}.png"
|
217 |
+
save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
|
218 |
+
os.umask(0o000) # file permission: 666; dir permission: 777
|
219 |
+
os.makedirs(save_path, exist_ok=True)
|
220 |
+
unique_name = os.path.join(save_path, unique_name)
|
221 |
+
if save_img:
|
222 |
+
save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
|
223 |
+
|
224 |
+
return unique_name
|
225 |
+
|
226 |
+
|
227 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
228 |
+
if randomize_seed:
|
229 |
+
seed = random.randint(0, MAX_SEED)
|
230 |
+
return seed
|
231 |
+
|
232 |
+
|
233 |
+
@spaces.GPU(enable_queue=True)
|
234 |
+
async def generate_2(
|
235 |
+
prompt: str = None,
|
236 |
+
negative_prompt: str = "",
|
237 |
+
style: str = DEFAULT_STYLE_NAME,
|
238 |
+
use_negative_prompt: bool = False,
|
239 |
+
num_imgs: int = 1,
|
240 |
+
seed: int = 0,
|
241 |
+
height: int = 1024,
|
242 |
+
width: int = 1024,
|
243 |
+
flow_dpms_guidance_scale: float = 5.0,
|
244 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
245 |
+
flow_dpms_inference_steps: int = 20,
|
246 |
+
randomize_seed: bool = False,
|
247 |
+
):
|
248 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
249 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
250 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
251 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
252 |
+
prompt = "A red heart."
|
253 |
+
|
254 |
+
print(prompt)
|
255 |
+
|
256 |
+
if not use_negative_prompt:
|
257 |
+
negative_prompt = None # type: ignore
|
258 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
images = pipe2(
|
262 |
+
prompt=prompt,
|
263 |
+
height=height,
|
264 |
+
width=width,
|
265 |
+
guidance_scale=3.5,
|
266 |
+
num_inference_steps=50,
|
267 |
+
num_images_per_prompt=num_imgs,
|
268 |
+
max_sequence_length=256,
|
269 |
+
generator=generator,
|
270 |
+
).images
|
271 |
+
|
272 |
+
save_img = False
|
273 |
+
img = images
|
274 |
+
if save_img:
|
275 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
276 |
+
print(img)
|
277 |
+
torch.cuda.empty_cache()
|
278 |
+
|
279 |
+
return img
|
280 |
+
|
281 |
+
|
282 |
+
@spaces.GPU(enable_queue=True)
|
283 |
+
async def generate(
|
284 |
+
prompt: str = None,
|
285 |
+
negative_prompt: str = "",
|
286 |
+
style: str = DEFAULT_STYLE_NAME,
|
287 |
+
use_negative_prompt: bool = False,
|
288 |
+
num_imgs: int = 1,
|
289 |
+
seed: int = 0,
|
290 |
+
height: int = 1024,
|
291 |
+
width: int = 1024,
|
292 |
+
flow_dpms_guidance_scale: float = 5.0,
|
293 |
+
flow_dpms_pag_guidance_scale: float = 2.0,
|
294 |
+
flow_dpms_inference_steps: int = 20,
|
295 |
+
randomize_seed: bool = False,
|
296 |
+
):
|
297 |
+
global TEST_TIMES
|
298 |
+
# seed = 823753551
|
299 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
300 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
301 |
+
print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
|
302 |
+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
|
303 |
+
prompt = "A red heart."
|
304 |
+
|
305 |
+
print(prompt)
|
306 |
+
|
307 |
+
num_inference_steps = flow_dpms_inference_steps
|
308 |
+
guidance_scale = flow_dpms_guidance_scale
|
309 |
+
pag_guidance_scale = flow_dpms_pag_guidance_scale
|
310 |
+
|
311 |
+
if not use_negative_prompt:
|
312 |
+
negative_prompt = None # type: ignore
|
313 |
+
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
314 |
+
|
315 |
+
pipe.progress_fn(0, desc="Sana Start")
|
316 |
+
|
317 |
+
with torch.no_grad():
|
318 |
+
images = pipe(
|
319 |
+
prompt=prompt,
|
320 |
+
height=height,
|
321 |
+
width=width,
|
322 |
+
negative_prompt=negative_prompt,
|
323 |
+
guidance_scale=guidance_scale,
|
324 |
+
pag_guidance_scale=pag_guidance_scale,
|
325 |
+
num_inference_steps=num_inference_steps,
|
326 |
+
num_images_per_prompt=num_imgs,
|
327 |
+
generator=generator,
|
328 |
+
)
|
329 |
+
|
330 |
+
pipe.progress_fn(1.0, desc="Sana End")
|
331 |
+
|
332 |
+
save_img = False
|
333 |
+
if save_img:
|
334 |
+
img = [save_image_sana(img, seed, save_img=save_image) for img in images]
|
335 |
+
print(img)
|
336 |
+
else:
|
337 |
+
if num_imgs > 1:
|
338 |
+
nrow = 2
|
339 |
+
else:
|
340 |
+
nrow = 1
|
341 |
+
img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
|
342 |
+
img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
343 |
+
img = [Image.fromarray(img.astype(np.uint8))]
|
344 |
+
|
345 |
+
torch.cuda.empty_cache()
|
346 |
+
|
347 |
+
return img
|
348 |
+
|
349 |
+
|
350 |
+
TEST_TIMES = read_inference_count()
|
351 |
+
model_size = "1.6" if "D20" in args.model_path else "0.6"
|
352 |
+
title = f"""
|
353 |
+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
|
354 |
+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
|
355 |
+
</div>
|
356 |
+
"""
|
357 |
+
DESCRIPTION = f"""
|
358 |
+
<p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
|
359 |
+
<p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
|
360 |
+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
|
361 |
+
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
|
362 |
+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
|
363 |
+
"""
|
364 |
+
if model_size == "0.6":
|
365 |
+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
|
366 |
+
if not torch.cuda.is_available():
|
367 |
+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
368 |
+
|
369 |
+
examples = [
|
370 |
+
'a cyberpunk cat with a neon sign that says "Sana"',
|
371 |
+
"A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
|
372 |
+
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
|
373 |
+
"portrait photo of a girl, photograph, highly detailed face, depth of field",
|
374 |
+
'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
|
375 |
+
"🐶 Wearing 🕶 flying on the 🌈",
|
376 |
+
# "👧 with 🌹 in the ❄️",
|
377 |
+
# "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
|
378 |
+
# "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
|
379 |
+
# "Astronaut in a jungle, cold color palette, muted colors, detailed",
|
380 |
+
# "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
|
381 |
+
]
|
382 |
+
|
383 |
+
css = """
|
384 |
+
.gradio-container{max-width: 1024px !important}
|
385 |
+
h1{text-align:center}
|
386 |
+
"""
|
387 |
+
with gr.Blocks(css=css) as demo:
|
388 |
+
gr.Markdown(title)
|
389 |
+
gr.Markdown(DESCRIPTION)
|
390 |
+
gr.DuplicateButton(
|
391 |
+
value="Duplicate Space for private use",
|
392 |
+
elem_id="duplicate-button",
|
393 |
+
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
|
394 |
+
)
|
395 |
+
info_box = gr.Markdown(
|
396 |
+
value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
|
397 |
+
)
|
398 |
+
demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
|
399 |
+
# with gr.Row(equal_height=False):
|
400 |
+
with gr.Group():
|
401 |
+
with gr.Row():
|
402 |
+
prompt = gr.Text(
|
403 |
+
label="Prompt",
|
404 |
+
show_label=False,
|
405 |
+
max_lines=1,
|
406 |
+
placeholder="Enter your prompt",
|
407 |
+
container=False,
|
408 |
+
)
|
409 |
+
run_button = gr.Button("Run-sana", scale=0)
|
410 |
+
run_button2 = gr.Button("Run-flux", scale=0)
|
411 |
+
|
412 |
+
with gr.Row():
|
413 |
+
result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
|
414 |
+
result_2 = gr.Gallery(
|
415 |
+
label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
|
416 |
+
)
|
417 |
+
|
418 |
+
with gr.Accordion("Advanced options", open=False):
|
419 |
+
with gr.Group():
|
420 |
+
with gr.Row(visible=True):
|
421 |
+
height = gr.Slider(
|
422 |
+
label="Height",
|
423 |
+
minimum=256,
|
424 |
+
maximum=MAX_IMAGE_SIZE,
|
425 |
+
step=32,
|
426 |
+
value=1024,
|
427 |
+
)
|
428 |
+
width = gr.Slider(
|
429 |
+
label="Width",
|
430 |
+
minimum=256,
|
431 |
+
maximum=MAX_IMAGE_SIZE,
|
432 |
+
step=32,
|
433 |
+
value=1024,
|
434 |
+
)
|
435 |
+
with gr.Row():
|
436 |
+
flow_dpms_inference_steps = gr.Slider(
|
437 |
+
label="Sampling steps",
|
438 |
+
minimum=5,
|
439 |
+
maximum=40,
|
440 |
+
step=1,
|
441 |
+
value=18,
|
442 |
+
)
|
443 |
+
flow_dpms_guidance_scale = gr.Slider(
|
444 |
+
label="CFG Guidance scale",
|
445 |
+
minimum=1,
|
446 |
+
maximum=10,
|
447 |
+
step=0.1,
|
448 |
+
value=5.0,
|
449 |
+
)
|
450 |
+
flow_dpms_pag_guidance_scale = gr.Slider(
|
451 |
+
label="PAG Guidance scale",
|
452 |
+
minimum=1,
|
453 |
+
maximum=4,
|
454 |
+
step=0.5,
|
455 |
+
value=2.0,
|
456 |
+
)
|
457 |
+
with gr.Row():
|
458 |
+
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
|
459 |
+
negative_prompt = gr.Text(
|
460 |
+
label="Negative prompt",
|
461 |
+
max_lines=1,
|
462 |
+
placeholder="Enter a negative prompt",
|
463 |
+
visible=True,
|
464 |
+
)
|
465 |
+
style_selection = gr.Radio(
|
466 |
+
show_label=True,
|
467 |
+
container=True,
|
468 |
+
interactive=True,
|
469 |
+
choices=STYLE_NAMES,
|
470 |
+
value=DEFAULT_STYLE_NAME,
|
471 |
+
label="Image Style",
|
472 |
+
)
|
473 |
+
seed = gr.Slider(
|
474 |
+
label="Seed",
|
475 |
+
minimum=0,
|
476 |
+
maximum=MAX_SEED,
|
477 |
+
step=1,
|
478 |
+
value=0,
|
479 |
+
)
|
480 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
481 |
+
with gr.Row(visible=True):
|
482 |
+
schedule = gr.Radio(
|
483 |
+
show_label=True,
|
484 |
+
container=True,
|
485 |
+
interactive=True,
|
486 |
+
choices=SCHEDULE_NAME,
|
487 |
+
value=DEFAULT_SCHEDULE_NAME,
|
488 |
+
label="Sampler Schedule",
|
489 |
+
visible=True,
|
490 |
+
)
|
491 |
+
num_imgs = gr.Slider(
|
492 |
+
label="Num Images",
|
493 |
+
minimum=1,
|
494 |
+
maximum=6,
|
495 |
+
step=1,
|
496 |
+
value=1,
|
497 |
+
)
|
498 |
+
|
499 |
+
run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
|
500 |
+
|
501 |
+
gr.Examples(
|
502 |
+
examples=examples,
|
503 |
+
inputs=prompt,
|
504 |
+
outputs=[result],
|
505 |
+
fn=generate,
|
506 |
+
cache_examples=CACHE_EXAMPLES,
|
507 |
+
)
|
508 |
+
gr.Examples(
|
509 |
+
examples=examples,
|
510 |
+
inputs=prompt,
|
511 |
+
outputs=[result_2],
|
512 |
+
fn=generate_2,
|
513 |
+
cache_examples=CACHE_EXAMPLES,
|
514 |
+
)
|
515 |
+
|
516 |
+
use_negative_prompt.change(
|
517 |
+
fn=lambda x: gr.update(visible=x),
|
518 |
+
inputs=use_negative_prompt,
|
519 |
+
outputs=negative_prompt,
|
520 |
+
api_name=False,
|
521 |
+
)
|
522 |
+
|
523 |
+
run_button.click(
|
524 |
+
fn=generate,
|
525 |
+
inputs=[
|
526 |
+
prompt,
|
527 |
+
negative_prompt,
|
528 |
+
style_selection,
|
529 |
+
use_negative_prompt,
|
530 |
+
num_imgs,
|
531 |
+
seed,
|
532 |
+
height,
|
533 |
+
width,
|
534 |
+
flow_dpms_guidance_scale,
|
535 |
+
flow_dpms_pag_guidance_scale,
|
536 |
+
flow_dpms_inference_steps,
|
537 |
+
randomize_seed,
|
538 |
+
],
|
539 |
+
outputs=[result],
|
540 |
+
queue=True,
|
541 |
+
)
|
542 |
+
|
543 |
+
run_button2.click(
|
544 |
+
fn=generate_2,
|
545 |
+
inputs=[
|
546 |
+
prompt,
|
547 |
+
negative_prompt,
|
548 |
+
style_selection,
|
549 |
+
use_negative_prompt,
|
550 |
+
num_imgs,
|
551 |
+
seed,
|
552 |
+
height,
|
553 |
+
width,
|
554 |
+
flow_dpms_guidance_scale,
|
555 |
+
flow_dpms_pag_guidance_scale,
|
556 |
+
flow_dpms_inference_steps,
|
557 |
+
randomize_seed,
|
558 |
+
],
|
559 |
+
outputs=[result_2],
|
560 |
+
queue=True,
|
561 |
+
)
|
562 |
+
|
563 |
+
|
564 |
+
if __name__ == "__main__":
|
565 |
+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
|
apps/safety_check.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
|
17 |
+
import json
|
18 |
+
from functools import lru_cache
|
19 |
+
from typing import Dict, Sequence
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import transformers
|
23 |
+
from torch.nn.functional import softmax
|
24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
25 |
+
|
26 |
+
__all__ = ["is_dangerous"]
|
27 |
+
|
28 |
+
__rank, __local_rank, __world_size, __device = (
|
29 |
+
0,
|
30 |
+
0,
|
31 |
+
1,
|
32 |
+
"cuda" if torch.cuda.is_available() else "cpu",
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
safety_policy = """
|
37 |
+
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
|
42 |
+
prompt = f"""You are a policy expert trying to help determine whether a user
|
43 |
+
prompt is in violation of the defined safety policies.
|
44 |
+
|
45 |
+
<start_of_turn>
|
46 |
+
Human Question: {user_prompt.strip()}
|
47 |
+
<end_of_turn>
|
48 |
+
|
49 |
+
Our safety principle is defined in the below:
|
50 |
+
|
51 |
+
{safety_policy.strip()}
|
52 |
+
|
53 |
+
Does the human question violate the above principle? Your answer must start
|
54 |
+
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
|
55 |
+
correctly.
|
56 |
+
"""
|
57 |
+
|
58 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
59 |
+
with torch.no_grad():
|
60 |
+
logits = model(**inputs).logits
|
61 |
+
|
62 |
+
# Extract the logits for the Yes and No tokens
|
63 |
+
vocab = tokenizer.get_vocab()
|
64 |
+
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
|
65 |
+
|
66 |
+
# Convert these logits to a probability with softmax
|
67 |
+
probabilities = softmax(selected_logits, dim=0)
|
68 |
+
|
69 |
+
# Return probability of 'Yes'
|
70 |
+
score = probabilities[0].item()
|
71 |
+
|
72 |
+
return score > threshold
|
apps/sana_controlnet_pipeline.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
import warnings
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import pyrallis
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
warnings.filterwarnings("ignore") # ignore warning
|
27 |
+
|
28 |
+
|
29 |
+
from diffusion import DPMS, FlowEuler
|
30 |
+
from diffusion.data.datasets.utils import (
|
31 |
+
ASPECT_RATIO_512_TEST,
|
32 |
+
ASPECT_RATIO_1024_TEST,
|
33 |
+
ASPECT_RATIO_2048_TEST,
|
34 |
+
ASPECT_RATIO_4096_TEST,
|
35 |
+
)
|
36 |
+
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode
|
37 |
+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
38 |
+
from diffusion.utils.config import SanaConfig, model_init_config
|
39 |
+
from diffusion.utils.logger import get_root_logger
|
40 |
+
from tools.controlnet.utils import get_scribble_map, transform_control_signal
|
41 |
+
from tools.download import find_model
|
42 |
+
|
43 |
+
|
44 |
+
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
45 |
+
guidance_type = default_guidance_type
|
46 |
+
if not (pag_scale > 1.0 and attn_type == "linear"):
|
47 |
+
guidance_type = "classifier-free"
|
48 |
+
elif pag_scale > 1.0 and attn_type == "linear":
|
49 |
+
guidance_type = "classifier-free_PAG"
|
50 |
+
return guidance_type
|
51 |
+
|
52 |
+
|
53 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
54 |
+
"""Returns binned height and width."""
|
55 |
+
ar = float(height / width)
|
56 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
57 |
+
default_hw = ratios[closest_ratio]
|
58 |
+
return int(default_hw[0]), int(default_hw[1])
|
59 |
+
|
60 |
+
|
61 |
+
def get_ar_from_ref_image(ref_image):
|
62 |
+
def reduce_ratio(h, w):
|
63 |
+
def gcd(a, b):
|
64 |
+
while b:
|
65 |
+
a, b = b, a % b
|
66 |
+
return a
|
67 |
+
|
68 |
+
divisor = gcd(h, w)
|
69 |
+
return f"{h // divisor}:{w // divisor}"
|
70 |
+
|
71 |
+
if isinstance(ref_image, str):
|
72 |
+
ref_image = Image.open(ref_image)
|
73 |
+
w, h = ref_image.size
|
74 |
+
return reduce_ratio(h, w)
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class SanaControlNetInference(SanaConfig):
|
79 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
80 |
+
model_path: str = field(
|
81 |
+
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
82 |
+
)
|
83 |
+
output: str = "./output"
|
84 |
+
bs: int = 1
|
85 |
+
image_size: int = 1024
|
86 |
+
cfg_scale: float = 5.0
|
87 |
+
pag_scale: float = 2.0
|
88 |
+
seed: int = 42
|
89 |
+
step: int = -1
|
90 |
+
custom_image_size: Optional[int] = None
|
91 |
+
shield_model_path: str = field(
|
92 |
+
default="google/shieldgemma-2b",
|
93 |
+
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
class SanaControlNetPipeline(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
config = pyrallis.load(SanaControlNetInference, open(config))
|
104 |
+
self.args = self.config = config
|
105 |
+
|
106 |
+
# set some hyper-parameters
|
107 |
+
self.image_size = self.config.model.image_size
|
108 |
+
|
109 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
110 |
+
logger = get_root_logger()
|
111 |
+
self.logger = logger
|
112 |
+
self.progress_fn = lambda progress, desc: None
|
113 |
+
self.thickness = 2
|
114 |
+
self.blend_alpha = 0.0
|
115 |
+
|
116 |
+
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
117 |
+
self.max_sequence_length = config.text_encoder.model_max_length
|
118 |
+
self.flow_shift = config.scheduler.flow_shift
|
119 |
+
guidance_type = "classifier-free_PAG"
|
120 |
+
|
121 |
+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
122 |
+
self.weight_dtype = weight_dtype
|
123 |
+
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
124 |
+
|
125 |
+
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
126 |
+
self.vis_sampler = self.config.scheduler.vis_sampler
|
127 |
+
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
128 |
+
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
129 |
+
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
130 |
+
|
131 |
+
# 1. build vae and text encoder
|
132 |
+
self.vae = self.build_vae(config.vae)
|
133 |
+
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
134 |
+
|
135 |
+
# 2. build Sana model
|
136 |
+
self.model = self.build_sana_model(config).to(self.device)
|
137 |
+
|
138 |
+
# 3. pre-compute null embedding
|
139 |
+
with torch.no_grad():
|
140 |
+
null_caption_token = self.tokenizer(
|
141 |
+
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
142 |
+
).to(self.device)
|
143 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
144 |
+
0
|
145 |
+
]
|
146 |
+
|
147 |
+
def build_vae(self, config):
|
148 |
+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
149 |
+
return vae
|
150 |
+
|
151 |
+
def build_text_encoder(self, config):
|
152 |
+
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
153 |
+
return tokenizer, text_encoder
|
154 |
+
|
155 |
+
def build_sana_model(self, config):
|
156 |
+
# model setting
|
157 |
+
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
158 |
+
model = build_model(
|
159 |
+
config.model.model,
|
160 |
+
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
161 |
+
**model_kwargs,
|
162 |
+
)
|
163 |
+
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
164 |
+
self.logger.info(
|
165 |
+
f"{model.__class__.__name__}:{config.model.model},"
|
166 |
+
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
167 |
+
)
|
168 |
+
return model
|
169 |
+
|
170 |
+
def from_pretrained(self, model_path):
|
171 |
+
state_dict = find_model(model_path)
|
172 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
173 |
+
if "pos_embed" in state_dict:
|
174 |
+
del state_dict["pos_embed"]
|
175 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
176 |
+
self.model.eval().to(self.weight_dtype)
|
177 |
+
|
178 |
+
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
179 |
+
self.logger.warning(f"Missing keys: {missing}")
|
180 |
+
self.logger.warning(f"Unexpected keys: {unexpected}")
|
181 |
+
|
182 |
+
def register_progress_bar(self, progress_fn=None):
|
183 |
+
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
184 |
+
|
185 |
+
def set_blend_alpha(self, blend_alpha):
|
186 |
+
self.blend_alpha = blend_alpha
|
187 |
+
|
188 |
+
@torch.inference_mode()
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
prompt=None,
|
192 |
+
ref_image=None,
|
193 |
+
negative_prompt="",
|
194 |
+
num_inference_steps=20,
|
195 |
+
guidance_scale=5,
|
196 |
+
pag_guidance_scale=2.5,
|
197 |
+
num_images_per_prompt=1,
|
198 |
+
sketch_thickness=2,
|
199 |
+
generator=torch.Generator().manual_seed(42),
|
200 |
+
latents=None,
|
201 |
+
):
|
202 |
+
self.ori_height, self.ori_width = ref_image.height, ref_image.width
|
203 |
+
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
204 |
+
|
205 |
+
# 1. pre-compute negative embedding
|
206 |
+
if negative_prompt != "":
|
207 |
+
null_caption_token = self.tokenizer(
|
208 |
+
negative_prompt,
|
209 |
+
max_length=self.max_sequence_length,
|
210 |
+
padding="max_length",
|
211 |
+
truncation=True,
|
212 |
+
return_tensors="pt",
|
213 |
+
).to(self.device)
|
214 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
215 |
+
0
|
216 |
+
]
|
217 |
+
|
218 |
+
if prompt is None:
|
219 |
+
prompt = [""]
|
220 |
+
prompts = prompt if isinstance(prompt, list) else [prompt]
|
221 |
+
samples = []
|
222 |
+
|
223 |
+
for prompt in prompts:
|
224 |
+
# data prepare
|
225 |
+
prompts, hw, ar = (
|
226 |
+
[],
|
227 |
+
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
228 |
+
num_images_per_prompt, 1
|
229 |
+
),
|
230 |
+
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
231 |
+
)
|
232 |
+
|
233 |
+
ar = get_ar_from_ref_image(ref_image)
|
234 |
+
prompt += f" --ar {ar}"
|
235 |
+
for _ in range(num_images_per_prompt):
|
236 |
+
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(
|
237 |
+
prompt, self.base_ratios, device=self.device, show=False
|
238 |
+
)
|
239 |
+
prompts.append(prompt_clean.strip())
|
240 |
+
|
241 |
+
self.latent_size_h, self.latent_size_w = (
|
242 |
+
int(hw[0, 0] // self.config.vae.vae_downsample_rate),
|
243 |
+
int(hw[0, 1] // self.config.vae.vae_downsample_rate),
|
244 |
+
)
|
245 |
+
|
246 |
+
with torch.no_grad():
|
247 |
+
# prepare text feature
|
248 |
+
if not self.config.text_encoder.chi_prompt:
|
249 |
+
max_length_all = self.config.text_encoder.model_max_length
|
250 |
+
prompts_all = prompts
|
251 |
+
else:
|
252 |
+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
253 |
+
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
254 |
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
255 |
+
max_length_all = (
|
256 |
+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
257 |
+
) # magic number 2: [bos], [_]
|
258 |
+
|
259 |
+
caption_token = self.tokenizer(
|
260 |
+
prompts_all,
|
261 |
+
max_length=max_length_all,
|
262 |
+
padding="max_length",
|
263 |
+
truncation=True,
|
264 |
+
return_tensors="pt",
|
265 |
+
).to(device=self.device)
|
266 |
+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
267 |
+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
268 |
+
:, :, select_index
|
269 |
+
].to(self.weight_dtype)
|
270 |
+
emb_masks = caption_token.attention_mask[:, select_index]
|
271 |
+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
272 |
+
|
273 |
+
n = len(prompts)
|
274 |
+
if latents is None:
|
275 |
+
z = torch.randn(
|
276 |
+
n,
|
277 |
+
self.config.vae.vae_latent_dim,
|
278 |
+
self.latent_size_h,
|
279 |
+
self.latent_size_w,
|
280 |
+
generator=generator,
|
281 |
+
device=self.device,
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
z = latents.to(self.device)
|
285 |
+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
286 |
+
|
287 |
+
# control signal
|
288 |
+
if isinstance(ref_image, str):
|
289 |
+
ref_image = cv2.imread(ref_image)
|
290 |
+
elif isinstance(ref_image, Image.Image):
|
291 |
+
ref_image = np.array(ref_image)
|
292 |
+
control_signal = get_scribble_map(
|
293 |
+
input_image=ref_image,
|
294 |
+
det="Scribble_HED",
|
295 |
+
detect_resolution=int(hw.min()),
|
296 |
+
thickness=sketch_thickness,
|
297 |
+
)
|
298 |
+
|
299 |
+
control_signal = transform_control_signal(control_signal, hw).to(self.device).to(self.weight_dtype)
|
300 |
+
|
301 |
+
control_signal_latent = vae_encode(
|
302 |
+
self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
|
303 |
+
)
|
304 |
+
|
305 |
+
model_kwargs["control_signal"] = control_signal_latent
|
306 |
+
|
307 |
+
if self.vis_sampler == "flow_euler":
|
308 |
+
flow_solver = FlowEuler(
|
309 |
+
self.model,
|
310 |
+
condition=caption_embs,
|
311 |
+
uncondition=null_y,
|
312 |
+
cfg_scale=guidance_scale,
|
313 |
+
model_kwargs=model_kwargs,
|
314 |
+
)
|
315 |
+
sample = flow_solver.sample(
|
316 |
+
z,
|
317 |
+
steps=num_inference_steps,
|
318 |
+
)
|
319 |
+
elif self.vis_sampler == "flow_dpm-solver":
|
320 |
+
scheduler = DPMS(
|
321 |
+
self.model.forward_with_dpmsolver,
|
322 |
+
condition=caption_embs,
|
323 |
+
uncondition=null_y,
|
324 |
+
guidance_type=self.guidance_type,
|
325 |
+
cfg_scale=guidance_scale,
|
326 |
+
model_type="flow",
|
327 |
+
model_kwargs=model_kwargs,
|
328 |
+
schedule="FLOW",
|
329 |
+
)
|
330 |
+
scheduler.register_progress_bar(self.progress_fn)
|
331 |
+
sample = scheduler.sample(
|
332 |
+
z,
|
333 |
+
steps=num_inference_steps,
|
334 |
+
order=2,
|
335 |
+
skip_type="time_uniform_flow",
|
336 |
+
method="multistep",
|
337 |
+
flow_shift=self.flow_shift,
|
338 |
+
)
|
339 |
+
|
340 |
+
sample = sample.to(self.vae_dtype)
|
341 |
+
with torch.no_grad():
|
342 |
+
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
343 |
+
|
344 |
+
if self.blend_alpha > 0:
|
345 |
+
print(f"blend image and mask with alpha: {self.blend_alpha}")
|
346 |
+
sample = sample * (1 - self.blend_alpha) + control_signal * self.blend_alpha
|
347 |
+
|
348 |
+
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
349 |
+
samples.append(sample)
|
350 |
+
|
351 |
+
return sample
|
352 |
+
|
353 |
+
return samples
|
apps/sana_pipeline.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
import argparse
|
17 |
+
import warnings
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import Optional, Tuple
|
20 |
+
|
21 |
+
import pyrallis
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
warnings.filterwarnings("ignore") # ignore warning
|
26 |
+
|
27 |
+
|
28 |
+
from diffusion import DPMS, FlowEuler
|
29 |
+
from diffusion.data.datasets.utils import (
|
30 |
+
ASPECT_RATIO_512_TEST,
|
31 |
+
ASPECT_RATIO_1024_TEST,
|
32 |
+
ASPECT_RATIO_2048_TEST,
|
33 |
+
ASPECT_RATIO_4096_TEST,
|
34 |
+
)
|
35 |
+
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
|
36 |
+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
|
37 |
+
from diffusion.utils.config import SanaConfig, model_init_config
|
38 |
+
from diffusion.utils.logger import get_root_logger
|
39 |
+
|
40 |
+
# from diffusion.utils.misc import read_config
|
41 |
+
from tools.download import find_model
|
42 |
+
|
43 |
+
|
44 |
+
def guidance_type_select(default_guidance_type, pag_scale, attn_type):
|
45 |
+
guidance_type = default_guidance_type
|
46 |
+
if not (pag_scale > 1.0 and attn_type == "linear"):
|
47 |
+
guidance_type = "classifier-free"
|
48 |
+
elif pag_scale > 1.0 and attn_type == "linear":
|
49 |
+
guidance_type = "classifier-free_PAG"
|
50 |
+
return guidance_type
|
51 |
+
|
52 |
+
|
53 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
54 |
+
"""Returns binned height and width."""
|
55 |
+
ar = float(height / width)
|
56 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
57 |
+
default_hw = ratios[closest_ratio]
|
58 |
+
return int(default_hw[0]), int(default_hw[1])
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass
|
62 |
+
class SanaInference(SanaConfig):
|
63 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
|
64 |
+
model_path: str = field(
|
65 |
+
default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
|
66 |
+
)
|
67 |
+
output: str = "./output"
|
68 |
+
bs: int = 1
|
69 |
+
image_size: int = 1024
|
70 |
+
cfg_scale: float = 5.0
|
71 |
+
pag_scale: float = 2.0
|
72 |
+
seed: int = 42
|
73 |
+
step: int = -1
|
74 |
+
custom_image_size: Optional[int] = None
|
75 |
+
shield_model_path: str = field(
|
76 |
+
default="google/shieldgemma-2b",
|
77 |
+
metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
class SanaPipeline(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
config = pyrallis.load(SanaInference, open(config))
|
88 |
+
self.args = self.config = config
|
89 |
+
|
90 |
+
# set some hyper-parameters
|
91 |
+
self.image_size = self.config.model.image_size
|
92 |
+
|
93 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
94 |
+
logger = get_root_logger()
|
95 |
+
self.logger = logger
|
96 |
+
self.progress_fn = lambda progress, desc: None
|
97 |
+
|
98 |
+
self.latent_size = self.image_size // config.vae.vae_downsample_rate
|
99 |
+
self.max_sequence_length = config.text_encoder.model_max_length
|
100 |
+
self.flow_shift = config.scheduler.flow_shift
|
101 |
+
guidance_type = "classifier-free_PAG"
|
102 |
+
|
103 |
+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
|
104 |
+
self.weight_dtype = weight_dtype
|
105 |
+
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
|
106 |
+
|
107 |
+
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
|
108 |
+
self.vis_sampler = self.config.scheduler.vis_sampler
|
109 |
+
logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
|
110 |
+
self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
|
111 |
+
logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
|
112 |
+
|
113 |
+
# 1. build vae and text encoder
|
114 |
+
self.vae = self.build_vae(config.vae)
|
115 |
+
self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
|
116 |
+
|
117 |
+
# 2. build Sana model
|
118 |
+
self.model = self.build_sana_model(config).to(self.device)
|
119 |
+
|
120 |
+
# 3. pre-compute null embedding
|
121 |
+
with torch.no_grad():
|
122 |
+
null_caption_token = self.tokenizer(
|
123 |
+
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
|
124 |
+
).to(self.device)
|
125 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
126 |
+
0
|
127 |
+
]
|
128 |
+
|
129 |
+
def build_vae(self, config):
|
130 |
+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
|
131 |
+
return vae
|
132 |
+
|
133 |
+
def build_text_encoder(self, config):
|
134 |
+
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
|
135 |
+
return tokenizer, text_encoder
|
136 |
+
|
137 |
+
def build_sana_model(self, config):
|
138 |
+
# model setting
|
139 |
+
model_kwargs = model_init_config(config, latent_size=self.latent_size)
|
140 |
+
model = build_model(
|
141 |
+
config.model.model,
|
142 |
+
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
|
143 |
+
**model_kwargs,
|
144 |
+
)
|
145 |
+
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
|
146 |
+
self.logger.info(
|
147 |
+
f"{model.__class__.__name__}:{config.model.model},"
|
148 |
+
f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
|
149 |
+
)
|
150 |
+
return model
|
151 |
+
|
152 |
+
def from_pretrained(self, model_path):
|
153 |
+
state_dict = find_model(model_path)
|
154 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
155 |
+
if "pos_embed" in state_dict:
|
156 |
+
del state_dict["pos_embed"]
|
157 |
+
missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
|
158 |
+
self.model.eval().to(self.weight_dtype)
|
159 |
+
|
160 |
+
self.logger.info("Generating sample from ckpt: %s" % model_path)
|
161 |
+
self.logger.warning(f"Missing keys: {missing}")
|
162 |
+
self.logger.warning(f"Unexpected keys: {unexpected}")
|
163 |
+
|
164 |
+
def register_progress_bar(self, progress_fn=None):
|
165 |
+
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
|
166 |
+
|
167 |
+
@torch.inference_mode()
|
168 |
+
def forward(
|
169 |
+
self,
|
170 |
+
prompt=None,
|
171 |
+
height=1024,
|
172 |
+
width=1024,
|
173 |
+
negative_prompt="",
|
174 |
+
num_inference_steps=20,
|
175 |
+
guidance_scale=5,
|
176 |
+
pag_guidance_scale=2.5,
|
177 |
+
num_images_per_prompt=1,
|
178 |
+
generator=torch.Generator().manual_seed(42),
|
179 |
+
latents=None,
|
180 |
+
):
|
181 |
+
self.ori_height, self.ori_width = height, width
|
182 |
+
self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
|
183 |
+
self.latent_size_h, self.latent_size_w = (
|
184 |
+
self.height // self.config.vae.vae_downsample_rate,
|
185 |
+
self.width // self.config.vae.vae_downsample_rate,
|
186 |
+
)
|
187 |
+
self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
|
188 |
+
|
189 |
+
# 1. pre-compute negative embedding
|
190 |
+
if negative_prompt != "":
|
191 |
+
null_caption_token = self.tokenizer(
|
192 |
+
negative_prompt,
|
193 |
+
max_length=self.max_sequence_length,
|
194 |
+
padding="max_length",
|
195 |
+
truncation=True,
|
196 |
+
return_tensors="pt",
|
197 |
+
).to(self.device)
|
198 |
+
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
|
199 |
+
0
|
200 |
+
]
|
201 |
+
|
202 |
+
if prompt is None:
|
203 |
+
prompt = [""]
|
204 |
+
prompts = prompt if isinstance(prompt, list) else [prompt]
|
205 |
+
samples = []
|
206 |
+
|
207 |
+
for prompt in prompts:
|
208 |
+
# data prepare
|
209 |
+
prompts, hw, ar = (
|
210 |
+
[],
|
211 |
+
torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
|
212 |
+
num_images_per_prompt, 1
|
213 |
+
),
|
214 |
+
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
|
215 |
+
)
|
216 |
+
|
217 |
+
for _ in range(num_images_per_prompt):
|
218 |
+
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
|
219 |
+
|
220 |
+
with torch.no_grad():
|
221 |
+
# prepare text feature
|
222 |
+
if not self.config.text_encoder.chi_prompt:
|
223 |
+
max_length_all = self.config.text_encoder.model_max_length
|
224 |
+
prompts_all = prompts
|
225 |
+
else:
|
226 |
+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
|
227 |
+
prompts_all = [chi_prompt + prompt for prompt in prompts]
|
228 |
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
229 |
+
max_length_all = (
|
230 |
+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
|
231 |
+
) # magic number 2: [bos], [_]
|
232 |
+
|
233 |
+
caption_token = self.tokenizer(
|
234 |
+
prompts_all,
|
235 |
+
max_length=max_length_all,
|
236 |
+
padding="max_length",
|
237 |
+
truncation=True,
|
238 |
+
return_tensors="pt",
|
239 |
+
).to(device=self.device)
|
240 |
+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
|
241 |
+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
|
242 |
+
:, :, select_index
|
243 |
+
].to(self.weight_dtype)
|
244 |
+
emb_masks = caption_token.attention_mask[:, select_index]
|
245 |
+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
|
246 |
+
|
247 |
+
n = len(prompts)
|
248 |
+
if latents is None:
|
249 |
+
z = torch.randn(
|
250 |
+
n,
|
251 |
+
self.config.vae.vae_latent_dim,
|
252 |
+
self.latent_size_h,
|
253 |
+
self.latent_size_w,
|
254 |
+
generator=generator,
|
255 |
+
device=self.device,
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
z = latents.to(self.device)
|
259 |
+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
|
260 |
+
if self.vis_sampler == "flow_euler":
|
261 |
+
flow_solver = FlowEuler(
|
262 |
+
self.model,
|
263 |
+
condition=caption_embs,
|
264 |
+
uncondition=null_y,
|
265 |
+
cfg_scale=guidance_scale,
|
266 |
+
model_kwargs=model_kwargs,
|
267 |
+
)
|
268 |
+
sample = flow_solver.sample(
|
269 |
+
z,
|
270 |
+
steps=num_inference_steps,
|
271 |
+
)
|
272 |
+
elif self.vis_sampler == "flow_dpm-solver":
|
273 |
+
scheduler = DPMS(
|
274 |
+
self.model,
|
275 |
+
condition=caption_embs,
|
276 |
+
uncondition=null_y,
|
277 |
+
guidance_type=self.guidance_type,
|
278 |
+
cfg_scale=guidance_scale,
|
279 |
+
pag_scale=pag_guidance_scale,
|
280 |
+
pag_applied_layers=self.config.model.pag_applied_layers,
|
281 |
+
model_type="flow",
|
282 |
+
model_kwargs=model_kwargs,
|
283 |
+
schedule="FLOW",
|
284 |
+
)
|
285 |
+
scheduler.register_progress_bar(self.progress_fn)
|
286 |
+
sample = scheduler.sample(
|
287 |
+
z,
|
288 |
+
steps=num_inference_steps,
|
289 |
+
order=2,
|
290 |
+
skip_type="time_uniform_flow",
|
291 |
+
method="multistep",
|
292 |
+
flow_shift=self.flow_shift,
|
293 |
+
)
|
294 |
+
|
295 |
+
sample = sample.to(self.vae_dtype)
|
296 |
+
with torch.no_grad():
|
297 |
+
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
|
298 |
+
|
299 |
+
sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
|
300 |
+
samples.append(sample)
|
301 |
+
|
302 |
+
return sample
|
303 |
+
|
304 |
+
return samples
|