PommesPeter commited on
Commit
a935b35
1 Parent(s): 7690a4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -321
app.py CHANGED
@@ -13,44 +13,29 @@ snapshot_download(
13
  repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
14
  )
15
 
 
 
16
  import argparse
17
  import builtins
18
  import json
19
  import math
 
 
20
  import random
21
  import socket
22
-
23
- import spaces
24
  import traceback
25
 
26
- import fairscale.nn.model_parallel.initialize as fs_init
 
27
  import gradio as gr
28
  import numpy as np
29
-
30
  import torch
31
  import torch.distributed as dist
32
  from torchvision.transforms.functional import to_pil_image
33
 
34
- from PIL import Image
35
- from safetensors.torch import load_file
36
-
37
  import models
38
- from transport import create_transport, Sampler
39
-
40
- print(f"Is CUDA available: {torch.cuda.is_available()}")
41
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
42
-
43
- description = """
44
- # Lumina Next Text-to-Image
45
-
46
- #### Lumina-Next-T2I is a 2B `Next-DiT` model with `Gemma-2B` text encoder.
47
-
48
- #### Demo current model: `Lumina-Next-T2I`
49
-
50
- #### Lumina-Next supports higher-order solvers. <span style='color: orange;'>It can generate images with merely 10 steps without any distillation.
51
-
52
- """
53
- hf_token = os.environ["HF_TOKEN"]
54
 
55
 
56
  class ModelFailure:
@@ -58,10 +43,7 @@ class ModelFailure:
58
 
59
 
60
  # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
61
- def encode_prompt(
62
- prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True
63
- ):
64
-
65
  captions = []
66
  for caption in prompt_batch:
67
  if random.random() < proportion_empty_prompts:
@@ -94,10 +76,11 @@ def encode_prompt(
94
  return prompt_embeds, prompt_masks
95
 
96
 
 
97
  def load_models(args, master_port, rank):
98
  # import here to avoid huggingface Tokenizer parallelism warnings
99
  from diffusers.models import AutoencoderKL
100
- from transformers import AutoModelForCausalLM, AutoTokenizer
101
 
102
  # override the default print function since the delay can be large for child process
103
  original_print = builtins.print
@@ -111,127 +94,96 @@ def load_models(args, master_port, rank):
111
  builtins.print = print
112
 
113
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
114
- print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
115
-
116
- dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
117
- args.precision
118
- ]
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
120
 
 
 
121
  print(f"Creating lm: Gemma-2B")
122
- text_encoder = (
123
- AutoModelForCausalLM.from_pretrained(
124
- "google/gemma-2b",
125
- torch_dtype=dtype,
126
- device_map=device,
127
- # device_map="cuda",
128
- token=hf_token,
129
- )
130
- .get_decoder()
131
- .eval()
132
- )
133
  cap_feat_dim = text_encoder.config.hidden_size
134
- if args.num_gpus > 1:
135
- raise NotImplementedError("Inference with >1 GPUs not yet supported")
136
-
137
- tokenizer = AutoTokenizer.from_pretrained(
138
- "google/gemma-2b",
139
- add_bos_token=True,
140
- add_eos_token=True,
141
- token=hf_token,
142
- )
143
  tokenizer.padding_side = "right"
144
 
145
- print(f"Creating vae: sdxl-vae")
 
146
  vae = AutoencoderKL.from_pretrained(
147
- "stabilityai/sdxl-vae",
148
  torch_dtype=torch.float32,
149
- ).to(device)
150
 
151
- print(f"Creating DiT: Next-DiT")
152
  # latent_size = train_args.image_size // 8
153
- model = models.__dict__["NextDiT_2B_GQA_patch2"](
154
  qk_norm=train_args.qk_norm,
155
  cap_feat_dim=cap_feat_dim,
156
  )
157
- # model.eval().to("cuda", dtype=dtype)
158
  model.eval().to(device, dtype=dtype)
159
 
160
- assert train_args.model_parallel_size == args.num_gpus
161
  if args.ema:
162
  print("Loading ema model.")
163
  ckpt = load_file(
164
  os.path.join(
165
  args.ckpt,
166
  f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
167
- ),
168
  )
