Spaces:
Running
Running
fix understand_audio: clone tensors for inference mode, working on GPU (52s)
Browse files- train_engine.py +2 -2
train_engine.py
CHANGED
|
@@ -3414,6 +3414,7 @@ def understand_audio(
|
|
| 3414 |
past_think = False
|
| 3415 |
|
| 3416 |
for step in range(max_new_tokens):
|
|
|
|
| 3417 |
# After </think>: block audio codes so the LM only generates text
|
| 3418 |
if past_think:
|
| 3419 |
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
|
@@ -3422,9 +3423,8 @@ def understand_audio(
|
|
| 3422 |
if temperature <= 0:
|
| 3423 |
next_id = int(logits[0].argmax().item())
|
| 3424 |
else:
|
| 3425 |
-
scaled = logits[0] / temperature
|
| 3426 |
if top_k > 0:
|
| 3427 |
-
# Zero out everything below top_k
|
| 3428 |
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3429 |
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3430 |
if top_p > 0 and top_p < 1.0:
|
|
|
|
| 3414 |
past_think = False
|
| 3415 |
|
| 3416 |
for step in range(max_new_tokens):
|
| 3417 |
+
logits = logits.clone()
|
| 3418 |
# After </think>: block audio codes so the LM only generates text
|
| 3419 |
if past_think:
|
| 3420 |
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
|
|
|
| 3423 |
if temperature <= 0:
|
| 3424 |
next_id = int(logits[0].argmax().item())
|
| 3425 |
else:
|
| 3426 |
+
scaled = logits[0].clone() / temperature
|
| 3427 |
if top_k > 0:
|
|
|
|
| 3428 |
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3429 |
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3430 |
if top_p > 0 and top_p < 1.0:
|