PommesPeter commited on
Commit
005776b
1 Parent(s): 3d04fb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -43
app.py CHANGED
@@ -7,9 +7,8 @@ subprocess.run(
7
  shell=True,
8
  )
9
 
10
- from huggingface_hub import snapshot_download
11
-
12
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
 
13
  snapshot_download(
14
  repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
15
  )
@@ -32,8 +31,7 @@ import torch.distributed as dist
32
  from torchvision.transforms.functional import to_pil_image
33
 
34
  from PIL import Image
35
- from queue import Queue
36
- from threading import Thread, Barrier
37
 
38
  import models
39
 
@@ -50,7 +48,6 @@ description = """
50
  #### Demo current model: `Lumina-Next-T2I`
51
 
52
  """
53
-
54
  hf_token = os.environ["HF_TOKEN"]
55
 
56
 
@@ -161,12 +158,11 @@ def load_models(args, master_port, rank):
161
  assert train_args.model_parallel_size == args.num_gpus
162
  if args.ema:
163
  print("Loading ema model.")
164
- ckpt = torch.load(
165
  os.path.join(
166
  args.ckpt,
167
- f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.pth",
168
  ),
169
- map_location="cpu",
170
  )
171
  model.load_state_dict(ckpt, strict=True)
172
 
@@ -179,17 +175,15 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
179
  args.precision
180
  ]
181
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
182
-
183
- print(args)
184
- device = "cuda" if torch.cuda.is_available() else "cpu"
185
  torch.cuda.set_device(0)
186
-
187
  # loading model to gpu
188
  # text_encoder = text_encoder.cuda()
189
  # vae = vae.cuda()
190
  # model = model.to("cuda", dtype=dtype)
191
 
192
- with torch.autocast("cuda", dtype):
193
  (
194
  cap,
195
  resolution,
@@ -202,18 +196,19 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
202
  proportional_attn,
203
  ) = infer_args
204
 
205
- print(
206
- "> params:",
207
- cap,
208
- resolution,
209
- num_sampling_steps,
210
- cfg_scale,
211
- solver,
212
- t_shift,
213
- seed,
214
- ntk_scaling,
215
- proportional_attn,
216
  )
 
 
217
  try:
218
  # begin sampler
219
  transport = create_transport(
@@ -249,7 +244,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
249
  latent_w, latent_h = w // 8, h // 8
250
  if int(seed) != 0:
251
  torch.random.manual_seed(int(seed))
252
- z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
253
  z = z.repeat(2, 1, 1, 1)
254
 
255
  with torch.no_grad():
@@ -276,13 +271,8 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
276
  ntk_factor=ntk_factor,
277
  )
278
 
279
- print(f"caption: {cap}")
280
- print(f"num_sampling_steps: {num_sampling_steps}")
281
- print(f"cfg_scale: {cfg_scale}")
282
-
283
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
284
- print("> [debug] start sample")
285
- samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
286
  samples = samples[:1]
287
 
288
  factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
@@ -294,7 +284,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
294
 
295
  img = to_pil_image(samples[0].float())
296
 
297
- return img
298
  except Exception:
299
  print(traceback.format_exc())
300
  return ModelFailure()
@@ -505,18 +495,15 @@ def main():
505
  )
506
  with gr.Row():
507
  submit_btn = gr.Button("Submit", variant="primary")
508
- # reset_btn = gr.ClearButton([
509
- # cap, resolution,
510
- # num_sampling_steps, cfg_scale, solver,
511
- # t_shift, seed,
512
- # ntk_scaling, proportional_attn
513
- # ])
514
  with gr.Column():
515
  output_img = gr.Image(
516
  label="Lumina Generated image",
517
  interactive=False,
518
  format="png",
 
519
  )
 
 
520
 
521
  with gr.Row():
522
  gr.Examples(
@@ -582,8 +569,8 @@ def main():
582
  examples_per_page=22,
583
  )
584
 
585
- @spaces.GPU(duration=200)
586
- def on_submit(*infer_args):
587
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
588
  if isinstance(result, ModelFailure):
589
  raise RuntimeError("Model failed to generate the image.")
@@ -602,10 +589,10 @@ def main():
602
  ntk_scaling,
603
  proportional_attn,
604
  ],
605
- [output_img],
606
  )
607
 
608
- demo.queue(max_size=20).launch()
609
 
610
 
611
  if __name__ == "__main__":
 
7
  shell=True,
8
  )
9
 
 
 
10
  os.makedirs("/home/user/app/checkpoints", exist_ok=True)
11
+ from huggingface_hub import snapshot_download
12
  snapshot_download(
13
  repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints"
14
  )
 
31
  from torchvision.transforms.functional import to_pil_image
32
 
33
  from PIL import Image
34
+ from safetensors.torch import load_file
 
35
 
36
  import models
37
 
 
48
  #### Demo current model: `Lumina-Next-T2I`
49
 
50
  """
 
51
  hf_token = os.environ["HF_TOKEN"]
52
 
53
 
 
158
  assert train_args.model_parallel_size == args.num_gpus
159
  if args.ema:
160
  print("Loading ema model.")
161
+ ckpt = load_file(
162
  os.path.join(
163
  args.ckpt,
164
+ f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.safetensors",
165
  ),
 
166
  )
167
  model.load_state_dict(ckpt, strict=True)
168
 
 
175
  args.precision
176
  ]
177
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
178
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
179
  torch.cuda.set_device(0)
180
+
181
  # loading model to gpu
182
  # text_encoder = text_encoder.cuda()
183
  # vae = vae.cuda()
184
  # model = model.to("cuda", dtype=dtype)
185
 
186
+ with torch.autocast(device, dtype):
187
  (
188
  cap,
189
  resolution,
 
196
  proportional_attn,
197
  ) = infer_args
198
 
199
+ metadata = dict(
200
+ cap=cap,
201
+ resolution=resolution,
202
+ num_sampling_steps=num_sampling_steps,
203
+ cfg_scale=cfg_scale,
204
+ solver=solver,
205
+ t_shift=t_shift,
206
+ seed=seed,
207
+ ntk_scaling=ntk_scaling,
208
+ proportional_attn=proportional_attn,
 
209
  )
210
+ print("> params:", json.dumps(metadata, indent=2))
211
+
212
  try:
213
  # begin sampler
214
  transport = create_transport(
 
244
  latent_w, latent_h = w // 8, h // 8
245
  if int(seed) != 0:
246
  torch.random.manual_seed(int(seed))
247
+ z = torch.randn([1, 4, latent_h, latent_w], device=device).to(dtype)
248
  z = z.repeat(2, 1, 1, 1)
249
 
250
  with torch.no_grad():
 
271
  ntk_factor=ntk_factor,
272
  )
273
 
274
+ print("> start sample")
275
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
 
 
 
 
 
276
  samples = samples[:1]
277
 
278
  factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
 
284
 
285
  img = to_pil_image(samples[0].float())
286
 
287
+ return img, metadata
288
  except Exception:
289
  print(traceback.format_exc())
290
  return ModelFailure()
 
495
  )
496
  with gr.Row():
497
  submit_btn = gr.Button("Submit", variant="primary")
 
 
 
 
 
 
498
  with gr.Column():
499
  output_img = gr.Image(
500
  label="Lumina Generated image",
501
  interactive=False,
502
  format="png",
503
+ show_label=False
504
  )
505
+ with gr.Accordion(label="Generation Parameters", open=False):
506
+ gr_metadata = gr.JSON(label="metadata", show_label=False)
507
 
508
  with gr.Row():
509
  gr.Examples(
 
569
  examples_per_page=22,
570
  )
571
 
572
+ @spaces.GPU(duration=80)
573
+ def on_submit(*infer_args, progress=gr.Progress(track_tqdm=True),):
574
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
575
  if isinstance(result, ModelFailure):
576
  raise RuntimeError("Model failed to generate the image.")
 
589
  ntk_scaling,
590
  proportional_attn,
591
  ],
592
+ [output_img, gr_metadata],
593
  )
594
 
595
+ demo.queue().launch()
596
 
597
 
598
  if __name__ == "__main__":