169
  model.load_state_dict(ckpt, strict=True)
170
-
171
  return text_encoder, tokenizer, vae, model
172
 
173
-
174
  @torch.no_grad()
175
  def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
176
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
177
  args.precision
178
  ]
179
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
180
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
  torch.cuda.set_device(0)
182
-
183
- # loading model to gpu
184
- # text_encoder = text_encoder.cuda()
185
- # vae = vae.cuda()
186
- # model = model.to("cuda", dtype=dtype)
187
-
188
  with torch.autocast("cuda", dtype):
189
- (
190
- cap,
191
- neg_cap,
192
- resolution,
193
- num_sampling_steps,
194
- cfg_scale,
195
- solver,
196
- t_shift,
197
- seed,
198
- scaling_method,
199
- proportional_attn,
200
- ) = infer_args
201
-
202
- metadata = dict(
203
- cap=cap,
204
- neg_cap=neg_cap,
205
- resolution=resolution,
206
- num_sampling_steps=num_sampling_steps,
207
- cfg_scale=cfg_scale,
208
- solver=solver,
209
- t_shift=t_shift,
210
- seed=seed,
211
- scaling_method=scaling_method,
212
- proportional_attn=proportional_attn,
213
- )
214
- print("> params:", json.dumps(metadata, indent=2))
215
-
216
- try:
217
- # begin sampler
218
- transport = create_transport(
219
- args.path_type,
220
- args.prediction,
221
- args.loss_weight,
222
- args.train_eps,
223
- args.sample_eps,
224
  )
225
- sampler = Sampler(transport)
226
- if args.likelihood:
227
- # assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" # todo
228
- sample_fn = sampler.sample_ode_likelihood(
229
- sampling_method=solver,
230
- num_steps=num_sampling_steps,
231
- atol=args.atol,
232
- rtol=args.rtol,
 
 
233
  )
234
- else:
235
  sample_fn = sampler.sample_ode(
236
  sampling_method=solver,
237
  num_steps=num_sampling_steps,
@@ -240,70 +192,67 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
240
  reverse=args.reverse,
241
  time_shifting_factor=t_shift,
242
  )
243
- # end sampler
244
-
245
- do_extrapolation = "Extrapolation" in resolution
246
- resolution = resolution.split(" ")[-1]
247
- w, h = resolution.split("x")
248
- w, h = int(w), int(h)
249
-
250
-
251
- latent_w, latent_h = w // 8, h // 8
252
- if int(seed) != 0:
253
- torch.random.manual_seed(int(seed))
254
- z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
255
- z = z.repeat(2, 1, 1, 1)
256
-
257
- with torch.no_grad():
258
- if neg_cap != "":
259
- cap_feats, cap_mask = encode_prompt(
260
- [cap] + [neg_cap],
261
- text_encoder,
262
- tokenizer,
263
- 0.0,
264
- )
 
 
 
 
 
 
265
  else:
266
- cap_feats, cap_mask = encode_prompt(
267
- [cap] + [""],
268
- text_encoder,
269
- tokenizer,
270
- 0.0,
271
- )
272
- cap_mask = cap_mask.to(cap_feats.device)
273
-
274
- model_kwargs = dict(
275
- cap_feats=cap_feats,
276
- cap_mask=cap_mask,
277
- cfg_scale=cfg_scale,
278
- )
279
-
280
- if proportional_attn:
281
- model_kwargs["proportional_attn"] = True
282
- model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
283
- if do_extrapolation and scaling_method == "Time-aware":
284
- model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size ** 2)
285
- else:
286
- model_kwargs["scale_factor"] = 1.0
287
 
288
- print(f"> scale factor: {model_kwargs['scale_factor']}")
 
 
 
 
 
289
 
290
- print("> start sample")
291
- samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
292
- samples = samples[:1]
 
293
 
294
- factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
295
- print(f"vae factor: {factor}")
 
296
 
297
- samples = vae.decode(samples / factor).sample
298
- samples = (samples + 1.0) / 2.0
299
- samples.clamp_(0.0, 1.0)
 
 
300
 
301
- img = to_pil_image(samples[0].float())
 
302
 
303
- return img, metadata
304
- except Exception:
305
- print(traceback.format_exc())
306
- return ModelFailure()
307
 
308
 
309
  def none_or_str(value):
@@ -335,12 +284,8 @@ def parse_transport_args(parser):
335
  choices=[None, "velocity", "likelihood"],
336
  help="the weighting of different components in the loss function, can be 'velocity' for dynamic modeling, 'likelihood' for statistical consistency, or None for no weighting.",
337
  )
338
- group.add_argument(
339
- "--sample-eps", type=float, help="sampling in the transport model."
340
- )
341
- group.add_argument(
342
- "--train-eps", type=float, help="training to stabilize the learning process."
343
- )
344
 
345
 
346
  def parse_ode_args(parser):
@@ -357,9 +302,7 @@ def parse_ode_args(parser):
357
  default=1e-3,
358
  help="Relative tolerance for the ODE solver.",
359
  )
360
- group.add_argument(
361
- "--reverse", action="store_true", help="run the ODE solver in reverse."
362
- )
363
  group.add_argument(
364
  "--likelihood",
365
  action="store_true",
@@ -367,47 +310,6 @@ def parse_ode_args(parser):
367
  )
368
 
369
 
370
- def parse_sde_args(parser):
371
- group = parser.add_argument_group("SDE arguments")
372
- group.add_argument(
373
- "--sampling-method",
374
- type=str,
375
- default="Euler",
376
- choices=["Euler", "Heun"],
377
- help="the numerical method used for sampling the stochastic differential equation: 'Euler' for simplicity or 'Heun' for improved accuracy.",
378
- )
379
- group.add_argument(
380
- "--diffusion-form",
381
- type=str,
382
- default="sigma",
383
- choices=[
384
- "constant",
385
- "SBDM",
386
- "sigma",
387
- "linear",
388
- "decreasing",
389
- "increasing-decreasing",
390
- ],
391
- help="form of diffusion coefficient in the SDE",
392
- )
393
- group.add_argument(
394
- "--diffusion-norm",
395
- type=float,
396
- default=1.0,
397
- help="Normalizes the diffusion coefficient, affecting the scale of the stochastic component.",
398
- )
399
- group.add_argument(
400
- "--last-step",
401
- type=none_or_str,
402
- default="Mean",
403
- choices=[None, "Mean", "Tweedie", "Euler"],
404
- help="form of last step taken in the SDE",
405
- )
406
- group.add_argument(
407
- "--last-step-size", type=float, default=0.04, help="size of the last step taken"
408
- )
409
-
410
-
411
  def find_free_port() -> int:
412
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
413
  sock.bind(("", 0))
@@ -418,7 +320,6 @@ def find_free_port() -> int:
418
 
419
  def main():
420
  parser = argparse.ArgumentParser()
421
- mode = "ODE"
422
 
423
  parser.add_argument("--num_gpus", type=int, default=1)
424
  parser.add_argument("--ckpt", type=str, default="/home/user/app/checkpoints")
@@ -427,15 +328,32 @@ def main():
427
 
428
  parse_transport_args(parser)
429
  parse_ode_args(parser)
 
430
  args = parser.parse_known_args()[0]
 
431
 
432
  if args.num_gpus != 1:
433
  raise NotImplementedError("Multi-GPU Inference is not yet supported")
434
 
435
- args.sampler_mode = mode
436
-
437
  text_encoder, tokenizer, vae, model = load_models(args, 60001, 0)
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  with gr.Blocks() as demo:
440
  with gr.Row():
441
  gr.Markdown(description)
@@ -457,22 +375,17 @@ def main():
457
  placeholder="Enter a negative caption.",
458
  )
459
  with gr.Row():
460
- res_choices = ["1024x1024", "512x2048", "2048x512"] + [
461
- "(Extrapolation) 2048x1920",
462
- "(Extrapolation) 1920x2048",
463
- "(Extrapolation) 1664x1664",
464
- "(Extrapolation) 1536x2560",
465
- "(Extrapolation) 2048x1024",
466
- "(Extrapolation) 1024x2048",
467
  ]
