zzc0208 commited on
Commit
18ce78c
·
verified ·
1 Parent(s): 0bc6b41

Upload 8 files

Browse files
apps/app_sana.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # SPDX-License-Identifier: Apache-2.0
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import os
21
+ import random
22
+ import socket
23
+ import sqlite3
24
+ import time
25
+ import uuid
26
+ from datetime import datetime
27
+
28
+ import gradio as gr
29
+ import numpy as np
30
+ import spaces
31
+ import torch
32
+ from PIL import Image
33
+ from torchvision.utils import make_grid, save_image
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer
35
+
36
+ from app import safety_check
37
+ from app.sana_pipeline import SanaPipeline
38
+
39
+ MAX_SEED = np.iinfo(np.int32).max
40
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
41
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
42
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
43
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
44
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
45
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
46
+ COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
47
+
48
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
+
50
+ style_list = [
51
+ {
52
+ "name": "(No style)",
53
+ "prompt": "{prompt}",
54
+ "negative_prompt": "",
55
+ },
56
+ {
57
+ "name": "Cinematic",
58
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
59
+ "cinemascope, moody, epic, gorgeous, film grain, grainy",
60
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
61
+ },
62
+ {
63
+ "name": "Photographic",
64
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
65
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
66
+ },
67
+ {
68
+ "name": "Anime",
69
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
70
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
71
+ },
72
+ {
73
+ "name": "Manga",
74
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
75
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
76
+ },
77
+ {
78
+ "name": "Digital Art",
79
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
80
+ "negative_prompt": "photo, photorealistic, realism, ugly",
81
+ },
82
+ {
83
+ "name": "Pixel art",
84
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
85
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
86
+ },
87
+ {
88
+ "name": "Fantasy art",
89
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
90
+ "majestic, magical, fantasy art, cover art, dreamy",
91
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
92
+ "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
93
+ "disfigured, sloppy, duplicate, mutated, black and white",
94
+ },
95
+ {
96
+ "name": "Neonpunk",
97
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
98
+ "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
99
+ "ultra detailed, intricate, professional",
100
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
101
+ },
102
+ {
103
+ "name": "3D Model",
104
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
105
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
106
+ },
107
+ ]
108
+
109
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
110
+ STYLE_NAMES = list(styles.keys())
111
+ DEFAULT_STYLE_NAME = "(No style)"
112
+ SCHEDULE_NAME = ["Flow_DPM_Solver"]
113
+ DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
114
+ NUM_IMAGES_PER_PROMPT = 1
115
+ INFER_SPEED = 0
116
+
117
+
118
+ def norm_ip(img, low, high):
119
+ img.clamp_(min=low, max=high)
120
+ img.sub_(low).div_(max(high - low, 1e-5))
121
+ return img
122
+
123
+
124
+ def open_db():
125
+ db = sqlite3.connect(COUNTER_DB)
126
+ db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
127
+ db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)')
128
+ return db
129
+
130
+
131
+ def read_inference_count():
132
+ with open_db() as db:
133
+ cur = db.execute('SELECT value FROM counter WHERE app="Sana"')
134
+ db.commit()
135
+ return cur.fetchone()[0]
136
+
137
+
138
+ def write_inference_count(count):
139
+ count = max(0, int(count))
140
+ with open_db() as db:
141
+ db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"')
142
+ db.commit()
143
+
144
+
145
+ def run_inference(num_imgs=1):
146
+ write_inference_count(num_imgs)
147
+ count = read_inference_count()
148
+
149
+ return (
150
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
151
+ f"16px; color:red; font-weight: bold;'>{count}</span>"
152
+ )
153
+
154
+
155
+ def update_inference_count():
156
+ count = read_inference_count()
157
+ return (
158
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
159
+ f"16px; color:red; font-weight: bold;'>{count}</span>"
160
+ )
161
+
162
+
163
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
164
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
165
+ if not negative:
166
+ negative = ""
167
+ return p.replace("{prompt}", positive), n + negative
168
+
169
+
170
+ def get_args():
171
+ parser = argparse.ArgumentParser()
172
+ parser.add_argument("--config", type=str, help="config")
173
+ parser.add_argument(
174
+ "--model_path",
175
+ nargs="?",
176
+ default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
177
+ type=str,
178
+ help="Path to the model file (positional)",
179
+ )
180
+ parser.add_argument("--output", default="./", type=str)
181
+ parser.add_argument("--bs", default=1, type=int)
182
+ parser.add_argument("--image_size", default=1024, type=int)
183
+ parser.add_argument("--cfg_scale", default=5.0, type=float)
184
+ parser.add_argument("--pag_scale", default=2.0, type=float)
185
+ parser.add_argument("--seed", default=42, type=int)
186
+ parser.add_argument("--step", default=-1, type=int)
187
+ parser.add_argument("--custom_image_size", default=None, type=int)
188
+ parser.add_argument("--share", action="store_true")
189
+ parser.add_argument(
190
+ "--shield_model_path",
191
+ type=str,
192
+ help="The path to shield model, we employ ShieldGemma-2B by default.",
193
+ default="google/shieldgemma-2b",
194
+ )
195
+
196
+ return parser.parse_known_args()[0]
197
+
198
+
199
+ args = get_args()
200
+
201
+ if torch.cuda.is_available():
202
+ model_path = args.model_path
203
+ pipe = SanaPipeline(args.config)
204
+ pipe.from_pretrained(model_path)
205
+ pipe.register_progress_bar(gr.Progress())
206
+
207
+ # safety checker
208
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
209
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
210
+ args.shield_model_path,
211
+ device_map="auto",
212
+ torch_dtype=torch.bfloat16,
213
+ ).to(device)
214
+
215
+
216
+ def save_image_sana(img, seed="", save_img=False):
217
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
218
+ save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
219
+ os.umask(0o000) # file permission: 666; dir permission: 777
220
+ os.makedirs(save_path, exist_ok=True)
221
+ unique_name = os.path.join(save_path, unique_name)
222
+ if save_img:
223
+ save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
224
+
225
+ return unique_name
226
+
227
+
228
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
229
+ if randomize_seed:
230
+ seed = random.randint(0, MAX_SEED)
231
+ return seed
232
+
233
+
234
+ @torch.no_grad()
235
+ @torch.inference_mode()
236
+ @spaces.GPU(enable_queue=True)
237
+ def generate(
238
+ prompt: str = None,
239
+ negative_prompt: str = "",
240
+ style: str = DEFAULT_STYLE_NAME,
241
+ use_negative_prompt: bool = False,
242
+ num_imgs: int = 1,
243
+ seed: int = 0,
244
+ height: int = 1024,
245
+ width: int = 1024,
246
+ flow_dpms_guidance_scale: float = 5.0,
247
+ flow_dpms_pag_guidance_scale: float = 2.0,
248
+ flow_dpms_inference_steps: int = 20,
249
+ randomize_seed: bool = False,
250
+ ):
251
+ global INFER_SPEED
252
+ # seed = 823753551
253
+ box = run_inference(num_imgs)
254
+ seed = int(randomize_seed_fn(seed, randomize_seed))
255
+ generator = torch.Generator(device=device).manual_seed(seed)
256
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
257
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
258
+ prompt = "A red heart."
259
+
260
+ print(prompt)
261
+
262
+ num_inference_steps = flow_dpms_inference_steps
263
+ guidance_scale = flow_dpms_guidance_scale
264
+ pag_guidance_scale = flow_dpms_pag_guidance_scale
265
+
266
+ if not use_negative_prompt:
267
+ negative_prompt = None # type: ignore
268
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
269
+
270
+ pipe.progress_fn(0, desc="Sana Start")
271
+
272
+ time_start = time.time()
273
+ images = pipe(
274
+ prompt=prompt,
275
+ height=height,
276
+ width=width,
277
+ negative_prompt=negative_prompt,
278
+ guidance_scale=guidance_scale,
279
+ pag_guidance_scale=pag_guidance_scale,
280
+ num_inference_steps=num_inference_steps,
281
+ num_images_per_prompt=num_imgs,
282
+ generator=generator,
283
+ )
284
+
285
+ pipe.progress_fn(1.0, desc="Sana End")
286
+ INFER_SPEED = (time.time() - time_start) / num_imgs
287
+
288
+ save_img = False
289
+ if save_img:
290
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
291
+ print(img)
292
+ else:
293
+ img = [
294
+ Image.fromarray(
295
+ norm_ip(img, -1, 1)
296
+ .mul(255)
297
+ .add_(0.5)
298
+ .clamp_(0, 255)
299
+ .permute(1, 2, 0)
300
+ .to("cpu", torch.uint8)
301
+ .numpy()
302
+ .astype(np.uint8)
303
+ )
304
+ for img in images
305
+ ]
306
+
307
+ torch.cuda.empty_cache()
308
+
309
+ return (
310
+ img,
311
+ seed,
312
+ f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
313
+ box,
314
+ )
315
+
316
+
317
+ model_size = "1.6" if "1600M" in args.model_path else "0.6"
318
+ title = f"""
319
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
320
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
321
+ </div>
322
+ """
323
+ DESCRIPTION = f"""
324
+ <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
325
+ <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
326
+ <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
327
+ <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
328
+ <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
329
+ """
330
+ if model_size == "0.6":
331
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
332
+ if not torch.cuda.is_available():
333
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
334
+
335
+ examples = [
336
+ 'a cyberpunk cat with a neon sign that says "Sana"',
337
+ "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
338
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
339
+ "portrait photo of a girl, photograph, highly detailed face, depth of field",
340
+ 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
341
+ "🐶 Wearing 🕶 flying on the 🌈",
342
+ "👧 with 🌹 in the ❄️",
343
+ "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
344
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
345
+ "Astronaut in a jungle, cold color palette, muted colors, detailed",
346
+ "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
347
+ ]
348
+
349
+ css = """
350
+ .gradio-container{max-width: 640px !important}
351
+ h1{text-align:center}
352
+ """
353
+ with gr.Blocks(css=css, title="Sana") as demo:
354
+ gr.Markdown(title)
355
+ gr.HTML(DESCRIPTION)
356
+ gr.DuplicateButton(
357
+ value="Duplicate Space for private use",
358
+ elem_id="duplicate-button",
359
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
360
+ )
361
+ info_box = gr.Markdown(
362
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
363
+ )
364
+ demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
365
+ # with gr.Row(equal_height=False):
366
+ with gr.Group():
367
+ with gr.Row():
368
+ prompt = gr.Text(
369
+ label="Prompt",
370
+ show_label=False,
371
+ max_lines=1,
372
+ placeholder="Enter your prompt",
373
+ container=False,
374
+ )
375
+ run_button = gr.Button("Run", scale=0)
376
+ result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
377
+ speed_box = gr.Markdown(
378
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
379
+ )
380
+ with gr.Accordion("Advanced options", open=False):
381
+ with gr.Group():
382
+ with gr.Row(visible=True):
383
+ height = gr.Slider(
384
+ label="Height",
385
+ minimum=256,
386
+ maximum=MAX_IMAGE_SIZE,
387
+ step=32,
388
+ value=args.image_size,
389
+ )
390
+ width = gr.Slider(
391
+ label="Width",
392
+ minimum=256,
393
+ maximum=MAX_IMAGE_SIZE,
394
+ step=32,
395
+ value=args.image_size,
396
+ )
397
+ with gr.Row():
398
+ flow_dpms_inference_steps = gr.Slider(
399
+ label="Sampling steps",
400
+ minimum=5,
401
+ maximum=40,
402
+ step=1,
403
+ value=20,
404
+ )
405
+ flow_dpms_guidance_scale = gr.Slider(
406
+ label="CFG Guidance scale",
407
+ minimum=1,
408
+ maximum=10,
409
+ step=0.1,
410
+ value=4.5,
411
+ )
412
+ flow_dpms_pag_guidance_scale = gr.Slider(
413
+ label="PAG Guidance scale",
414
+ minimum=1,
415
+ maximum=4,
416
+ step=0.5,
417
+ value=1.0,
418
+ )
419
+ with gr.Row():
420
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
421
+ negative_prompt = gr.Text(
422
+ label="Negative prompt",
423
+ max_lines=1,
424
+ placeholder="Enter a negative prompt",
425
+ visible=True,
426
+ )
427
+ style_selection = gr.Radio(
428
+ show_label=True,
429
+ container=True,
430
+ interactive=True,
431
+ choices=STYLE_NAMES,
432
+ value=DEFAULT_STYLE_NAME,
433
+ label="Image Style",
434
+ )
435
+ seed = gr.Slider(
436
+ label="Seed",
437
+ minimum=0,
438
+ maximum=MAX_SEED,
439
+ step=1,
440
+ value=0,
441
+ )
442
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
443
+ with gr.Row(visible=True):
444
+ schedule = gr.Radio(
445
+ show_label=True,
446
+ container=True,
447
+ interactive=True,
448
+ choices=SCHEDULE_NAME,
449
+ value=DEFAULT_SCHEDULE_NAME,
450
+ label="Sampler Schedule",
451
+ visible=True,
452
+ )
453
+ num_imgs = gr.Slider(
454
+ label="Num Images",
455
+ minimum=1,
456
+ maximum=6,
457
+ step=1,
458
+ value=1,
459
+ )
460
+
461
+ gr.Examples(
462
+ examples=examples,
463
+ inputs=prompt,
464
+ outputs=[result, seed],
465
+ fn=generate,
466
+ cache_examples=CACHE_EXAMPLES,
467
+ )
468
+
469
+ use_negative_prompt.change(
470
+ fn=lambda x: gr.update(visible=x),
471
+ inputs=use_negative_prompt,
472
+ outputs=negative_prompt,
473
+ api_name=False,
474
+ )
475
+
476
+ gr.on(
477
+ triggers=[
478
+ prompt.submit,
479
+ negative_prompt.submit,
480
+ run_button.click,
481
+ ],
482
+ fn=generate,
483
+ inputs=[
484
+ prompt,
485
+ negative_prompt,
486
+ style_selection,
487
+ use_negative_prompt,
488
+ num_imgs,
489
+ seed,
490
+ height,
491
+ width,
492
+ flow_dpms_guidance_scale,
493
+ flow_dpms_pag_guidance_scale,
494
+ flow_dpms_inference_steps,
495
+ randomize_seed,
496
+ ],
497
+ outputs=[result, seed, speed_box, info_box],
498
+ api_name="run",
499
+ )
500
+
501
+ if __name__ == "__main__":
502
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
apps/app_sana_4bit.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ #!/usr/bin/env python
6
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+ # SPDX-License-Identifier: Apache-2.0
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import os
25
+ import random
26
+ import time
27
+ import uuid
28
+ from datetime import datetime
29
+
30
+ import gradio as gr
31
+ import numpy as np
32
+ import spaces
33
+ import torch
34
+ from diffusers import SanaPipeline
35
+ from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
36
+ from torchvision.utils import save_image
37
+
38
+ MAX_SEED = np.iinfo(np.int32).max
39
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
40
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
41
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
42
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
43
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
44
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
45
+ COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
46
+ INFER_SPEED = 0
47
+
48
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
+
50
+ style_list = [
51
+ {
52
+ "name": "(No style)",
53
+ "prompt": "{prompt}",
54
+ "negative_prompt": "",
55
+ },
56
+ {
57
+ "name": "Cinematic",
58
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
59
+ "cinemascope, moody, epic, gorgeous, film grain, grainy",
60
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
61
+ },
62
+ {
63
+ "name": "Photographic",
64
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
65
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
66
+ },
67
+ {
68
+ "name": "Anime",
69
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
70
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
71
+ },
72
+ {
73
+ "name": "Manga",
74
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
75
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
76
+ },
77
+ {
78
+ "name": "Digital Art",
79
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
80
+ "negative_prompt": "photo, photorealistic, realism, ugly",
81
+ },
82
+ {
83
+ "name": "Pixel art",
84
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
85
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
86
+ },
87
+ {
88
+ "name": "Fantasy art",
89
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
90
+ "majestic, magical, fantasy art, cover art, dreamy",
91
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
92
+ "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
93
+ "disfigured, sloppy, duplicate, mutated, black and white",
94
+ },
95
+ {
96
+ "name": "Neonpunk",
97
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
98
+ "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
99
+ "ultra detailed, intricate, professional",
100
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
101
+ },
102
+ {
103
+ "name": "3D Model",
104
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
105
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
106
+ },
107
+ ]
108
+
109
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
110
+ STYLE_NAMES = list(styles.keys())
111
+ DEFAULT_STYLE_NAME = "(No style)"
112
+ SCHEDULE_NAME = ["Flow_DPM_Solver"]
113
+ DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
114
+ NUM_IMAGES_PER_PROMPT = 1
115
+
116
+
117
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
118
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
119
+ if not negative:
120
+ negative = ""
121
+ return p.replace("{prompt}", positive), n + negative
122
+
123
+
124
+ def get_args():
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument(
127
+ "--model_path",
128
+ nargs="?",
129
+ default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
130
+ type=str,
131
+ help="Path to the model file (positional)",
132
+ )
133
+ parser.add_argument("--share", action="store_true")
134
+
135
+ return parser.parse_known_args()[0]
136
+
137
+
138
+ args = get_args()
139
+
140
+ if torch.cuda.is_available():
141
+
142
+ transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
143
+ pipe = SanaPipeline.from_pretrained(
144
+ "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
145
+ transformer=transformer,
146
+ variant="bf16",
147
+ torch_dtype=torch.bfloat16,
148
+ ).to(device)
149
+
150
+ pipe.text_encoder.to(torch.bfloat16)
151
+ pipe.vae.to(torch.bfloat16)
152
+
153
+
154
+ def save_image_sana(img, seed="", save_img=False):
155
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
156
+ save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
157
+ os.umask(0o000) # file permission: 666; dir permission: 777
158
+ os.makedirs(save_path, exist_ok=True)
159
+ unique_name = os.path.join(save_path, unique_name)
160
+ if save_img:
161
+ save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
162
+
163
+ return unique_name
164
+
165
+
166
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
167
+ if randomize_seed:
168
+ seed = random.randint(0, MAX_SEED)
169
+ return seed
170
+
171
+
172
+ @torch.no_grad()
173
+ @torch.inference_mode()
174
+ @spaces.GPU(enable_queue=True)
175
+ def generate(
176
+ prompt: str = None,
177
+ negative_prompt: str = "",
178
+ style: str = DEFAULT_STYLE_NAME,
179
+ use_negative_prompt: bool = False,
180
+ num_imgs: int = 1,
181
+ seed: int = 0,
182
+ height: int = 1024,
183
+ width: int = 1024,
184
+ flow_dpms_guidance_scale: float = 5.0,
185
+ flow_dpms_inference_steps: int = 20,
186
+ randomize_seed: bool = False,
187
+ ):
188
+ global INFER_SPEED
189
+ # seed = 823753551
190
+ seed = int(randomize_seed_fn(seed, randomize_seed))
191
+ generator = torch.Generator(device=device).manual_seed(seed)
192
+ print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}")
193
+
194
+ print(prompt)
195
+
196
+ num_inference_steps = flow_dpms_inference_steps
197
+ guidance_scale = flow_dpms_guidance_scale
198
+
199
+ if not use_negative_prompt:
200
+ negative_prompt = None # type: ignore
201
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
202
+
203
+ time_start = time.time()
204
+ images = pipe(
205
+ prompt=prompt,
206
+ height=height,
207
+ width=width,
208
+ negative_prompt=negative_prompt,
209
+ guidance_scale=guidance_scale,
210
+ num_inference_steps=num_inference_steps,
211
+ num_images_per_prompt=num_imgs,
212
+ generator=generator,
213
+ ).images
214
+ INFER_SPEED = (time.time() - time_start) / num_imgs
215
+
216
+ save_img = False
217
+ if save_img:
218
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
219
+ print(img)
220
+ else:
221
+ img = images
222
+
223
+ torch.cuda.empty_cache()
224
+
225
+ return (
226
+ img,
227
+ seed,
228
+ f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
229
+ )
230
+
231
+
232
+ model_size = "1.6" if "1600M" in args.model_path else "0.6"
233
+ title = f"""
234
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
235
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="30%" alt="logo"/>
236
+ </div>
237
+ """
238
+ DESCRIPTION = f"""
239
+ <p style="font-size: 30px; font-weight: bold; text-align: center;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)</p>
240
+ """
241
+ if model_size == "0.6":
242
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
243
+ if not torch.cuda.is_available():
244
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
245
+
246
+ examples = [
247
+ 'a cyberpunk cat with a neon sign that says "Sana"',
248
+ "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
249
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
250
+ "portrait photo of a girl, photograph, highly detailed face, depth of field",
251
+ 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
252
+ "🐶 Wearing 🕶 flying on the 🌈",
253
+ "👧 with 🌹 in the ❄️",
254
+ "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
255
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
256
+ "Astronaut in a jungle, cold color palette, muted colors, detailed",
257
+ "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
258
+ ]
259
+
260
+ css = """
261
+ .gradio-container {max-width: 850px !important; height: auto !important;}
262
+ h1 {text-align: center;}
263
+ """
264
+ theme = gr.themes.Base()
265
+ with gr.Blocks(css=css, theme=theme, title="Sana") as demo:
266
+ gr.Markdown(title)
267
+ gr.HTML(DESCRIPTION)
268
+ gr.DuplicateButton(
269
+ value="Duplicate Space for private use",
270
+ elem_id="duplicate-button",
271
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
272
+ )
273
+ # with gr.Row(equal_height=False):
274
+ with gr.Group():
275
+ with gr.Row():
276
+ prompt = gr.Text(
277
+ label="Prompt",
278
+ show_label=False,
279
+ max_lines=1,
280
+ placeholder="Enter your prompt",
281
+ container=False,
282
+ )
283
+ run_button = gr.Button("Run", scale=0)
284
+ result = gr.Gallery(
285
+ label="Result",
286
+ show_label=False,
287
+ height=750,
288
+ columns=NUM_IMAGES_PER_PROMPT,
289
+ format="jpeg",
290
+ )
291
+
292
+ speed_box = gr.Markdown(
293
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
294
+ )
295
+ with gr.Accordion("Advanced options", open=False):
296
+ with gr.Group():
297
+ with gr.Row(visible=True):
298
+ height = gr.Slider(
299
+ label="Height",
300
+ minimum=256,
301
+ maximum=MAX_IMAGE_SIZE,
302
+ step=32,
303
+ value=1024,
304
+ )
305
+ width = gr.Slider(
306
+ label="Width",
307
+ minimum=256,
308
+ maximum=MAX_IMAGE_SIZE,
309
+ step=32,
310
+ value=1024,
311
+ )
312
+ with gr.Row():
313
+ flow_dpms_inference_steps = gr.Slider(
314
+ label="Sampling steps",
315
+ minimum=5,
316
+ maximum=40,
317
+ step=1,
318
+ value=20,
319
+ )
320
+ flow_dpms_guidance_scale = gr.Slider(
321
+ label="CFG Guidance scale",
322
+ minimum=1,
323
+ maximum=10,
324
+ step=0.1,
325
+ value=4.5,
326
+ )
327
+ with gr.Row():
328
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
329
+ negative_prompt = gr.Text(
330
+ label="Negative prompt",
331
+ max_lines=1,
332
+ placeholder="Enter a negative prompt",
333
+ visible=True,
334
+ )
335
+ style_selection = gr.Radio(
336
+ show_label=True,
337
+ container=True,
338
+ interactive=True,
339
+ choices=STYLE_NAMES,
340
+ value=DEFAULT_STYLE_NAME,
341
+ label="Image Style",
342
+ )
343
+ seed = gr.Slider(
344
+ label="Seed",
345
+ minimum=0,
346
+ maximum=MAX_SEED,
347
+ step=1,
348
+ value=0,
349
+ )
350
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
351
+ with gr.Row(visible=True):
352
+ schedule = gr.Radio(
353
+ show_label=True,
354
+ container=True,
355
+ interactive=True,
356
+ choices=SCHEDULE_NAME,
357
+ value=DEFAULT_SCHEDULE_NAME,
358
+ label="Sampler Schedule",
359
+ visible=True,
360
+ )
361
+ num_imgs = gr.Slider(
362
+ label="Num Images",
363
+ minimum=1,
364
+ maximum=6,
365
+ step=1,
366
+ value=1,
367
+ )
368
+
369
+ gr.Examples(
370
+ examples=examples,
371
+ inputs=prompt,
372
+ outputs=[result, seed],
373
+ fn=generate,
374
+ cache_examples=CACHE_EXAMPLES,
375
+ )
376
+
377
+ use_negative_prompt.change(
378
+ fn=lambda x: gr.update(visible=x),
379
+ inputs=use_negative_prompt,
380
+ outputs=negative_prompt,
381
+ api_name=False,
382
+ )
383
+
384
+ gr.on(
385
+ triggers=[
386
+ prompt.submit,
387
+ negative_prompt.submit,
388
+ run_button.click,
389
+ ],
390
+ fn=generate,
391
+ inputs=[
392
+ prompt,
393
+ negative_prompt,
394
+ style_selection,
395
+ use_negative_prompt,
396
+ num_imgs,
397
+ seed,
398
+ height,
399
+ width,
400
+ flow_dpms_guidance_scale,
401
+ flow_dpms_inference_steps,
402
+ randomize_seed,
403
+ ],
404
+ outputs=[result, seed, speed_box],
405
+ api_name="run",
406
+ )
407
+
408
+ if __name__ == "__main__":
409
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
apps/app_sana_4bit_compare_bf16.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
2
+ import argparse
3
+ import os
4
+ import random
5
+ import time
6
+ from datetime import datetime
7
+
8
+ import GPUtil
9
+
10
+ # import gradio last to avoid conflicts with other imports
11
+ import gradio as gr
12
+ import safety_check
13
+ import spaces
14
+ import torch
15
+ from diffusers import SanaPipeline
16
+ from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+
19
+ MAX_IMAGE_SIZE = 2048
20
+ MAX_SEED = 1000000000
21
+
22
+ DEFAULT_HEIGHT = 1024
23
+ DEFAULT_WIDTH = 1024
24
+
25
+ # num_inference_steps, guidance_scale, seed
26
+ EXAMPLES = [
27
+ [
28
+ "🐶 Wearing 🕶 flying on the 🌈",
29
+ 1024,
30
+ 1024,
31
+ 20,
32
+ 5,
33
+ 2,
34
+ ],
35
+ [
36
+ "大漠孤烟直, 长河落日圆",
37
+ 1024,
38
+ 1024,
39
+ 20,
40
+ 5,
41
+ 23,
42
+ ],
43
+ [
44
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
45
+ "volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
46
+ "art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
47
+ 1024,
48
+ 1024,
49
+ 20,
50
+ 5,
51
+ 233,
52
+ ],
53
+ [
54
+ "A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
55
+ "sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
56
+ "lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
57
+ "for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
58
+ "cinematic lighting, ultra-HD.",
59
+ 1024,
60
+ 1024,
61
+ 20,
62
+ 5,
63
+ 2333,
64
+ ],
65
+ [
66
+ "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
67
+ "She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
68
+ "She wears sunglasses and red lipstick. She walks confidently and casually. "
69
+ "The street is damp and reflective, creating a mirror effect of the colorful lights. "
70
+ "Many pedestrians walk about.",
71
+ 1024,
72
+ 1024,
73
+ 20,
74
+ 5,
75
+ 23333,
76
+ ],
77
+ [
78
+ "Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
79
+ "opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
80
+ "and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
81
+ "cinematic lighting, ultra-HD.",
82
+ 1024,
83
+ 1024,
84
+ 20,
85
+ 5,
86
+ 233333,
87
+ ],
88
+ ]
89
+
90
+
91
+ def hash_str_to_int(s: str) -> int:
92
+ """Hash a string to an integer."""
93
+ modulus = 10**9 + 7 # Large prime modulus
94
+ hash_int = 0
95
+ for char in s:
96
+ hash_int = (hash_int * 31 + ord(char)) % modulus
97
+ return hash_int
98
+
99
+
100
+ def get_pipeline(
101
+ precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
102
+ ) -> SanaPipeline:
103
+ if precision == "int4":
104
+ assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
105
+ transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
106
+
107
+ pipeline_init_kwargs["transformer"] = transformer
108
+ if use_qencoder:
109
+ raise NotImplementedError("Quantized encoder not supported for Sana for now")
110
+ else:
111
+ assert precision == "bf16"
112
+ pipeline = SanaPipeline.from_pretrained(
113
+ "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
114
+ variant="bf16",
115
+ torch_dtype=torch.bfloat16,
116
+ **pipeline_init_kwargs,
117
+ )
118
+
119
+ pipeline = pipeline.to(device)
120
+ return pipeline
121
+
122
+
123
+ def get_args() -> argparse.Namespace:
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument(
126
+ "-p",
127
+ "--precisions",
128
+ type=str,
129
+ default=["int4"],
130
+ nargs="*",
131
+ choices=["int4", "bf16"],
132
+ help="Which precisions to use",
133
+ )
134
+ parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
135
+ parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
136
+ parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
137
+ return parser.parse_args()
138
+
139
+
140
+ args = get_args()
141
+
142
+
143
+ pipelines = []
144
+ pipeline_init_kwargs = {}
145
+ for i, precision in enumerate(args.precisions):
146
+
147
+ pipeline = get_pipeline(
148
+ precision=precision,
149
+ use_qencoder=args.use_qencoder,
150
+ device="cuda",
151
+ pipeline_init_kwargs={**pipeline_init_kwargs},
152
+ )
153
+ pipelines.append(pipeline)
154
+ if i == 0:
155
+ pipeline_init_kwargs["vae"] = pipeline.vae
156
+ pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
157
+
158
+ # safety checker
159
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
160
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
161
+ args.shield_model_path,
162
+ device_map="auto",
163
+ torch_dtype=torch.bfloat16,
164
+ ).to(pipeline.device)
165
+
166
+
167
+ @spaces.GPU(enable_queue=True)
168
+ def generate(
169
+ prompt: str = None,
170
+ height: int = 1024,
171
+ width: int = 1024,
172
+ num_inference_steps: int = 4,
173
+ guidance_scale: float = 0,
174
+ seed: int = 0,
175
+ ):
176
+ print(f"Prompt: {prompt}")
177
+ is_unsafe_prompt = False
178
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
179
+ prompt = "A peaceful world."
180
+ images, latency_strs = [], []
181
+ for i, pipeline in enumerate(pipelines):
182
+ progress = gr.Progress(track_tqdm=True)
183
+ start_time = time.time()
184
+ image = pipeline(
185
+ prompt=prompt,
186
+ height=height,
187
+ width=width,
188
+ guidance_scale=guidance_scale,
189
+ num_inference_steps=num_inference_steps,
190
+ generator=torch.Generator().manual_seed(seed),
191
+ ).images[0]
192
+ end_time = time.time()
193
+ latency = end_time - start_time
194
+ if latency < 1:
195
+ latency = latency * 1000
196
+ latency_str = f"{latency:.2f}ms"
197
+ else:
198
+ latency_str = f"{latency:.2f}s"
199
+ images.append(image)
200
+ latency_strs.append(latency_str)
201
+ if is_unsafe_prompt:
202
+ for i in range(len(latency_strs)):
203
+ latency_strs[i] += " (Unsafe prompt detected)"
204
+ torch.cuda.empty_cache()
205
+
206
+ if args.count_use:
207
+ if os.path.exists("use_count.txt"):
208
+ with open("use_count.txt") as f:
209
+ count = int(f.read())
210
+ else:
211
+ count = 0
212
+ count += 1
213
+ current_time = datetime.now()
214
+ print(f"{current_time}: {count}")
215
+ with open("use_count.txt", "w") as f:
216
+ f.write(str(count))
217
+ with open("use_record.txt", "a") as f:
218
+ f.write(f"{current_time}: {count}\n")
219
+
220
+ return *images, *latency_strs
221
+
222
+
223
+ with open("./assets/description.html") as f:
224
+ DESCRIPTION = f.read()
225
+ gpus = GPUtil.getGPUs()
226
+ if len(gpus) > 0:
227
+ gpu = gpus[0]
228
+ memory = gpu.memoryTotal / 1024
229
+ device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
230
+ else:
231
+ device_info = "Running on CPU 🥶 This demo does not work on CPU."
232
+ notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
233
+
234
+ with gr.Blocks(
235
+ css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
236
+ title=f"SVDQuant SANA-1600M Demo",
237
+ ) as demo:
238
+
239
+ def get_header_str():
240
+
241
+ if args.count_use:
242
+ if os.path.exists("use_count.txt"):
243
+ with open("use_count.txt") as f:
244
+ count = int(f.read())
245
+ else:
246
+ count = 0
247
+ count_info = (
248
+ f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
249
+ f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
250
+ f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
251
+ )
252
+ else:
253
+ count_info = ""
254
+ header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
255
+ return header_str
256
+
257
+ header = gr.HTML(get_header_str())
258
+ demo.load(fn=get_header_str, outputs=header)
259
+
260
+ with gr.Row():
261
+ image_results, latency_results = [], []
262
+ for i, precision in enumerate(args.precisions):
263
+ with gr.Column():
264
+ gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
265
+ with gr.Group():
266
+ image_result = gr.Image(
267
+ format="png",
268
+ image_mode="RGB",
269
+ label="Result",
270
+ show_label=False,
271
+ show_download_button=True,
272
+ interactive=False,
273
+ )
274
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
275
+ image_results.append(image_result)
276
+ latency_results.append(latency_result)
277
+ with gr.Row():
278
+ prompt = gr.Text(
279
+ label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
280
+ )
281
+ run_button = gr.Button("Run", scale=1)
282
+
283
+ with gr.Row():
284
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
285
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
286
+ with gr.Accordion("Advanced options", open=False):
287
+ with gr.Group():
288
+ height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
289
+ width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
290
+ with gr.Group():
291
+ num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
292
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
293
+
294
+ input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed]
295
+
296
+ gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
297
+
298
+ gr.on(
299
+ triggers=[prompt.submit, run_button.click],
300
+ fn=generate,
301
+ inputs=input_args,
302
+ outputs=[*image_results, *latency_results],
303
+ api_name="run",
304
+ )
305
+ randomize_seed.click(
306
+ lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
307
+ ).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
308
+
309
+ gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
310
+
311
+
312
+ if __name__ == "__main__":
313
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
apps/app_sana_controlnet_hed.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
2
+ import argparse
3
+ import os
4
+ import random
5
+ import socket
6
+ import tempfile
7
+ import time
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ from app import safety_check
16
+ from app.sana_controlnet_pipeline import SanaControlNetPipeline
17
+
18
+ STYLES = {
19
+ "None": "{prompt}",
20
+ "Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
21
+ "3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
22
+ "Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
23
+ "Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
24
+ "Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
25
+ "Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
26
+ "Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
27
+ "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
28
+ "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
29
+ }
30
+ DEFAULT_STYLE_NAME = "None"
31
+ STYLE_NAMES = list(STYLES.keys())
32
+
33
+ MAX_SEED = 1000000000
34
+ DEFAULT_SKETCH_GUIDANCE = 0.28
35
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
36
+
37
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
+
39
+ blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
40
+
41
+
42
+ def get_args():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--config", type=str, help="config")
45
+ parser.add_argument(
46
+ "--model_path",
47
+ nargs="?",
48
+ default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
49
+ type=str,
50
+ help="Path to the model file (positional)",
51
+ )
52
+ parser.add_argument("--output", default="./", type=str)
53
+ parser.add_argument("--bs", default=1, type=int)
54
+ parser.add_argument("--image_size", default=1024, type=int)
55
+ parser.add_argument("--cfg_scale", default=5.0, type=float)
56
+ parser.add_argument("--pag_scale", default=2.0, type=float)
57
+ parser.add_argument("--seed", default=42, type=int)
58
+ parser.add_argument("--step", default=-1, type=int)
59
+ parser.add_argument("--custom_image_size", default=None, type=int)
60
+ parser.add_argument("--share", action="store_true")
61
+ parser.add_argument(
62
+ "--shield_model_path",
63
+ type=str,
64
+ help="The path to shield model, we employ ShieldGemma-2B by default.",
65
+ default="google/shieldgemma-2b",
66
+ )
67
+
68
+ return parser.parse_known_args()[0]
69
+
70
+
71
+ args = get_args()
72
+
73
+ if torch.cuda.is_available():
74
+ model_path = args.model_path
75
+ pipe = SanaControlNetPipeline(args.config)
76
+ pipe.from_pretrained(model_path)
77
+ pipe.register_progress_bar(gr.Progress())
78
+
79
+ # safety checker
80
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
81
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
82
+ args.shield_model_path,
83
+ device_map="auto",
84
+ torch_dtype=torch.bfloat16,
85
+ ).to(device)
86
+
87
+
88
+ def save_image(img):
89
+ if isinstance(img, dict):
90
+ img = img["composite"]
91
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
92
+ img.save(temp_file.name)
93
+ return temp_file.name
94
+
95
+
96
+ def norm_ip(img, low, high):
97
+ img.clamp_(min=low, max=high)
98
+ img.sub_(low).div_(max(high - low, 1e-5))
99
+ return img
100
+
101
+
102
+ @torch.no_grad()
103
+ @torch.inference_mode()
104
+ def run(
105
+ image,
106
+ prompt: str,
107
+ prompt_template: str,
108
+ sketch_thickness: int,
109
+ guidance_scale: float,
110
+ inference_steps: int,
111
+ seed: int,
112
+ blend_alpha: float,
113
+ ) -> tuple[Image, str]:
114
+
115
+ print(f"Prompt: {prompt}")
116
+ image_numpy = np.array(image["composite"].convert("RGB"))
117
+
118
+ if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
119
+ return blank_image, "Please input the prompt or draw something."
120
+
121
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
122
+ prompt = "A red heart."
123
+
124
+ prompt = prompt_template.format(prompt=prompt)
125
+ pipe.set_blend_alpha(blend_alpha)
126
+ start_time = time.time()
127
+ images = pipe(
128
+ prompt=prompt,
129
+ ref_image=image["composite"],
130
+ guidance_scale=guidance_scale,
131
+ num_inference_steps=inference_steps,
132
+ num_images_per_prompt=1,
133
+ sketch_thickness=sketch_thickness,
134
+ generator=torch.Generator(device=device).manual_seed(seed),
135
+ )
136
+
137
+ latency = time.time() - start_time
138
+
139
+ if latency < 1:
140
+ latency = latency * 1000
141
+ latency_str = f"{latency:.2f}ms"
142
+ else:
143
+ latency_str = f"{latency:.2f}s"
144
+ torch.cuda.empty_cache()
145
+
146
+ img = [
147
+ Image.fromarray(
148
+ norm_ip(img, -1, 1)
149
+ .mul(255)
150
+ .add_(0.5)
151
+ .clamp_(0, 255)
152
+ .permute(1, 2, 0)
153
+ .to("cpu", torch.uint8)
154
+ .numpy()
155
+ .astype(np.uint8)
156
+ )
157
+ for img in images
158
+ ]
159
+ img = img[0]
160
+ return img, latency_str
161
+
162
+
163
+ model_size = "1.6" if "1600M" in args.model_path else "0.6"
164
+ title = f"""
165
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
166
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
167
+ </div>
168
+ """
169
+ DESCRIPTION = f"""
170
+ <p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
171
+ <p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
172
+ <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
173
+ <p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
174
+ <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
175
+ """
176
+ if model_size == "0.6":
177
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
178
+ if not torch.cuda.is_available():
179
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
180
+
181
+
182
+ with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
183
+ gr.Markdown(title)
184
+ gr.HTML(DESCRIPTION)
185
+
186
+ with gr.Row(elem_id="main_row"):
187
+ with gr.Column(elem_id="column_input"):
188
+ gr.Markdown("## INPUT", elem_id="input_header")
189
+ with gr.Group():
190
+ canvas = gr.Sketchpad(
191
+ value=blank_image,
192
+ height=640,
193
+ image_mode="RGB",
194
+ sources=["upload", "clipboard"],
195
+ type="pil",
196
+ label="Sketch",
197
+ show_label=False,
198
+ show_download_button=True,
199
+ interactive=True,
200
+ transforms=[],
201
+ canvas_size=(1024, 1024),
202
+ scale=1,
203
+ brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
204
+ format="png",
205
+ layers=False,
206
+ )
207
+ with gr.Row():
208
+ prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
209
+ run_button = gr.Button("Run", scale=1, elem_id="run_button")
210
+ download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
211
+ with gr.Row():
212
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
213
+ prompt_template = gr.Textbox(
214
+ label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
215
+ )
216
+
217
+ with gr.Row():
218
+ sketch_thickness = gr.Slider(
219
+ label="Sketch Thickness",
220
+ minimum=1,
221
+ maximum=4,
222
+ step=1,
223
+ value=2,
224
+ )
225
+ with gr.Row():
226
+ inference_steps = gr.Slider(
227
+ label="Sampling steps",
228
+ minimum=5,
229
+ maximum=40,
230
+ step=1,
231
+ value=20,
232
+ )
233
+ guidance_scale = gr.Slider(
234
+ label="CFG Guidance scale",
235
+ minimum=1,
236
+ maximum=10,
237
+ step=0.1,
238
+ value=4.5,
239
+ )
240
+ blend_alpha = gr.Slider(
241
+ label="Blend Alpha",
242
+ minimum=0,
243
+ maximum=1,
244
+ step=0.1,
245
+ value=0,
246
+ )
247
+ with gr.Row():
248
+ seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
249
+ randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
250
+
251
+ with gr.Column(elem_id="column_output"):
252
+ gr.Markdown("## OUTPUT", elem_id="output_header")
253
+ with gr.Group():
254
+ result = gr.Image(
255
+ format="png",
256
+ height=640,
257
+ image_mode="RGB",
258
+ type="pil",
259
+ label="Result",
260
+ show_label=False,
261
+ show_download_button=True,
262
+ interactive=False,
263
+ elem_id="output_image",
264
+ )
265
+ latency_result = gr.Text(label="Inference Latency", show_label=True)
266
+
267
+ download_result = gr.DownloadButton("Download Result", elem_id="download_result")
268
+ gr.Markdown("### Instructions")
269
+ gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
270
+ gr.Markdown("**2**. Start sketching or upload a reference image")
271
+ gr.Markdown("**3**. Change the image style using a style template")
272
+ gr.Markdown("**4**. Try different seeds to generate different results")
273
+
274
+ run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
275
+ run_outputs = [result, latency_result]
276
+
277
+ randomize_seed.click(
278
+ lambda: random.randint(0, MAX_SEED),
279
+ inputs=[],
280
+ outputs=seed,
281
+ api_name=False,
282
+ queue=False,
283
+ ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
284
+
285
+ style.change(
286
+ lambda x: STYLES[x],
287
+ inputs=[style],
288
+ outputs=[prompt_template],
289
+ api_name=False,
290
+ queue=False,
291
+ ).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
292
+ gr.on(
293
+ triggers=[prompt.submit, run_button.click, canvas.change],
294
+ fn=run,
295
+ inputs=run_inputs,
296
+ outputs=run_outputs,
297
+ api_name=False,
298
+ )
299
+
300
+ download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
301
+ download_result.click(fn=save_image, inputs=result, outputs=download_result)
302
+ gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
apps/app_sana_multithread.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # SPDX-License-Identifier: Apache-2.0
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import os
21
+ import random
22
+ import uuid
23
+ from datetime import datetime
24
+
25
+ import gradio as gr
26
+ import numpy as np
27
+ import spaces
28
+ import torch
29
+ from diffusers import FluxPipeline
30
+ from PIL import Image
31
+ from torchvision.utils import make_grid, save_image
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ from app import safety_check
35
+ from app.sana_pipeline import SanaPipeline
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
39
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
40
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
41
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
42
+ DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
43
+ os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
44
+
45
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
+
47
+ style_list = [
48
+ {
49
+ "name": "(No style)",
50
+ "prompt": "{prompt}",
51
+ "negative_prompt": "",
52
+ },
53
+ {
54
+ "name": "Cinematic",
55
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
56
+ "cinemascope, moody, epic, gorgeous, film grain, grainy",
57
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
58
+ },
59
+ {
60
+ "name": "Photographic",
61
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
62
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
63
+ },
64
+ {
65
+ "name": "Anime",
66
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
67
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
68
+ },
69
+ {
70
+ "name": "Manga",
71
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
72
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
73
+ },
74
+ {
75
+ "name": "Digital Art",
76
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
77
+ "negative_prompt": "photo, photorealistic, realism, ugly",
78
+ },
79
+ {
80
+ "name": "Pixel art",
81
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
82
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
83
+ },
84
+ {
85
+ "name": "Fantasy art",
86
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
87
+ "majestic, magical, fantasy art, cover art, dreamy",
88
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
89
+ "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
90
+ "disfigured, sloppy, duplicate, mutated, black and white",
91
+ },
92
+ {
93
+ "name": "Neonpunk",
94
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
95
+ "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
96
+ "ultra detailed, intricate, professional",
97
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
98
+ },
99
+ {
100
+ "name": "3D Model",
101
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
102
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
103
+ },
104
+ ]
105
+
106
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
107
+ STYLE_NAMES = list(styles.keys())
108
+ DEFAULT_STYLE_NAME = "(No style)"
109
+ SCHEDULE_NAME = ["Flow_DPM_Solver"]
110
+ DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
111
+ NUM_IMAGES_PER_PROMPT = 1
112
+ TEST_TIMES = 0
113
+ FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
114
+
115
+
116
+ def set_env(seed=0):
117
+ torch.manual_seed(seed)
118
+ torch.set_grad_enabled(False)
119
+
120
+
121
+ def read_inference_count():
122
+ global TEST_TIMES
123
+ try:
124
+ with open(FILENAME) as f:
125
+ count = int(f.read().strip())
126
+ except FileNotFoundError:
127
+ count = 0
128
+ TEST_TIMES = count
129
+
130
+ return count
131
+
132
+
133
+ def write_inference_count(count):
134
+ with open(FILENAME, "w") as f:
135
+ f.write(str(count))
136
+
137
+
138
+ def run_inference(num_imgs=1):
139
+ TEST_TIMES = read_inference_count()
140
+ TEST_TIMES += int(num_imgs)
141
+ write_inference_count(TEST_TIMES)
142
+
143
+ return (
144
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
145
+ f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
146
+ )
147
+
148
+
149
+ def update_inference_count():
150
+ count = read_inference_count()
151
+ return (
152
+ f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
153
+ f"16px; color:red; font-weight: bold;'>{count}</span>"
154
+ )
155
+
156
+
157
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
158
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
159
+ if not negative:
160
+ negative = ""
161
+ return p.replace("{prompt}", positive), n + negative
162
+
163
+
164
+ def get_args():
165
+ parser = argparse.ArgumentParser()
166
+ parser.add_argument("--config", type=str, help="config")
167
+ parser.add_argument(
168
+ "--model_path",
169
+ nargs="?",
170
+ default="output/Sana_D20/SANA.pth",
171
+ type=str,
172
+ help="Path to the model file (positional)",
173
+ )
174
+ parser.add_argument("--output", default="./", type=str)
175
+ parser.add_argument("--bs", default=1, type=int)
176
+ parser.add_argument("--image_size", default=1024, type=int)
177
+ parser.add_argument("--cfg_scale", default=5.0, type=float)
178
+ parser.add_argument("--pag_scale", default=2.0, type=float)
179
+ parser.add_argument("--seed", default=42, type=int)
180
+ parser.add_argument("--step", default=-1, type=int)
181
+ parser.add_argument("--custom_image_size", default=None, type=int)
182
+ parser.add_argument(
183
+ "--shield_model_path",
184
+ type=str,
185
+ help="The path to shield model, we employ ShieldGemma-2B by default.",
186
+ default="google/shieldgemma-2b",
187
+ )
188
+
189
+ return parser.parse_args()
190
+
191
+
192
+ args = get_args()
193
+
194
+ if torch.cuda.is_available():
195
+ weight_dtype = torch.float16
196
+ model_path = args.model_path
197
+ pipe = SanaPipeline(args.config)
198
+ pipe.from_pretrained(model_path)
199
+ pipe.register_progress_bar(gr.Progress())
200
+
201
+ repo_name = "black-forest-labs/FLUX.1-dev"
202
+ pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
203
+
204
+ # safety checker
205
+ safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
206
+ safety_checker_model = AutoModelForCausalLM.from_pretrained(
207
+ args.shield_model_path,
208
+ device_map="auto",
209
+ torch_dtype=torch.bfloat16,
210
+ ).to(device)
211
+
212
+ set_env(42)
213
+
214
+
215
+ def save_image_sana(img, seed="", save_img=False):
216
+ unique_name = f"{str(uuid.uuid4())}_{seed}.png"
217
+ save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
218
+ os.umask(0o000) # file permission: 666; dir permission: 777
219
+ os.makedirs(save_path, exist_ok=True)
220
+ unique_name = os.path.join(save_path, unique_name)
221
+ if save_img:
222
+ save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
223
+
224
+ return unique_name
225
+
226
+
227
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
228
+ if randomize_seed:
229
+ seed = random.randint(0, MAX_SEED)
230
+ return seed
231
+
232
+
233
+ @spaces.GPU(enable_queue=True)
234
+ async def generate_2(
235
+ prompt: str = None,
236
+ negative_prompt: str = "",
237
+ style: str = DEFAULT_STYLE_NAME,
238
+ use_negative_prompt: bool = False,
239
+ num_imgs: int = 1,
240
+ seed: int = 0,
241
+ height: int = 1024,
242
+ width: int = 1024,
243
+ flow_dpms_guidance_scale: float = 5.0,
244
+ flow_dpms_pag_guidance_scale: float = 2.0,
245
+ flow_dpms_inference_steps: int = 20,
246
+ randomize_seed: bool = False,
247
+ ):
248
+ seed = int(randomize_seed_fn(seed, randomize_seed))
249
+ generator = torch.Generator(device=device).manual_seed(seed)
250
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
251
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
252
+ prompt = "A red heart."
253
+
254
+ print(prompt)
255
+
256
+ if not use_negative_prompt:
257
+ negative_prompt = None # type: ignore
258
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
259
+
260
+ with torch.no_grad():
261
+ images = pipe2(
262
+ prompt=prompt,
263
+ height=height,
264
+ width=width,
265
+ guidance_scale=3.5,
266
+ num_inference_steps=50,
267
+ num_images_per_prompt=num_imgs,
268
+ max_sequence_length=256,
269
+ generator=generator,
270
+ ).images
271
+
272
+ save_img = False
273
+ img = images
274
+ if save_img:
275
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
276
+ print(img)
277
+ torch.cuda.empty_cache()
278
+
279
+ return img
280
+
281
+
282
+ @spaces.GPU(enable_queue=True)
283
+ async def generate(
284
+ prompt: str = None,
285
+ negative_prompt: str = "",
286
+ style: str = DEFAULT_STYLE_NAME,
287
+ use_negative_prompt: bool = False,
288
+ num_imgs: int = 1,
289
+ seed: int = 0,
290
+ height: int = 1024,
291
+ width: int = 1024,
292
+ flow_dpms_guidance_scale: float = 5.0,
293
+ flow_dpms_pag_guidance_scale: float = 2.0,
294
+ flow_dpms_inference_steps: int = 20,
295
+ randomize_seed: bool = False,
296
+ ):
297
+ global TEST_TIMES
298
+ # seed = 823753551
299
+ seed = int(randomize_seed_fn(seed, randomize_seed))
300
+ generator = torch.Generator(device=device).manual_seed(seed)
301
+ print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
302
+ if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
303
+ prompt = "A red heart."
304
+
305
+ print(prompt)
306
+
307
+ num_inference_steps = flow_dpms_inference_steps
308
+ guidance_scale = flow_dpms_guidance_scale
309
+ pag_guidance_scale = flow_dpms_pag_guidance_scale
310
+
311
+ if not use_negative_prompt:
312
+ negative_prompt = None # type: ignore
313
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
314
+
315
+ pipe.progress_fn(0, desc="Sana Start")
316
+
317
+ with torch.no_grad():
318
+ images = pipe(
319
+ prompt=prompt,
320
+ height=height,
321
+ width=width,
322
+ negative_prompt=negative_prompt,
323
+ guidance_scale=guidance_scale,
324
+ pag_guidance_scale=pag_guidance_scale,
325
+ num_inference_steps=num_inference_steps,
326
+ num_images_per_prompt=num_imgs,
327
+ generator=generator,
328
+ )
329
+
330
+ pipe.progress_fn(1.0, desc="Sana End")
331
+
332
+ save_img = False
333
+ if save_img:
334
+ img = [save_image_sana(img, seed, save_img=save_image) for img in images]
335
+ print(img)
336
+ else:
337
+ if num_imgs > 1:
338
+ nrow = 2
339
+ else:
340
+ nrow = 1
341
+ img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
342
+ img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
343
+ img = [Image.fromarray(img.astype(np.uint8))]
344
+
345
+ torch.cuda.empty_cache()
346
+
347
+ return img
348
+
349
+
350
+ TEST_TIMES = read_inference_count()
351
+ model_size = "1.6" if "D20" in args.model_path else "0.6"
352
+ title = f"""
353
+ <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
354
+ <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
355
+ </div>
356
+ """
357
+ DESCRIPTION = f"""
358
+ <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
359
+ <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
360
+ <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
361
+ <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
362
+ <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
363
+ """
364
+ if model_size == "0.6":
365
+ DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
366
+ if not torch.cuda.is_available():
367
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
368
+
369
+ examples = [
370
+ 'a cyberpunk cat with a neon sign that says "Sana"',
371
+ "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
372
+ "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
373
+ "portrait photo of a girl, photograph, highly detailed face, depth of field",
374
+ 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
375
+ "🐶 Wearing 🕶 flying on the 🌈",
376
+ # "👧 with 🌹 in the ❄️",
377
+ # "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
378
+ # "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
379
+ # "Astronaut in a jungle, cold color palette, muted colors, detailed",
380
+ # "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
381
+ ]
382
+
383
+ css = """
384
+ .gradio-container{max-width: 1024px !important}
385
+ h1{text-align:center}
386
+ """
387
+ with gr.Blocks(css=css) as demo:
388
+ gr.Markdown(title)
389
+ gr.Markdown(DESCRIPTION)
390
+ gr.DuplicateButton(
391
+ value="Duplicate Space for private use",
392
+ elem_id="duplicate-button",
393
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
394
+ )
395
+ info_box = gr.Markdown(
396
+ value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
397
+ )
398
+ demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
399
+ # with gr.Row(equal_height=False):
400
+ with gr.Group():
401
+ with gr.Row():
402
+ prompt = gr.Text(
403
+ label="Prompt",
404
+ show_label=False,
405
+ max_lines=1,
406
+ placeholder="Enter your prompt",
407
+ container=False,
408
+ )
409
+ run_button = gr.Button("Run-sana", scale=0)
410
+ run_button2 = gr.Button("Run-flux", scale=0)
411
+
412
+ with gr.Row():
413
+ result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
414
+ result_2 = gr.Gallery(
415
+ label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
416
+ )
417
+
418
+ with gr.Accordion("Advanced options", open=False):
419
+ with gr.Group():
420
+ with gr.Row(visible=True):
421
+ height = gr.Slider(
422
+ label="Height",
423
+ minimum=256,
424
+ maximum=MAX_IMAGE_SIZE,
425
+ step=32,
426
+ value=1024,
427
+ )
428
+ width = gr.Slider(
429
+ label="Width",
430
+ minimum=256,
431
+ maximum=MAX_IMAGE_SIZE,
432
+ step=32,
433
+ value=1024,
434
+ )
435
+ with gr.Row():
436
+ flow_dpms_inference_steps = gr.Slider(
437
+ label="Sampling steps",
438
+ minimum=5,
439
+ maximum=40,
440
+ step=1,
441
+ value=18,
442
+ )
443
+ flow_dpms_guidance_scale = gr.Slider(
444
+ label="CFG Guidance scale",
445
+ minimum=1,
446
+ maximum=10,
447
+ step=0.1,
448
+ value=5.0,
449
+ )
450
+ flow_dpms_pag_guidance_scale = gr.Slider(
451
+ label="PAG Guidance scale",
452
+ minimum=1,
453
+ maximum=4,
454
+ step=0.5,
455
+ value=2.0,
456
+ )
457
+ with gr.Row():
458
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
459
+ negative_prompt = gr.Text(
460
+ label="Negative prompt",
461
+ max_lines=1,
462
+ placeholder="Enter a negative prompt",
463
+ visible=True,
464
+ )
465
+ style_selection = gr.Radio(
466
+ show_label=True,
467
+ container=True,
468
+ interactive=True,
469
+ choices=STYLE_NAMES,
470
+ value=DEFAULT_STYLE_NAME,
471
+ label="Image Style",
472
+ )
473
+ seed = gr.Slider(
474
+ label="Seed",
475
+ minimum=0,
476
+ maximum=MAX_SEED,
477
+ step=1,
478
+ value=0,
479
+ )
480
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
481
+ with gr.Row(visible=True):
482
+ schedule = gr.Radio(
483
+ show_label=True,
484
+ container=True,
485
+ interactive=True,
486
+ choices=SCHEDULE_NAME,
487
+ value=DEFAULT_SCHEDULE_NAME,
488
+ label="Sampler Schedule",
489
+ visible=True,
490
+ )
491
+ num_imgs = gr.Slider(
492
+ label="Num Images",
493
+ minimum=1,
494
+ maximum=6,
495
+ step=1,
496
+ value=1,
497
+ )
498
+
499
+ run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
500
+
501
+ gr.Examples(
502
+ examples=examples,
503
+ inputs=prompt,
504
+ outputs=[result],
505
+ fn=generate,
506
+ cache_examples=CACHE_EXAMPLES,
507
+ )
508
+ gr.Examples(
509
+ examples=examples,
510
+ inputs=prompt,
511
+ outputs=[result_2],
512
+ fn=generate_2,
513
+ cache_examples=CACHE_EXAMPLES,
514
+ )
515
+
516
+ use_negative_prompt.change(
517
+ fn=lambda x: gr.update(visible=x),
518
+ inputs=use_negative_prompt,
519
+ outputs=negative_prompt,
520
+ api_name=False,
521
+ )
522
+
523
+ run_button.click(
524
+ fn=generate,
525
+ inputs=[
526
+ prompt,
527
+ negative_prompt,
528
+ style_selection,
529
+ use_negative_prompt,
530
+ num_imgs,
531
+ seed,
532
+ height,
533
+ width,
534
+ flow_dpms_guidance_scale,
535
+ flow_dpms_pag_guidance_scale,
536
+ flow_dpms_inference_steps,
537
+ randomize_seed,
538
+ ],
539
+ outputs=[result],
540
+ queue=True,
541
+ )
542
+
543
+ run_button2.click(
544
+ fn=generate_2,
545
+ inputs=[
546
+ prompt,
547
+ negative_prompt,
548
+ style_selection,
549
+ use_negative_prompt,
550
+ num_imgs,
551
+ seed,
552
+ height,
553
+ width,
554
+ flow_dpms_guidance_scale,
555
+ flow_dpms_pag_guidance_scale,
556
+ flow_dpms_inference_steps,
557
+ randomize_seed,
558
+ ],
559
+ outputs=[result_2],
560
+ queue=True,
561
+ )
562
+
563
+
564
+ if __name__ == "__main__":
565
+ demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
apps/safety_check.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
17
+ import json
18
+ from functools import lru_cache
19
+ from typing import Dict, Sequence
20
+
21
+ import torch
22
+ import transformers
23
+ from torch.nn.functional import softmax
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+
26
+ __all__ = ["is_dangerous"]
27
+
28
+ __rank, __local_rank, __world_size, __device = (
29
+ 0,
30
+ 0,
31
+ 1,
32
+ "cuda" if torch.cuda.is_available() else "cpu",
33
+ )
34
+
35
+
36
+ safety_policy = """
37
+ The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
38
+ """
39
+
40
+
41
+ def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
42
+ prompt = f"""You are a policy expert trying to help determine whether a user
43
+ prompt is in violation of the defined safety policies.
44
+
45
+ <start_of_turn>
46
+ Human Question: {user_prompt.strip()}
47
+ <end_of_turn>
48
+
49
+ Our safety principle is defined in the below:
50
+
51
+ {safety_policy.strip()}
52
+
53
+ Does the human question violate the above principle? Your answer must start
54
+ with 'Yes' or 'No'. And then walk through step by step to be sure we answer
55
+ correctly.
56
+ """
57
+
58
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
59
+ with torch.no_grad():
60
+ logits = model(**inputs).logits
61
+
62
+ # Extract the logits for the Yes and No tokens
63
+ vocab = tokenizer.get_vocab()
64
+ selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
65
+
66
+ # Convert these logits to a probability with softmax
67
+ probabilities = softmax(selected_logits, dim=0)
68
+
69
+ # Return probability of 'Yes'
70
+ score = probabilities[0].item()
71
+
72
+ return score > threshold
apps/sana_controlnet_pipeline.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ import warnings
17
+ from dataclasses import dataclass, field
18
+ from typing import Optional, Tuple
19
+
20
+ import numpy as np
21
+ import pyrallis
22
+ import torch
23
+ import torch.nn as nn
24
+ from PIL import Image
25
+
26
+ warnings.filterwarnings("ignore") # ignore warning
27
+
28
+
29
+ from diffusion import DPMS, FlowEuler
30
+ from diffusion.data.datasets.utils import (
31
+ ASPECT_RATIO_512_TEST,
32
+ ASPECT_RATIO_1024_TEST,
33
+ ASPECT_RATIO_2048_TEST,
34
+ ASPECT_RATIO_4096_TEST,
35
+ )
36
+ from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode
37
+ from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
38
+ from diffusion.utils.config import SanaConfig, model_init_config
39
+ from diffusion.utils.logger import get_root_logger
40
+ from tools.controlnet.utils import get_scribble_map, transform_control_signal
41
+ from tools.download import find_model
42
+
43
+
44
+ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
45
+ guidance_type = default_guidance_type
46
+ if not (pag_scale > 1.0 and attn_type == "linear"):
47
+ guidance_type = "classifier-free"
48
+ elif pag_scale > 1.0 and attn_type == "linear":
49
+ guidance_type = "classifier-free_PAG"
50
+ return guidance_type
51
+
52
+
53
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
54
+ """Returns binned height and width."""
55
+ ar = float(height / width)
56
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
57
+ default_hw = ratios[closest_ratio]
58
+ return int(default_hw[0]), int(default_hw[1])
59
+
60
+
61
+ def get_ar_from_ref_image(ref_image):
62
+ def reduce_ratio(h, w):
63
+ def gcd(a, b):
64
+ while b:
65
+ a, b = b, a % b
66
+ return a
67
+
68
+ divisor = gcd(h, w)
69
+ return f"{h // divisor}:{w // divisor}"
70
+
71
+ if isinstance(ref_image, str):
72
+ ref_image = Image.open(ref_image)
73
+ w, h = ref_image.size
74
+ return reduce_ratio(h, w)
75
+
76
+
77
+ @dataclass
78
+ class SanaControlNetInference(SanaConfig):
79
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
80
+ model_path: str = field(
81
+ default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
82
+ )
83
+ output: str = "./output"
84
+ bs: int = 1
85
+ image_size: int = 1024
86
+ cfg_scale: float = 5.0
87
+ pag_scale: float = 2.0
88
+ seed: int = 42
89
+ step: int = -1
90
+ custom_image_size: Optional[int] = None
91
+ shield_model_path: str = field(
92
+ default="google/shieldgemma-2b",
93
+ metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
94
+ )
95
+
96
+
97
+ class SanaControlNetPipeline(nn.Module):
98
+ def __init__(
99
+ self,
100
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
101
+ ):
102
+ super().__init__()
103
+ config = pyrallis.load(SanaControlNetInference, open(config))
104
+ self.args = self.config = config
105
+
106
+ # set some hyper-parameters
107
+ self.image_size = self.config.model.image_size
108
+
109
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
110
+ logger = get_root_logger()
111
+ self.logger = logger
112
+ self.progress_fn = lambda progress, desc: None
113
+ self.thickness = 2
114
+ self.blend_alpha = 0.0
115
+
116
+ self.latent_size = self.image_size // config.vae.vae_downsample_rate
117
+ self.max_sequence_length = config.text_encoder.model_max_length
118
+ self.flow_shift = config.scheduler.flow_shift
119
+ guidance_type = "classifier-free_PAG"
120
+
121
+ weight_dtype = get_weight_dtype(config.model.mixed_precision)
122
+ self.weight_dtype = weight_dtype
123
+ self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
124
+
125
+ self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
126
+ self.vis_sampler = self.config.scheduler.vis_sampler
127
+ logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
128
+ self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
129
+ logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
130
+
131
+ # 1. build vae and text encoder
132
+ self.vae = self.build_vae(config.vae)
133
+ self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
134
+
135
+ # 2. build Sana model
136
+ self.model = self.build_sana_model(config).to(self.device)
137
+
138
+ # 3. pre-compute null embedding
139
+ with torch.no_grad():
140
+ null_caption_token = self.tokenizer(
141
+ "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
142
+ ).to(self.device)
143
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
144
+ 0
145
+ ]
146
+
147
+ def build_vae(self, config):
148
+ vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
149
+ return vae
150
+
151
+ def build_text_encoder(self, config):
152
+ tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
153
+ return tokenizer, text_encoder
154
+
155
+ def build_sana_model(self, config):
156
+ # model setting
157
+ model_kwargs = model_init_config(config, latent_size=self.latent_size)
158
+ model = build_model(
159
+ config.model.model,
160
+ use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
161
+ **model_kwargs,
162
+ )
163
+ self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
164
+ self.logger.info(
165
+ f"{model.__class__.__name__}:{config.model.model},"
166
+ f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
167
+ )
168
+ return model
169
+
170
+ def from_pretrained(self, model_path):
171
+ state_dict = find_model(model_path)
172
+ state_dict = state_dict.get("state_dict", state_dict)
173
+ if "pos_embed" in state_dict:
174
+ del state_dict["pos_embed"]
175
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
176
+ self.model.eval().to(self.weight_dtype)
177
+
178
+ self.logger.info("Generating sample from ckpt: %s" % model_path)
179
+ self.logger.warning(f"Missing keys: {missing}")
180
+ self.logger.warning(f"Unexpected keys: {unexpected}")
181
+
182
+ def register_progress_bar(self, progress_fn=None):
183
+ self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
184
+
185
+ def set_blend_alpha(self, blend_alpha):
186
+ self.blend_alpha = blend_alpha
187
+
188
+ @torch.inference_mode()
189
+ def forward(
190
+ self,
191
+ prompt=None,
192
+ ref_image=None,
193
+ negative_prompt="",
194
+ num_inference_steps=20,
195
+ guidance_scale=5,
196
+ pag_guidance_scale=2.5,
197
+ num_images_per_prompt=1,
198
+ sketch_thickness=2,
199
+ generator=torch.Generator().manual_seed(42),
200
+ latents=None,
201
+ ):
202
+ self.ori_height, self.ori_width = ref_image.height, ref_image.width
203
+ self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
204
+
205
+ # 1. pre-compute negative embedding
206
+ if negative_prompt != "":
207
+ null_caption_token = self.tokenizer(
208
+ negative_prompt,
209
+ max_length=self.max_sequence_length,
210
+ padding="max_length",
211
+ truncation=True,
212
+ return_tensors="pt",
213
+ ).to(self.device)
214
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
215
+ 0
216
+ ]
217
+
218
+ if prompt is None:
219
+ prompt = [""]
220
+ prompts = prompt if isinstance(prompt, list) else [prompt]
221
+ samples = []
222
+
223
+ for prompt in prompts:
224
+ # data prepare
225
+ prompts, hw, ar = (
226
+ [],
227
+ torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
228
+ num_images_per_prompt, 1
229
+ ),
230
+ torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
231
+ )
232
+
233
+ ar = get_ar_from_ref_image(ref_image)
234
+ prompt += f" --ar {ar}"
235
+ for _ in range(num_images_per_prompt):
236
+ prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(
237
+ prompt, self.base_ratios, device=self.device, show=False
238
+ )
239
+ prompts.append(prompt_clean.strip())
240
+
241
+ self.latent_size_h, self.latent_size_w = (
242
+ int(hw[0, 0] // self.config.vae.vae_downsample_rate),
243
+ int(hw[0, 1] // self.config.vae.vae_downsample_rate),
244
+ )
245
+
246
+ with torch.no_grad():
247
+ # prepare text feature
248
+ if not self.config.text_encoder.chi_prompt:
249
+ max_length_all = self.config.text_encoder.model_max_length
250
+ prompts_all = prompts
251
+ else:
252
+ chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
253
+ prompts_all = [chi_prompt + prompt for prompt in prompts]
254
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
255
+ max_length_all = (
256
+ num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
257
+ ) # magic number 2: [bos], [_]
258
+
259
+ caption_token = self.tokenizer(
260
+ prompts_all,
261
+ max_length=max_length_all,
262
+ padding="max_length",
263
+ truncation=True,
264
+ return_tensors="pt",
265
+ ).to(device=self.device)
266
+ select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
267
+ caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
268
+ :, :, select_index
269
+ ].to(self.weight_dtype)
270
+ emb_masks = caption_token.attention_mask[:, select_index]
271
+ null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
272
+
273
+ n = len(prompts)
274
+ if latents is None:
275
+ z = torch.randn(
276
+ n,
277
+ self.config.vae.vae_latent_dim,
278
+ self.latent_size_h,
279
+ self.latent_size_w,
280
+ generator=generator,
281
+ device=self.device,
282
+ )
283
+ else:
284
+ z = latents.to(self.device)
285
+ model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
286
+
287
+ # control signal
288
+ if isinstance(ref_image, str):
289
+ ref_image = cv2.imread(ref_image)
290
+ elif isinstance(ref_image, Image.Image):
291
+ ref_image = np.array(ref_image)
292
+ control_signal = get_scribble_map(
293
+ input_image=ref_image,
294
+ det="Scribble_HED",
295
+ detect_resolution=int(hw.min()),
296
+ thickness=sketch_thickness,
297
+ )
298
+
299
+ control_signal = transform_control_signal(control_signal, hw).to(self.device).to(self.weight_dtype)
300
+
301
+ control_signal_latent = vae_encode(
302
+ self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
303
+ )
304
+
305
+ model_kwargs["control_signal"] = control_signal_latent
306
+
307
+ if self.vis_sampler == "flow_euler":
308
+ flow_solver = FlowEuler(
309
+ self.model,
310
+ condition=caption_embs,
311
+ uncondition=null_y,
312
+ cfg_scale=guidance_scale,
313
+ model_kwargs=model_kwargs,
314
+ )
315
+ sample = flow_solver.sample(
316
+ z,
317
+ steps=num_inference_steps,
318
+ )
319
+ elif self.vis_sampler == "flow_dpm-solver":
320
+ scheduler = DPMS(
321
+ self.model.forward_with_dpmsolver,
322
+ condition=caption_embs,
323
+ uncondition=null_y,
324
+ guidance_type=self.guidance_type,
325
+ cfg_scale=guidance_scale,
326
+ model_type="flow",
327
+ model_kwargs=model_kwargs,
328
+ schedule="FLOW",
329
+ )
330
+ scheduler.register_progress_bar(self.progress_fn)
331
+ sample = scheduler.sample(
332
+ z,
333
+ steps=num_inference_steps,
334
+ order=2,
335
+ skip_type="time_uniform_flow",
336
+ method="multistep",
337
+ flow_shift=self.flow_shift,
338
+ )
339
+
340
+ sample = sample.to(self.vae_dtype)
341
+ with torch.no_grad():
342
+ sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
343
+
344
+ if self.blend_alpha > 0:
345
+ print(f"blend image and mask with alpha: {self.blend_alpha}")
346
+ sample = sample * (1 - self.blend_alpha) + control_signal * self.blend_alpha
347
+
348
+ sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
349
+ samples.append(sample)
350
+
351
+ return sample
352
+
353
+ return samples
apps/sana_pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ import argparse
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional, Tuple
20
+
21
+ import pyrallis
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ warnings.filterwarnings("ignore") # ignore warning
26
+
27
+
28
+ from diffusion import DPMS, FlowEuler
29
+ from diffusion.data.datasets.utils import (
30
+ ASPECT_RATIO_512_TEST,
31
+ ASPECT_RATIO_1024_TEST,
32
+ ASPECT_RATIO_2048_TEST,
33
+ ASPECT_RATIO_4096_TEST,
34
+ )
35
+ from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
36
+ from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
37
+ from diffusion.utils.config import SanaConfig, model_init_config
38
+ from diffusion.utils.logger import get_root_logger
39
+
40
+ # from diffusion.utils.misc import read_config
41
+ from tools.download import find_model
42
+
43
+
44
+ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
45
+ guidance_type = default_guidance_type
46
+ if not (pag_scale > 1.0 and attn_type == "linear"):
47
+ guidance_type = "classifier-free"
48
+ elif pag_scale > 1.0 and attn_type == "linear":
49
+ guidance_type = "classifier-free_PAG"
50
+ return guidance_type
51
+
52
+
53
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
54
+ """Returns binned height and width."""
55
+ ar = float(height / width)
56
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
57
+ default_hw = ratios[closest_ratio]
58
+ return int(default_hw[0]), int(default_hw[1])
59
+
60
+
61
+ @dataclass
62
+ class SanaInference(SanaConfig):
63
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
64
+ model_path: str = field(
65
+ default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
66
+ )
67
+ output: str = "./output"
68
+ bs: int = 1
69
+ image_size: int = 1024
70
+ cfg_scale: float = 5.0
71
+ pag_scale: float = 2.0
72
+ seed: int = 42
73
+ step: int = -1
74
+ custom_image_size: Optional[int] = None
75
+ shield_model_path: str = field(
76
+ default="google/shieldgemma-2b",
77
+ metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
78
+ )
79
+
80
+
81
+ class SanaPipeline(nn.Module):
82
+ def __init__(
83
+ self,
84
+ config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
85
+ ):
86
+ super().__init__()
87
+ config = pyrallis.load(SanaInference, open(config))
88
+ self.args = self.config = config
89
+
90
+ # set some hyper-parameters
91
+ self.image_size = self.config.model.image_size
92
+
93
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
94
+ logger = get_root_logger()
95
+ self.logger = logger
96
+ self.progress_fn = lambda progress, desc: None
97
+
98
+ self.latent_size = self.image_size // config.vae.vae_downsample_rate
99
+ self.max_sequence_length = config.text_encoder.model_max_length
100
+ self.flow_shift = config.scheduler.flow_shift
101
+ guidance_type = "classifier-free_PAG"
102
+
103
+ weight_dtype = get_weight_dtype(config.model.mixed_precision)
104
+ self.weight_dtype = weight_dtype
105
+ self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
106
+
107
+ self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
108
+ self.vis_sampler = self.config.scheduler.vis_sampler
109
+ logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
110
+ self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
111
+ logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
112
+
113
+ # 1. build vae and text encoder
114
+ self.vae = self.build_vae(config.vae)
115
+ self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
116
+
117
+ # 2. build Sana model
118
+ self.model = self.build_sana_model(config).to(self.device)
119
+
120
+ # 3. pre-compute null embedding
121
+ with torch.no_grad():
122
+ null_caption_token = self.tokenizer(
123
+ "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
124
+ ).to(self.device)
125
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
126
+ 0
127
+ ]
128
+
129
+ def build_vae(self, config):
130
+ vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
131
+ return vae
132
+
133
+ def build_text_encoder(self, config):
134
+ tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
135
+ return tokenizer, text_encoder
136
+
137
+ def build_sana_model(self, config):
138
+ # model setting
139
+ model_kwargs = model_init_config(config, latent_size=self.latent_size)
140
+ model = build_model(
141
+ config.model.model,
142
+ use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
143
+ **model_kwargs,
144
+ )
145
+ self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
146
+ self.logger.info(
147
+ f"{model.__class__.__name__}:{config.model.model},"
148
+ f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
149
+ )
150
+ return model
151
+
152
+ def from_pretrained(self, model_path):
153
+ state_dict = find_model(model_path)
154
+ state_dict = state_dict.get("state_dict", state_dict)
155
+ if "pos_embed" in state_dict:
156
+ del state_dict["pos_embed"]
157
+ missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
158
+ self.model.eval().to(self.weight_dtype)
159
+
160
+ self.logger.info("Generating sample from ckpt: %s" % model_path)
161
+ self.logger.warning(f"Missing keys: {missing}")
162
+ self.logger.warning(f"Unexpected keys: {unexpected}")
163
+
164
+ def register_progress_bar(self, progress_fn=None):
165
+ self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
166
+
167
+ @torch.inference_mode()
168
+ def forward(
169
+ self,
170
+ prompt=None,
171
+ height=1024,
172
+ width=1024,
173
+ negative_prompt="",
174
+ num_inference_steps=20,
175
+ guidance_scale=5,
176
+ pag_guidance_scale=2.5,
177
+ num_images_per_prompt=1,
178
+ generator=torch.Generator().manual_seed(42),
179
+ latents=None,
180
+ ):
181
+ self.ori_height, self.ori_width = height, width
182
+ self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
183
+ self.latent_size_h, self.latent_size_w = (
184
+ self.height // self.config.vae.vae_downsample_rate,
185
+ self.width // self.config.vae.vae_downsample_rate,
186
+ )
187
+ self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
188
+
189
+ # 1. pre-compute negative embedding
190
+ if negative_prompt != "":
191
+ null_caption_token = self.tokenizer(
192
+ negative_prompt,
193
+ max_length=self.max_sequence_length,
194
+ padding="max_length",
195
+ truncation=True,
196
+ return_tensors="pt",
197
+ ).to(self.device)
198
+ self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
199
+ 0
200
+ ]
201
+
202
+ if prompt is None:
203
+ prompt = [""]
204
+ prompts = prompt if isinstance(prompt, list) else [prompt]
205
+ samples = []
206
+
207
+ for prompt in prompts:
208
+ # data prepare
209
+ prompts, hw, ar = (
210
+ [],
211
+ torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
212
+ num_images_per_prompt, 1
213
+ ),
214
+ torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
215
+ )
216
+
217
+ for _ in range(num_images_per_prompt):
218
+ prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
219
+
220
+ with torch.no_grad():
221
+ # prepare text feature
222
+ if not self.config.text_encoder.chi_prompt:
223
+ max_length_all = self.config.text_encoder.model_max_length
224
+ prompts_all = prompts
225
+ else:
226
+ chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
227
+ prompts_all = [chi_prompt + prompt for prompt in prompts]
228
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
229
+ max_length_all = (
230
+ num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
231
+ ) # magic number 2: [bos], [_]
232
+
233
+ caption_token = self.tokenizer(
234
+ prompts_all,
235
+ max_length=max_length_all,
236
+ padding="max_length",
237
+ truncation=True,
238
+ return_tensors="pt",
239
+ ).to(device=self.device)
240
+ select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
241
+ caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
242
+ :, :, select_index
243
+ ].to(self.weight_dtype)
244
+ emb_masks = caption_token.attention_mask[:, select_index]
245
+ null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
246
+
247
+ n = len(prompts)
248
+ if latents is None:
249
+ z = torch.randn(
250
+ n,
251
+ self.config.vae.vae_latent_dim,
252
+ self.latent_size_h,
253
+ self.latent_size_w,
254
+ generator=generator,
255
+ device=self.device,
256
+ )
257
+ else:
258
+ z = latents.to(self.device)
259
+ model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
260
+ if self.vis_sampler == "flow_euler":
261
+ flow_solver = FlowEuler(
262
+ self.model,
263
+ condition=caption_embs,
264
+ uncondition=null_y,
265
+ cfg_scale=guidance_scale,
266
+ model_kwargs=model_kwargs,
267
+ )
268
+ sample = flow_solver.sample(
269
+ z,
270
+ steps=num_inference_steps,
271
+ )
272
+ elif self.vis_sampler == "flow_dpm-solver":
273
+ scheduler = DPMS(
274
+ self.model,
275
+ condition=caption_embs,
276
+ uncondition=null_y,
277
+ guidance_type=self.guidance_type,
278
+ cfg_scale=guidance_scale,
279
+ pag_scale=pag_guidance_scale,
280
+ pag_applied_layers=self.config.model.pag_applied_layers,
281
+ model_type="flow",
282
+ model_kwargs=model_kwargs,
283
+ schedule="FLOW",
284
+ )
285
+ scheduler.register_progress_bar(self.progress_fn)
286
+ sample = scheduler.sample(
287
+ z,
288
+ steps=num_inference_steps,
289
+ order=2,
290
+ skip_type="time_uniform_flow",
291
+ method="multistep",
292
+ flow_shift=self.flow_shift,
293
+ )
294
+
295
+ sample = sample.to(self.vae_dtype)
296
+ with torch.no_grad():
297
+ sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
298
+
299
+ sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
300
+ samples.append(sample)
301
+
302
+ return sample
303
+
304
+ return samples