Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import subprocess
|
2 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
3 |
-
|
|
|
|
|
|
|
4 |
|
5 |
import argparse
|
6 |
import builtins
|
@@ -151,8 +154,6 @@ def load_model(args, master_port, rank):
|
|
151 |
assert train_args.model_parallel_size == args.num_gpus
|
152 |
if args.ema:
|
153 |
print("Loading ema model.")
|
154 |
-
|
155 |
-
subprocess.run("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Next-T2I --local-dir ./checkpoints --local-dir-use-symlinks False", shell=True)
|
156 |
ckpt = torch.load(
|
157 |
os.path.join(
|
158 |
args.ckpt,
|
@@ -166,13 +167,15 @@ def load_model(args, master_port, rank):
|
|
166 |
return text_encoder, tokenizer, vae, model
|
167 |
|
168 |
|
|
|
169 |
@torch.no_grad()
|
170 |
def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
|
171 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
172 |
args.precision
|
173 |
]
|
174 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
175 |
-
|
|
|
176 |
with torch.autocast("cuda", dtype):
|
177 |
# barrier.wait()
|
178 |
while True:
|
@@ -407,7 +410,6 @@ def find_free_port() -> int:
|
|
407 |
return port
|
408 |
|
409 |
|
410 |
-
@spaces.GPU
|
411 |
def main():
|
412 |
parser = argparse.ArgumentParser()
|
413 |
mode = "ODE"
|
@@ -439,7 +441,6 @@ def main():
|
|
439 |
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
440 |
# barrier = Barrier(args.num_gpus + 1)
|
441 |
for i in range(args.num_gpus):
|
442 |
-
text_encoder, tokenizer, vae, model = load_model(args, master_port, i)
|
443 |
request_queues.append(Queue())
|
444 |
generation_kwargs = dict(
|
445 |
args=args,
|
@@ -447,10 +448,6 @@ def main():
|
|
447 |
rank=i,
|
448 |
request_queue=request_queues[i],
|
449 |
response_queue=response_queue if i == 0 else None,
|
450 |
-
text_encoder=text_encoder,
|
451 |
-
tokenizer=tokenizer,
|
452 |
-
vae=vae,
|
453 |
-
model=model
|
454 |
)
|
455 |
model_main(**generation_kwargs)
|
456 |
# thread = Thread(target=model_main, kwargs=generation_kwargs)
|
|
|
1 |
import subprocess
|
2 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
3 |
+
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
os.makedirs("/home/user/app/checkpoints", exist_ok=True)
|
6 |
+
snapshot_download(repo_id="Alpha-VLLM/Lumina-Next-T2I", local_dir="/home/user/app/checkpoints")
|
7 |
|
8 |
import argparse
|
9 |
import builtins
|
|
|
154 |
assert train_args.model_parallel_size == args.num_gpus
|
155 |
if args.ema:
|
156 |
print("Loading ema model.")
|
|
|
|
|
157 |
ckpt = torch.load(
|
158 |
os.path.join(
|
159 |
args.ckpt,
|
|
|
167 |
return text_encoder, tokenizer, vae, model
|
168 |
|
169 |
|
170 |
+
@spaces.GPU(duration=80)
|
171 |
@torch.no_grad()
|
172 |
def model_main(args, master_port, rank, request_queue, response_queue, text_encoder, tokenizer, vae, model):
|
173 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
174 |
args.precision
|
175 |
]
|
176 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
177 |
+
text_encoder, tokenizer, vae, model = load_model(args, master_port, rank)
|
178 |
+
|
179 |
with torch.autocast("cuda", dtype):
|
180 |
# barrier.wait()
|
181 |
while True:
|
|
|
410 |
return port
|
411 |
|
412 |
|
|
|
413 |
def main():
|
414 |
parser = argparse.ArgumentParser()
|
415 |
mode = "ODE"
|
|
|
441 |
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
442 |
# barrier = Barrier(args.num_gpus + 1)
|
443 |
for i in range(args.num_gpus):
|
|
|
444 |
request_queues.append(Queue())
|
445 |
generation_kwargs = dict(
|
446 |
args=args,
|
|
|
448 |
rank=i,
|
449 |
request_queue=request_queues[i],
|
450 |
response_queue=response_queue if i == 0 else None,
|
|
|
|
|
|
|
|
|
451 |
)
|
452 |
model_main(**generation_kwargs)
|
453 |
# thread = Thread(target=model_main, kwargs=generation_kwargs)
|