468
- resolution = gr.Dropdown(
469
- value=res_choices[0], choices=res_choices, label="Resolution"
470
- )
471
  with gr.Row():
472
  num_sampling_steps = gr.Slider(
473
  minimum=1,
474
  maximum=70,
475
- value=10,
476
  step=1,
477
  interactive=True,
478
  label="Sampling steps",
@@ -480,41 +393,48 @@ def main():
480
  seed = gr.Slider(
481
  minimum=0,
482
  maximum=int(1e5),
483
- value=1,
484
  step=1,
485
  interactive=True,
486
  label="Seed (0 for random)",
487
  )
488
- with gr.Accordion(
489
- "Advanced Settings for Resolution Extrapolation", open=False
490
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  with gr.Row():
492
- solver = gr.Dropdown(
493
- value="midpoint",
494
- choices=["euler", "midpoint", "rk4"],
495
- label="solver",
496
- )
497
- t_shift = gr.Slider(
498
- minimum=1,
499
- maximum=20,
500
- value=6,
501
- step=1,
502
- interactive=True,
503
- label="Time shift",
504
  )
505
- cfg_scale = gr.Slider(
506
- minimum=1.0,
507
- maximum=20.0,
508
- value=4.0,
509
  interactive=True,
510
- label="CFG scale",
 
511
  )
512
  with gr.Row():
513
- scaling_method = gr.Dropdown(
514
- value="Time-aware",
515
- choices=["Time-aware", "None"],
516
- label="Rope scaling method",
517
- )
518
  proportional_attn = gr.Checkbox(
519
  value=True,
520
  interactive=True,
@@ -523,11 +443,12 @@ def main():
523
  with gr.Row():
524
  submit_btn = gr.Button("Submit", variant="primary")
525
  with gr.Column():
 
526
  output_img = gr.Image(
527
- label="Lumina Generated image",
528
  interactive=False,
529
  format="png",
530
- show_label=False
531
  )
532
  with gr.Accordion(label="Generation Parameters", open=True):
533
  gr_metadata = gr.JSON(label="metadata", show_label=False)
@@ -535,65 +456,15 @@ def main():
535
  with gr.Row():
536
  gr.Examples(
537
  [
538
- ["👽🤖👹👻"],
539
- ["🐔 playing 🏀"],
540
- ["☃️ with 🌹 in the ❄️"],
541
- ["🐶 wearing 😎 flying on 🌈 "],
542
- ["A small 🍎 and 🍊 with 😁 emoji in the Sahara desert"],
543
- ["Astronaut on Mars During sunset"],
544
- [
545
- "A scared cute rabbit in Happy Tree Friends style and punk vibe."
546
- ],
547
- ["A humanoid eagle soldier of the First World War."], # noqa
548
- [
549
- "A cute Christmas mockup on an old wooden industrial desk table with Christmas decorations and bokeh lights in the background."
550
- ],
551
- [
552
- "A front view of a romantic flower shop in France filled with various blooming flowers including lavenders and roses."
553
- ],
554
- [
555
- "An old man, portrayed as a retro superhero, stands in the streets of New York City at night"
556
- ],
557
- [
558
- "many trees are surrounded by a lake in autumn colors, in the style of nature-inspired imagery, havencore, brightly colored, dark white and dark orange, bright primary colors, environmental activism, forestpunk"
559
- ],
560
- [
561
- "A fluffy mouse holding a watermelon, in a magical and colorful setting, illustrated in the style of Hayao Miyazaki anime by Studio Ghibli."
562
- ],
563
- ["孤舟蓑笠翁"],
564
- ["两只黄鹂鸣翠柳"],
565
- ["大漠孤烟直,长河落日圆"],
566
- ["秋风起兮白云飞,草木黄落兮雁南归"],
567
- ["味噌ラーメン, 最高品質の浮世絵、江戸時代。"],
568
- ["東京タワー、最高品質の浮世絵、江戸時代。"],
569
- ["도쿄 타워, 최고 품질의 우키요에, 에도 시대"],
570
- [
571
- "Tour de Tokyo, estampes ukiyo-e de la plus haute qualité, période Edo"
572
- ],
573
- ["Токийская башня, лучшие укиё-э, период Эдо"],
574
- ["Tokio-Turm, hochwertigste Ukiyo-e, Edo-Zeit"],
575
- [
576
- "Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution"
577
- ],
578
- [
579
- "Character of lion in style of saiyan, mafia, gangsta, citylights background, Hyper detailed, hyper realistic, unreal engine ue5, cgi 3d, cinematic shot, 8k"
580
- ],
581
- [
582
- "In the sky above, a giant, whimsical cloud shaped like the 😊 emoji casts a soft, golden light over the scene"
583
- ],
584
- [
585
- "Cyberpunk eagle, neon ambiance, abstract black oil, gear mecha, detailed acrylic, grunge, intricate complexity, rendered in unreal engine 5, photorealistic, 8k"
586
- ],
587
- [
588
- "close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light"
589
- ],
590
- [
591
- "3D cartoon Fox Head with Human Body, Wearing Iridescent Holographic Liquid Texture & Translucent Material Sun Protective Shirt, Boss Feel, Nike or Addidas Sun Protective Shirt, WitchPunk, Y2K Style, Green and blue, Blue, Metallic Feel, Strong Reflection, plain background, no background, pure single color background, Digital Fashion, Surreal Futurism, Supreme Kong NFT Artwork Style, disney style, headshot photography for portrait studio shoot, fashion editorial aesthetic, high resolution in the style of HAPE PRIME NFT, NFT 3D IP Feel, Bored Ape Yacht Club NFT project Feel, high detail, fine luster, 3D render, oc render, best quality, 8K, bright, front lighting, Face Shot, fine luster, ultra detailed"
592
- ],
593
  ],
594
  [cap],
595
  label="Examples",
596
- examples_per_page=22,
597
  )
598
 
599
  @spaces.GPU(duration=200)
@@ -601,7 +472,7 @@ def main():
601
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
602
  if isinstance(result, ModelFailure):
603
  raise RuntimeError("Model failed to generate the image.")
604
- return result
605
 
606
  submit_btn.click(
607
  on_submit,
@@ -615,12 +486,18 @@ def main():
615
  t_shift,
616
  seed,
617
  scaling_method,
 
618
  proportional_attn,
619
  ],
620
  [output_img, gr_metadata],
621
  )
622
 
623
- demo.queue().launch()
 
 
 
 
 
624
 
625
 
626
  if __name__ == "__main__":
 
13
  repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
14
  )
15
 
16
+ hf_token = os.environ["HF_TOKEN"]
17
+
18
  import argparse
19
  import builtins
20
  import json
21
  import math
22
+ import multiprocessing as mp
23
+ import os
24
  import random
25
  import socket
 
 
26
  import traceback
27
 
28
+ from PIL import Image
29
+ import spaces
30
  import gradio as gr
31
  import numpy as np
32
+ from safetensors.torch import load_file
33
  import torch
34
  import torch.distributed as dist
35
  from torchvision.transforms.functional import to_pil_image
36
 
 
 
 
37
  import models
38
+ from transport import Sampler, create_transport
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  class ModelFailure:
 
43
 
44
 
45
  # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
46
+ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
 
 
 
47
  captions = []
48
  for caption in prompt_batch:
49
  if random.random() < proportion_empty_prompts:
 
76
  return prompt_embeds, prompt_masks
77
 
78
 
79
+ @torch.no_grad()
80
  def load_models(args, master_port, rank):
81
  # import here to avoid huggingface Tokenizer parallelism warnings
82
  from diffusers.models import AutoencoderKL
83
+ from transformers import AutoModel, AutoTokenizer
84
 
85
  # override the default print function since the delay can be large for child process
86
  original_print = builtins.print
 
94
  builtins.print = print
95
 
96
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
97
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
 
 
 
 
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
 
100
+ print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
101
+
102
  print(f"Creating lm: Gemma-2B")
103
+ text_encoder = AutoModel.from_pretrained(
104
+ "google/gemma-2b", torch_dtype=dtype, device_map=device, token=hf_token
105
+ ).eval()
 
 
 
 
 
 
 
 
106
  cap_feat_dim = text_encoder.config.hidden_size
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=hf_token)
 
 
 
 
 
 
 
