Nekochu commited on
Commit
4b2f4ad
·
1 Parent(s): 6bfdc38

fix understand_audio: clone tensors for inference mode, working on GPU (52s)

Browse files
Files changed (1) hide show
  1. train_engine.py +2 -2
train_engine.py CHANGED
@@ -3414,6 +3414,7 @@ def understand_audio(
3414
  past_think = False
3415
 
3416
  for step in range(max_new_tokens):
 
3417
  # After </think>: block audio codes so the LM only generates text
3418
  if past_think:
3419
  logits[0, _AUDIO_CODE_BASE:] = float("-inf")
@@ -3422,9 +3423,8 @@ def understand_audio(
3422
  if temperature <= 0:
3423
  next_id = int(logits[0].argmax().item())
3424
  else:
3425
- scaled = logits[0] / temperature
3426
  if top_k > 0:
3427
- # Zero out everything below top_k
3428
  topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3429
  scaled[scaled < topk_vals[-1]] = float("-inf")
3430
  if top_p > 0 and top_p < 1.0:
 
3414
  past_think = False
3415
 
3416
  for step in range(max_new_tokens):
3417
+ logits = logits.clone()
3418
  # After </think>: block audio codes so the LM only generates text
3419
  if past_think:
3420
  logits[0, _AUDIO_CODE_BASE:] = float("-inf")
 
3423
  if temperature <= 0:
3424
  next_id = int(logits[0].argmax().item())
3425
  else:
3426
+ scaled = logits[0].clone() / temperature
3427
  if top_k > 0:
 
3428
  topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3429
  scaled[scaled < topk_vals[-1]] = float("-inf")
3430
  if top_p > 0 and top_p < 1.0: