ing0 commited on
Commit
cbb2aab
·
1 Parent(s): 5cb345f
Files changed (1) hide show
  1. diffrhythm/infer/infer_utils.py +1 -0
diffrhythm/infer/infer_utils.py CHANGED
@@ -250,6 +250,7 @@ def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_
250
  mean, scale = latent.chunk(2, dim=1)
251
  prompt, _ = vae_sample(mean, scale)
252
  prompt = prompt.transpose(1, 2) # [b t d]
 
253
 
254
  pred_segments = json.loads(pred_segments)
255
  # import pdb; pdb.set_trace()
 
250
  mean, scale = latent.chunk(2, dim=1)
251
  prompt, _ = vae_sample(mean, scale)
252
  prompt = prompt.transpose(1, 2) # [b t d]
253
+ prompt = prompt[:,:max_frames,:] if prompt.shape[1] >= max_frames else torch.nn.functional.pad(prompt, (0, 0, 0, max_frames - prompt.shape[1]), mode="constant", value=0)
254
 
255
  pred_segments = json.loads(pred_segments)
256
  # import pdb; pdb.set_trace()