109
  tokenizer.padding_side = "right"
110
 
111
+
112
+ print(f"Creating vae: {train_args.vae}")
113
  vae = AutoencoderKL.from_pretrained(
114
+ (f"stabilityai/sd-vae-ft-{train_args.vae}" if train_args.vae != "sdxl" else "stabilityai/sdxl-vae"),
115
  torch_dtype=torch.float32,
116
+ ).cuda()
117
 
118
+ print(f"Creating Next-DiT: {train_args.model}")
119
  # latent_size = train_args.image_size // 8
120
+ model = models.__dict__[train_args.model](
121
  qk_norm=train_args.qk_norm,
122
  cap_feat_dim=cap_feat_dim,
123
  )
 
124
  model.eval().to(device, dtype=dtype)
125
 
 
126
  if args.ema:
127
  print("Loading ema model.")
128
  ckpt = load_file(
129
  os.path.join(
130
  args.ckpt,
131
  f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
132
+ )
133
  )
134
  model.load_state_dict(ckpt, strict=True)
135
+
136
  return text_encoder, tokenizer, vae, model
137
 
 
138
  @torch.no_grad()
139
  def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
140
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
141
  args.precision
142
  ]
143
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
 
144
  torch.cuda.set_device(0)
145
+
 
 
 
 
 
