Jonathan Fly commited on
Commit
8764625
1 Parent(s): d56cc00

Use greedy sampling path when temp is 0.0 to avoid division by zero (#53)

Browse files
Files changed (1) hide show
  1. audiocraft/models/lm.py +2 -1
audiocraft/models/lm.py CHANGED
@@ -363,7 +363,8 @@ class LMModel(StreamingModule):
363
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
  logits = logits[..., -1] # [B x K x card]
365
 
366
- if use_sampling:
 
367
  probs = torch.softmax(logits / temp, dim=-1)
368
  if top_p > 0.0:
369
  next_token = utils.sample_top_p(probs, p=top_p)
 
363
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
  logits = logits[..., -1] # [B x K x card]
365
 
366
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
367
+ if use_sampling and temp > 0.0:
368
  probs = torch.softmax(logits / temp, dim=-1)
369
  if top_p > 0.0:
370
  next_token = utils.sample_top_p(probs, p=top_p)