wondervictor commited on
Commit
0ba2ca3
·
verified ·
1 Parent(s): b8ad3ef

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +3 -3
autoregressive/models/generate.py CHANGED
@@ -60,7 +60,9 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
60
  logits = logits[:, -1, :] / max(temperature, 1e-5)
61
  if top_k > 0 or top_p < 1.0:
62
  logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
63
- probs = F.softmax(logits.to(torch.float32), dim=-1)
 
 
64
  # values, indices = torch.max(probs, dim=1, keepdim=True)
65
  # mask = (probs == values).float()
66
  # probs = probs * (1 - mask)
@@ -71,8 +73,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
71
  # add to fix 'nan' and 'inf'
72
  probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
73
  probs = torch.clamp(probs, min=0, max=None)
74
- print(probs.sum())
75
- print(probs)
76
  print(f'inf:{torch.any(torch.isinf(probs))}')
77
  print(f'nan: {torch.any(torch.isnan(probs))}')
78
 
 
60
  logits = logits[:, -1, :] / max(temperature, 1e-5)
61
  if top_k > 0 or top_p < 1.0:
62
  logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
63
+ probs = F.softmax(logits, dim=-1)
64
+ print(probs.sum())
65
+ print(probs)
66
  # values, indices = torch.max(probs, dim=1, keepdim=True)
67
  # mask = (probs == values).float()
68
  # probs = probs * (1 - mask)
 
73
  # add to fix 'nan' and 'inf'
74
  probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
75
  probs = torch.clamp(probs, min=0, max=None)
 
 
76
  print(f'inf:{torch.any(torch.isinf(probs))}')
77
  print(f'nan: {torch.any(torch.isnan(probs))}')
78