PommesPeter commited on
Commit
23dffff
1 Parent(s): c11232c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -6,13 +6,11 @@ subprocess.run("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Nex
6
  import argparse
7
  import builtins
8
  import json
9
- import threading
10
  import random
11
  import socket
12
  import spaces
13
  import traceback
14
  import os
15
- from queue import Queue
16
 
17
  import fairscale.nn.model_parallel.initialize as fs_init
18
  import gradio as gr
@@ -23,7 +21,8 @@ import torch.distributed as dist
23
  from torchvision.transforms.functional import to_pil_image
24
 
25
  from PIL import Image
26
- from threading import Thread
 
27
 
28
  import models
29
 
@@ -34,11 +33,8 @@ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
34
 
35
  description = """
36
  # Lumina Next Text-to-Image
37
-
38
  Lumina-Next-T2I is a 2B Next-DiT model with 2B text encoder.
39
-
40
  Demo current model: `Lumina-Next-T2I`
41
-
42
  """
43
 
44
  hf_token = os.environ['HF_TOKEN']
@@ -170,13 +166,15 @@ def load_model(args, master_port, rank):
170
 
171
 
172
  @torch.no_grad()
173
- def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
174
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
175
  args.precision
176
  ]
177
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
178
 
179
  with torch.autocast("cuda", dtype):
 
 
180
  while True:
181
  (
182
  cap,
@@ -439,7 +437,7 @@ def main():
439
  processes = []
440
  request_queues = []
441
  response_queue = Queue()
442
- # mp_barrier = mp.Barrier(args.num_gpus + 1)
443
  for i in range(args.num_gpus):
444
  text_encoder, tokenizer, vae, model = load_model(args, master_port, i)
445
  request_queues.append(Queue())
@@ -449,15 +447,15 @@ def main():
449
  rank=i,
450
  request_queue=request_queues[i],
451
  response_queue=response_queue if i == 0 else None,
 
452
  text_encoder=text_encoder,
453
  tokenizer=tokenizer,
454
  vae=vae,
455
  model=model
456
  )
457
- model_main(**generation_kwargs)
458
- # thread = Thread(target=model_main, kwargs=generation_kwargs)
459
- # thread.start()
460
- # processes.append(thread)
461
 
462
  with gr.Blocks() as demo:
463
  with gr.Row():
@@ -609,7 +607,7 @@ def main():
609
  [output_img],
610
  )
611
 
612
- # mp_barrier.wait()
613
  demo.queue(max_size=20).launch()
614
 
615
 
 
6
  import argparse
7
  import builtins
8
  import json
 
9
  import random
10
  import socket
11
  import spaces
12
  import traceback
13
  import os
 
14
 
15
  import fairscale.nn.model_parallel.initialize as fs_init
16
  import gradio as gr
 
21
  from torchvision.transforms.functional import to_pil_image
22
 
23
  from PIL import Image
24
+ from queue import Queue
25
+ from threading import Thread, Barrier
26
 
27
  import models
28
 
 
33
 
34
  description = """
35
  # Lumina Next Text-to-Image
 
36
  Lumina-Next-T2I is a 2B Next-DiT model with 2B text encoder.
 
37
  Demo current model: `Lumina-Next-T2I`
 
38
  """
39
 
40
  hf_token = os.environ['HF_TOKEN']
 
166
 
167
 
168
  @torch.no_grad()
169
+ def model_main(args, master_port, rank, request_queue, response_queue, barrier, text_encoder, tokenizer, vae, model):
170
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
171
  args.precision
172
  ]
173
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
174
 
175
  with torch.autocast("cuda", dtype):
176
+ barrier.wait()
177
+
178
  while True:
179
  (
180
  cap,
 
437
  processes = []
438
  request_queues = []
439
  response_queue = Queue()
440
+ barrier = Barrier(args.num_gpus + 1)
441
  for i in range(args.num_gpus):
442
  text_encoder, tokenizer, vae, model = load_model(args, master_port, i)
443
  request_queues.append(Queue())
 
447
  rank=i,
448
  request_queue=request_queues[i],
449
  response_queue=response_queue if i == 0 else None,
450
+ barrier=barrier,
451
  text_encoder=text_encoder,
452
  tokenizer=tokenizer,
453
  vae=vae,
454
  model=model
455
  )
456
+ thread = Thread(target=model_main, kwargs=generation_kwargs)
457
+ thread.start()
458
+ processes.append(thread)
 
459
 
460
  with gr.Blocks() as demo:
461
  with gr.Row():
 
607
  [output_img],
608
  )
609
 
610
+ barrier.wait()
611
  demo.queue(max_size=20).launch()
612
 
613