146
  with torch.autocast("cuda", dtype):
147
+ while True:
148
+ (
149
+ cap,
150
+ neg_cap,
151
+ resolution,
152
+ num_sampling_steps,
153
+ cfg_scale,
154
+ solver,
155
+ t_shift,
156
+ seed,
157
+ scaling_method,
158
+ scaling_watershed,
159
+ proportional_attn,
160
+ ) = infer_args
161
+
162
+ metadata = dict(
163
+ cap=cap,
164
+ neg_cap=neg_cap,
165
+ resolution=resolution,
166
+ num_sampling_steps=num_sampling_steps,
167
+ cfg_scale=cfg_scale,
168
+ solver=solver,
169
+ t_shift=t_shift,
170
+ seed=seed,
171
+ scaling_method=scaling_method,
172
+ scaling_watershed=scaling_watershed,
173
+ proportional_attn=proportional_attn,
 
 
 
 
 
 
 
 
174
  )
175
+ print("> params:", json.dumps(metadata, indent=2))
176
+
177
+ try:
178
+ # begin sampler
179
+ transport = create_transport(
180
+ args.path_type,
181
+ args.prediction,
182
+ args.loss_weight,
183
+ args.train_eps,
184
+ args.sample_eps,
185
  )
186
+ sampler = Sampler(transport)
187
  sample_fn = sampler.sample_ode(
188
  sampling_method=solver,
189
  num_steps=num_sampling_steps,
 
192
  reverse=args.reverse,
193
  time_shifting_factor=t_shift,
194
  )
195
+ # end sampler
196
+
197
+ do_extrapolation = "Extrapolation" in resolution
198
+ resolution = resolution.split(" ")[-1]
199
+ w, h = resolution.split("x")
200
+ w, h = int(w), int(h)
201
+ latent_w, latent_h = w // 8, h // 8
202
+ if int(seed) != 0:
203
+ torch.random.manual_seed(int(seed))
204
+ z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
205
+ z = z.repeat(2, 1, 1, 1)
206
+
207
+ with torch.no_grad():
208
+ if neg_cap != "":
209
+ cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
210
+ else:
211
+ cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
212
+
213
+ cap_mask = cap_mask.to(cap_feats.device)
214
+
215
+ model_kwargs = dict(
216
+ cap_feats=cap_feats,
217
+ cap_mask=cap_mask,
218
+ cfg_scale=cfg_scale,
219
+ )
220
+ if proportional_attn:
221
+ model_kwargs["proportional_attn"] = True
222
+ model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
223
  else:
