PommesPeter commited on
Commit
95eaf4d
1 Parent(s): 94a82b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import subprocess
2
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
3
- subprocess.run("mkdir -p ./checkpoints", shell=True)
 
 
 
4
 
5
  import argparse
6
  import builtins
@@ -151,8 +154,6 @@ def load_model(args, master_port, rank):
151
  assert train_args.model_parallel_size == args.num_gpus
152
  if args.ema:
153
  print("Loading ema model.")
154
-
155
- subprocess.run("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Next-T2I --local-dir ./checkpoints --local-dir-use-symlinks False", shell=True)
156
  ckpt = torch.load(
157
  os.path.join(
158
  args.ckpt,
@@ -166,13 +167,15 @@ def load_model(args, master_port, rank):
166
  return text_encoder, tokenizer, vae, model
167
 
168
 
 
169
  @torch.no_grad()
170
  def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
171
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
172
  args.precision
173
  ]
174
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
175
-
 
176
  with torch.autocast("cuda", dtype):
177
  # barrier.wait()
178
  while True:
@@ -407,7 +410,6 @@ def find_free_port() -> int:
407
  return port
408
 
409
 
410
- @spaces.GPU
411
  def main():
412
  parser = argparse.ArgumentParser()
413
  mode = "ODE"
@@ -439,7 +441,6 @@ def main():
439
  # mp_barrier = mp.Barrier(args.num_gpus + 1)
440
  # barrier = Barrier(args.num_gpus + 1)
441
  for i in range(args.num_gpus):
442
- text_encoder, tokenizer, vae, model = load_model(args, master_port, i)
443
  request_queues.append(Queue())
444
  generation_kwargs = dict(
445
  args=args,
@@ -447,10 +448,6 @@ def main():
447
  rank=i,
448
  request_queue=request_queues[i],
449
  response_queue=response_queue if i == 0 else None,
450
- text_encoder=text_encoder,
451
- tokenizer=tokenizer,
452
- vae=vae,
453
- model=model
454
  )
455
  model_main(**generation_kwargs)
456
  # thread = Thread(target=model_main, kwargs=generation_kwargs)
 
1
  import subprocess
2
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
3
+
4
+ from huggingface_hub import snapshot_download
5
+ os.makedirs("/home/user/app/checkpoints", exist_ok=True)
6
+ snapshot_download(repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints")
7
 
8
  import argparse
9
  import builtins
 
154
  assert train_args.model_parallel_size == args.num_gpus
155
  if args.ema:
156
  print("Loading ema model.")
 
 
157
  ckpt = torch.load(
158
  os.path.join(
159
  args.ckpt,
 
167
  return text_encoder, tokenizer, vae, model
168
 
169
 
170
+ @spaces.GPU(duration=80)
171
  @torch.no_grad()
172
  def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
173
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
174
  args.precision
175
  ]
176
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
177
+ text_encoder, tokenizer, vae, model = load_model(args, master_port, rank)
178
+
179
  with torch.autocast("cuda", dtype):
180
  # barrier.wait()
181
  while True:
 
410
  return port
411
 
412
 
 
413
  def main():
414
  parser = argparse.ArgumentParser()
415
  mode = "ODE"
 
441
  # mp_barrier = mp.Barrier(args.num_gpus + 1)
442
  # barrier = Barrier(args.num_gpus + 1)
443
  for i in range(args.num_gpus):
 
444
  request_queues.append(Queue())
445
  generation_kwargs = dict(
446
  args=args,
 
448
  rank=i,
449
  request_queue=request_queues[i],
450
  response_queue=response_queue if i == 0 else None,
 
 
 
 
451
  )
452
  model_main(**generation_kwargs)
453
  # thread = Thread(target=model_main, kwargs=generation_kwargs)