Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -117,13 +117,14 @@ def load_models(args, master_port, rank):
|
|
117 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
118 |
args.precision
|
119 |
]
|
|
|
120 |
|
121 |
print(f"Creating lm: Gemma-2B")
|
122 |
text_encoder = (
|
123 |
AutoModelForCausalLM.from_pretrained(
|
124 |
"google/gemma-2b",
|
125 |
torch_dtype=dtype,
|
126 |
-
device_map=
|
127 |
# device_map="cuda",
|
128 |
token=hf_token,
|
129 |
)
|
@@ -146,7 +147,7 @@ def load_models(args, master_port, rank):
|
|
146 |
vae = AutoencoderKL.from_pretrained(
|
147 |
"stabilityai/sdxl-vae",
|
148 |
torch_dtype=torch.float32,
|
149 |
-
)
|
150 |
|
151 |
print(f"Creating DiT: Next-DiT")
|
152 |
# latent_size = train_args.image_size // 8
|
@@ -155,7 +156,7 @@ def load_models(args, master_port, rank):
|
|
155 |
cap_feat_dim=cap_feat_dim,
|
156 |
)
|
157 |
# model.eval().to("cuda", dtype=dtype)
|
158 |
-
model.eval()
|
159 |
|
160 |
assert train_args.model_parallel_size == args.num_gpus
|
161 |
if args.ema:
|
@@ -169,7 +170,6 @@ def load_models(args, master_port, rank):
|
|
169 |
)
|
170 |
model.load_state_dict(ckpt, strict=True)
|
171 |
|
172 |
-
# barrier.wait()
|
173 |
return text_encoder, tokenizer, vae, model
|
174 |
|
175 |
|
@@ -181,12 +181,13 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
182 |
|
183 |
print(args)
|
|
|
184 |
torch.cuda.set_device(0)
|
185 |
-
|
186 |
# loading model to gpu
|
187 |
-
text_encoder = text_encoder.cuda()
|
188 |
-
vae = vae.cuda()
|
189 |
-
model = model.to("cuda", dtype=dtype)
|
190 |
|
191 |
with torch.autocast("cuda", dtype):
|
192 |
(
|
@@ -581,7 +582,7 @@ def main():
|
|
581 |
examples_per_page=22,
|
582 |
)
|
583 |
|
584 |
-
@spaces.GPU(duration=
|
585 |
def on_submit(*infer_args):
|
586 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
587 |
if isinstance(result, ModelFailure):
|
|
|
117 |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
118 |
args.precision
|
119 |
]
|
120 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
121 |
|
122 |
print(f"Creating lm: Gemma-2B")
|
123 |
text_encoder = (
|
124 |
AutoModelForCausalLM.from_pretrained(
|
125 |
"google/gemma-2b",
|
126 |
torch_dtype=dtype,
|
127 |
+
device_map=device,
|
128 |
# device_map="cuda",
|
129 |
token=hf_token,
|
130 |
)
|
|
|
147 |
vae = AutoencoderKL.from_pretrained(
|
148 |
"stabilityai/sdxl-vae",
|
149 |
torch_dtype=torch.float32,
|
150 |
+
).to(device)
|
151 |
|
152 |
print(f"Creating DiT: Next-DiT")
|
153 |
# latent_size = train_args.image_size // 8
|
|
|
156 |
cap_feat_dim=cap_feat_dim,
|
157 |
)
|
158 |
# model.eval().to("cuda", dtype=dtype)
|
159 |
+
model.eval().to(device, dtype=dtype)
|
160 |
|
161 |
assert train_args.model_parallel_size == args.num_gpus
|
162 |
if args.ema:
|
|
|
170 |
)
|
171 |
model.load_state_dict(ckpt, strict=True)
|
172 |
|
|
|
173 |
return text_encoder, tokenizer, vae, model
|
174 |
|
175 |
|
|
|
181 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
182 |
|
183 |
print(args)
|
184 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
185 |
torch.cuda.set_device(0)
|
186 |
+
|
187 |
# loading model to gpu
|
188 |
+
# text_encoder = text_encoder.cuda()
|
189 |
+
# vae = vae.cuda()
|
190 |
+
# model = model.to("cuda", dtype=dtype)
|
191 |
|
192 |
with torch.autocast("cuda", dtype):
|
193 |
(
|
|
|
582 |
examples_per_page=22,
|
583 |
)
|
584 |
|
585 |
+
@spaces.GPU(duration=200)
|
586 |
def on_submit(*infer_args):
|
587 |
result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
|
588 |
if isinstance(result, ModelFailure):
|