224
+ model_kwargs["proportional_attn"] = False
225
+ model_kwargs["base_seqlen"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ if do_extrapolation and scaling_method == "Time-aware":
228
+ model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size**2)
229
+ model_kwargs["scale_watershed"] = scaling_watershed
230
+ else:
231
+ model_kwargs["scale_factor"] = 1.0
232
+ model_kwargs["scale_watershed"] = 1.0
233
 
234
+ if dist.get_rank() == 0:
235
+ print(f"> caption: {cap}")
236
+ print(f"> num_sampling_steps: {num_sampling_steps}")
237
+ print(f"> cfg_scale: {cfg_scale}")
238
 
239
+ print("> start sample")
240
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
241
+ samples = samples[:1]
242
 
243
+ factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
244
+ print(f"> vae factor: {factor}")
245
+ samples = vae.decode(samples / factor).sample
246
+ samples = (samples + 1.0) / 2.0
247
+ samples.clamp_(0.0, 1.0)
248
 
249
+ img = to_pil_image(samples[0].float())
250
+ print("> generated image, done.")
251
 
252
+ return img, metadata
253
+ except Exception:
254
+ print(traceback.format_exc())
255
+ return ModelFailure()
256
 
257
 
258
  def none_or_str(value):
 
284
  choices=[None, "velocity", "likelihood"],
285
  help="the weighting of different components in the loss function, can be 'velocity' for dynamic modeling, 'likelihood' for statistical consistency, or None for no weighting.",
286
  )
287
+ group.add_argument("--sample-eps", type=float, help="sampling in the transport model.")
288
+ group.add_argument("--train-eps", type=float, help="training to stabilize the learning process.")
 
 
 
 
289
 
290
 
291
  def parse_ode_args(parser):
 
302
  default=1e-3,
303
  help="Relative tolerance for the ODE solver.",
304
  )
305
+ group.add_argument("--reverse", action="store_true", help="run the ODE solver in reverse.")
 
 
306
  group.add_argument(
307
  "--likelihood",
308
  action="store_true",
 
310
  )
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def find_free_port() -> int:
314
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
315
  sock.bind(("", 0))
 
320
 
321
  def main():
322
  parser = argparse.ArgumentParser()
 
323
 
324
  parser.add_argument("--num_gpus", type=int, default=1)
325
  parser.add_argument("--ckpt", type=str, default="/home/user/app/checkpoints")
 
328
 
329
  parse_transport_args(parser)
330
  parse_ode_args(parser)
331
+
332
  args = parser.parse_known_args()[0]
333
+ args.sampler_mode = "ODE"
334
 
335
  if args.num_gpus != 1:
336
  raise NotImplementedError("Multi-GPU Inference is not yet supported")
337
 
 
 
338
  text_encoder, tokenizer, vae, model = load_models(args, 60001, 0)
339
 
340
+ description = """
341
+ # Lumina Next Text-to-Image
342
+
343
+ Lumina-Next-T2I is a 2B Next-DiT model with 2B text encoder.
344
+
345
+ Demo current model: `Lumina-Next-T2I 1k Resolution`
346
+
347
+ ### <span style='color: red;'> Lumina-Next-T2I enables zero-shot resolution extrapolation to 2k.
348
+
349
+ ### Lumina-Next supports higher-order solvers ["euler", "midpoint"].
350
+ ### <span style='color: orange;'>It can generate images with merely 10 steps without any distillation for 1K resolution generation.
351
+
352
+ ### To reduce waiting times, we are offering three parallel demos:
353
+
354
+ Lumina-T2I 2B model: [[demo (supported 2k inference)](http://106.14.2.150:10020/)] [[demo](http://106.14.2.150:10021/)] [[demo](http://106.14.2.150:10022/)] [[demo (compositional generation)](http://106.14.2.150:10023/)]
355
+
356
+ """
357
  with gr.Blocks() as demo:
358
  with gr.Row():
359
  gr.Markdown(description)
 
375
  placeholder="Enter a negative caption.",
376
  )
377
  with gr.Row():
378
+ res_choices = [
379
+ "1024x1024",
380
+ "512x2048",
381
+ "2048x512",
 
 
 
382
  ]
