PommesPeter commited on
Commit
f50300f
1 Parent(s): 148edb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -117,13 +117,14 @@ def load_models(args, master_port, rank):
117
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
118
  args.precision
119
  ]
 
120
 
121
  print(f"Creating lm: Gemma-2B")
122
  text_encoder = (
123
  AutoModelForCausalLM.from_pretrained(
124
  "google/gemma-2b",
125
  torch_dtype=dtype,
126
- device_map="cpu",
127
  # device_map="cuda",
128
  token=hf_token,
129
  )
@@ -146,7 +147,7 @@ def load_models(args, master_port, rank):
146
  vae = AutoencoderKL.from_pretrained(
147
  "stabilityai/sdxl-vae",
148
  torch_dtype=torch.float32,
149
- )
150
 
151
  print(f"Creating DiT: Next-DiT")
152
  # latent_size = train_args.image_size // 8
@@ -155,7 +156,7 @@ def load_models(args, master_port, rank):
155
  cap_feat_dim=cap_feat_dim,
156
  )
157
  # model.eval().to("cuda", dtype=dtype)
158
- model.eval()
159
 
160
  assert train_args.model_parallel_size == args.num_gpus
161
  if args.ema:
@@ -169,7 +170,6 @@ def load_models(args, master_port, rank):
169
  )
170
  model.load_state_dict(ckpt, strict=True)
171
 
172
- # barrier.wait()
173
  return text_encoder, tokenizer, vae, model
174
 
175
 
@@ -181,12 +181,13 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
181
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
182
 
183
  print(args)
 
184
  torch.cuda.set_device(0)
185
-
186
  # loading model to gpu
187
- text_encoder = text_encoder.cuda()
188
- vae = vae.cuda()
189
- model = model.to("cuda", dtype=dtype)
190
 
191
  with torch.autocast("cuda", dtype):
192
  (
@@ -581,7 +582,7 @@ def main():
581
  examples_per_page=22,
582
  )
583
 
584
- @spaces.GPU(duration=240)
585
  def on_submit(*infer_args):
586
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
587
  if isinstance(result, ModelFailure):
 
117
  dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
118
  args.precision
119
  ]
120
+ device = "cuda" if torch.cuda.is_available() else "cpu"
121
 
122
  print(f"Creating lm: Gemma-2B")
123
  text_encoder = (
124
  AutoModelForCausalLM.from_pretrained(
125
  "google/gemma-2b",
126
  torch_dtype=dtype,
127
+ device_map=device,
128
  # device_map="cuda",
129
  token=hf_token,
130
  )
 
147
  vae = AutoencoderKL.from_pretrained(
148
  "stabilityai/sdxl-vae",
149
  torch_dtype=torch.float32,
150
+ ).to(device)
151
 
152
  print(f"Creating DiT: Next-DiT")
153
  # latent_size = train_args.image_size // 8
 
156
  cap_feat_dim=cap_feat_dim,
157
  )
158
  # model.eval().to("cuda", dtype=dtype)
159
+ model.eval().to(device, dtype=dtype)
160
 
161
  assert train_args.model_parallel_size == args.num_gpus
162
  if args.ema:
 
170
  )
171
  model.load_state_dict(ckpt, strict=True)
172
 
 
173
  return text_encoder, tokenizer, vae, model
174
 
175
 
 
181
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
182
 
183
  print(args)
184
+ device = "cuda" if torch.cuda.is_available() else "cpu"
185
  torch.cuda.set_device(0)
186
+
187
  # loading model to gpu
188
+ # text_encoder = text_encoder.cuda()
189
+ # vae = vae.cuda()
190
+ # model = model.to("cuda", dtype=dtype)
191
 
192
  with torch.autocast("cuda", dtype):
193
  (
 
582
  examples_per_page=22,
583
  )
584
 
585
+ @spaces.GPU(duration=200)
586
  def on_submit(*infer_args):
587
  result = infer_ode(args, infer_args, text_encoder, tokenizer, vae, model)
588
  if isinstance(result, ModelFailure):