downcast the scores after upcasting to prevent runtime errors
Browse filesthis change fixes https://github.com/huggingface/transformers/issues/41238
custom_generate/generate.py
CHANGED
|
@@ -379,7 +379,8 @@ def _group_beam_search(
|
|
| 379 |
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 380 |
|
| 381 |
if output_scores:
|
| 382 |
-
processed_score[batch_group_indices] = next_token_scores_processed
|
|
|
|
| 383 |
|
| 384 |
# reshape for beam search
|
| 385 |
next_token_scores = next_token_scores.view(
|
|
|
|
| 379 |
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 380 |
|
| 381 |
if output_scores:
|
| 382 |
+
processed_score[batch_group_indices] = next_token_scores_processed.to(processed_score.dtype)
|
| 383 |
+
|
| 384 |
|
| 385 |
# reshape for beam search
|
| 386 |
next_token_scores = next_token_scores.view(
|