adefossez commited on
Commit
a8ec0d0
1 Parent(s): 664552c
Files changed (1) hide show
  1. audiocraft/utils/utils.py +1 -1
audiocraft/utils/utils.py CHANGED
@@ -122,7 +122,7 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
- probs_sort *= (~mask).float(0)
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)
 
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
+ probs_sort *= (~mask).float()
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)