Ruining Li commited on
Commit
ae76e1f
1 Parent(s): 0af5cc2

Minor update.

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -22,7 +22,7 @@ import spaces
22
  TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
23
  DESCRIPTION = """
24
  <div>
25
- Try <a href='https://arxiv.org/abs/24xx.xxxxx'><b>DragAPart</b></a> yourself to manipulate your favorite articulated objects in 2 seconds!
26
  </div>
27
  """
28
  INSTRUCTION = '''
@@ -185,11 +185,8 @@ def single_image_sample(
185
  drags,
186
  hidden_cls,
187
  num_steps=50,
188
- vae=None,
189
  ):
190
  z = torch.randn(2, 4, 32, 32).to("cuda")
191
- if vae is not None:
192
- vae = vae.to("cuda")
193
 
194
  # Prepare input for classifer-free guidance
195
  rel = torch.cat([rel, rel], dim=0).to("cuda")
@@ -226,9 +223,7 @@ def single_image_sample(
226
 
227
  samples, _ = samples.chunk(2, dim=0)
228
 
229
- with torch.no_grad():
230
- images = vae.decode(samples / 0.18215).sample
231
- return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
232
 
233
  @spaces.GPU
234
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
@@ -278,7 +273,7 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
278
  if idx == 9:
279
  break
280
 
281
- images = single_image_sample(
282
  model.to("cuda"),
283
  diffusion,
284
  x_cond,
@@ -289,9 +284,10 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
289
  drags,
290
  cls_embedding,
291
  num_steps=50,
292
- vae=vae,
293
  )
294
- return images
 
 
295
 
296
 
297
  sam_predictor = sam_init()
 
22
  TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
23
  DESCRIPTION = """
24
  <div>
25
+ Try <a href='https://arxiv.org/abs/24xx.xxxxx'><b>DragAPart</b></a> yourself to manipulate your favorite articulated objects in seconds!
26
  </div>
27
  """
28
  INSTRUCTION = '''
 
185
  drags,
186
  hidden_cls,
187
  num_steps=50,
 
188
  ):
189
  z = torch.randn(2, 4, 32, 32).to("cuda")
 
 
190
 
191
  # Prepare input for classifer-free guidance
192
  rel = torch.cat([rel, rel], dim=0).to("cuda")
 
223
 
224
  samples, _ = samples.chunk(2, dim=0)
225
 
226
+ return samples
 
 
227
 
228
  @spaces.GPU
229
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
 
273
  if idx == 9:
274
  break
275
 
276
+ samples = single_image_sample(
277
  model.to("cuda"),
278
  diffusion,
279
  x_cond,
 
284
  drags,
285
  cls_embedding,
286
  num_steps=50,
 
287
  )
288
+ with torch.no_grad():
289
+ images = vae.decode(samples / 0.18215).sample
290
+ return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
291
 
292
 
293
  sam_predictor = sam_init()