383
+ resolution = gr.Dropdown(value=res_choices[0], choices=res_choices, label="Resolution")
 
 
384
  with gr.Row():
385
  num_sampling_steps = gr.Slider(
386
  minimum=1,
387
  maximum=70,
388
+ value=30,
389
  step=1,
390
  interactive=True,
391
  label="Sampling steps",
 
393
  seed = gr.Slider(
394
  minimum=0,
395
  maximum=int(1e5),
396
+ value=25,
397
  step=1,
398
  interactive=True,
399
  label="Seed (0 for random)",
400
  )
401
+ with gr.Row():
402
+ solver = gr.Dropdown(
403
+ value="midpoint",
404
+ choices=["euler", "midpoint"],
405
+ label="Solver",
406
+ )
407
+ t_shift = gr.Slider(
408
+ minimum=1,
409
+ maximum=20,
410
+ value=6,
411
+ step=1,
412
+ interactive=True,
413
+ label="Time shift",
414
+ )
415
+ cfg_scale = gr.Slider(
416
+ minimum=1.0,
417
+ maximum=20.0,
418
+ value=4.0,
419
+ interactive=True,
420
+ label="CFG scale",
421
+ )
422
+ with gr.Accordion("Advanced Settings for Resolution Extrapolation", open=False, visible=False):
423
  with gr.Row():
424
+ scaling_method = gr.Dropdown(
425
+ value="None",
426
+ choices=["None"],
427
+ label="RoPE scaling method",
 
 
 
 
 
 
 
 
428
  )
429
+ scaling_watershed = gr.Slider(
430
+ minimum=0.0,
431
+ maximum=1.0,
432
+ value=0.3,
433
  interactive=True,
434
+ label="Linear/NTK watershed",
435
+ visible=False,
436
  )
437
  with gr.Row():
 
 
 
 
 
438
  proportional_attn = gr.Checkbox(
439
  value=True,
440
  interactive=True,
 
443
  with gr.Row():
444
  submit_btn = gr.Button("Submit", variant="primary")
445
  with gr.Column():
446
+ default_img = Image.open("./image.png")
447
  output_img = gr.Image(
448
+ label="Generated image",
449
  interactive=False,
450
  format="png",
451
+ value=default_img,
452
  )
453
  with gr.Accordion(label="Generation Parameters", open=True):
454
  gr_metadata = gr.JSON(label="metadata", show_label=False)
 
456
  with gr.Row():
457
  gr.Examples(
458
  [
459
+ ["An old sailor, weathered by years at sea, stands at the helm of his ship, eyes scanning the horizon for signs of land, his face lined with tales of adventure and hardship."], # noqa
460
+ ["A regal swan glides gracefully across the surface of a tranquil lake, its snowy white feathers ruffled by the gentle breeze."], # noqa
461
+ ["A cunning fox, agilely weaving through the forest, its eyes sharp and alert, always ready for prey."], # noqa
462
+ ["Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution."], # noqa
463
+ ["Quaint rustic witch's cabin by the lake, autumn forest background, orange and honey colors, beautiful composition, magical, warm glowing lighting, cloudy, dreamy masterpiece, Nikon D610, photorealism, highly artistic, highly detailed, ultra high resolution, sharp focus, Mysterious."], # noqa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  ],
465
  [cap],
466
  label="Examples",
467
+ examples_per_page=80,
468
  )
469
 
470
  @spaces.GPU(duration=200)
 
472
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
473
  if isinstance(result, ModelFailure):
474
  raise RuntimeError("Model failed to generate the image.")
475
+ return result
476
 
477
  submit_btn.click(
478
  on_submit,
 
486
  t_shift,
487
  seed,
488
  scaling_method,
489
+ scaling_watershed,
490
  proportional_attn,
491
  ],
492
  [output_img, gr_metadata],
493
  )
494
 
495
+ def show_scaling_watershed(scaling_m):
496
+ return gr.update(visible=scaling_m == "Time-aware")
497
+
498
+ scaling_method.change(show_scaling_watershed, scaling_method, scaling_watershed)
499
+
500
+ demo.queue().launch(server_name="0.0.0.0")
501
 
502
 
503
  if __name__ == "__main__":