Ruining Li commited on
Commit
91159ab
1 Parent(s): ae76e1f
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -185,8 +185,11 @@ def single_image_sample(
185
  drags,
186
  hidden_cls,
187
  num_steps=50,
 
188
  ):
189
  z = torch.randn(2, 4, 32, 32).to("cuda")
 
 
190
 
191
  # Prepare input for classifer-free guidance
192
  rel = torch.cat([rel, rel], dim=0).to("cuda")
@@ -223,6 +226,13 @@ def single_image_sample(
223
 
224
  samples, _ = samples.chunk(2, dim=0)
225
 
 
 
 
 
 
 
 
226
  return samples
227
 
228
  @spaces.GPU
@@ -273,7 +283,7 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
273
  if idx == 9:
274
  break
275
 
276
- samples = single_image_sample(
277
  model.to("cuda"),
278
  diffusion,
279
  x_cond,
@@ -284,10 +294,8 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
284
  drags,
285
  cls_embedding,
286
  num_steps=50,
 
287
  )
288
- with torch.no_grad():
289
- images = vae.decode(samples / 0.18215).sample
290
- return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
291
 
292
 
293
  sam_predictor = sam_init()
 
185
  drags,
186
  hidden_cls,
187
  num_steps=50,
188
+ vae=None,
189
  ):
190
  z = torch.randn(2, 4, 32, 32).to("cuda")
191
+ if vae is not None:
192
+ vae = vae.to("cuda")
193
 
194
  # Prepare input for classifer-free guidance
195
  rel = torch.cat([rel, rel], dim=0).to("cuda")
 
226
 
227
  samples, _ = samples.chunk(2, dim=0)
228
 
229
+ if vae is not None:
230
+ with torch.no_grad():
231
+ images = vae.decode(samples / 0.18215).sample
232
+ else:
233
+ images = samples
234
+ return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
235
+
236
  return samples
237
 
238
  @spaces.GPU
 
283
  if idx == 9:
284
  break
285
 
286
+ return single_image_sample(
287
  model.to("cuda"),
288
  diffusion,
289
  x_cond,
 
294
  drags,
295
  cls_embedding,
296
  num_steps=50,
297
+ vae=vae,
298
  )
 
 
 
299
 
300
 
301
  sam_predictor = sam_init()