ischemist commited on
Commit
595f994
·
verified ·
1 Parent(s): 98da73e

downcast the scores after upcasting to prevent runtime errors

Browse files

this change fixes https://github.com/huggingface/transformers/issues/41238

Files changed (1) hide show
  1. custom_generate/generate.py +2 -1
custom_generate/generate.py CHANGED
@@ -379,7 +379,8 @@ def _group_beam_search(
379
  next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
380
 
381
  if output_scores:
382
- processed_score[batch_group_indices] = next_token_scores_processed
 
383
 
384
  # reshape for beam search
385
  next_token_scores = next_token_scores.view(
 
379
  next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
380
 
381
  if output_scores:
382
+ processed_score[batch_group_indices] = next_token_scores_processed.to(processed_score.dtype)
383
+
384
 
385
  # reshape for beam search
386
  next_token_scores = next_token_scores.view(