dhigurashi hibikaze commited on
Commit
d155d07
1 Parent(s): e28ebda

Fixed beam search error when using multiple GPUs (#6)

Browse files

- Fixed beam search error when using multiple GPUs (14a911cdcbe50c9fbaae49af6f26b3d981b30cde)


Co-authored-by: Hiroki Yamaguchi <hibikaze@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_plamo.py +1 -1
modeling_plamo.py CHANGED
@@ -701,5 +701,5 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
701
  def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
702
  reordered_past: Tuple[Any, ...] = ()
703
  for layer_past in past_key_values:
704
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
705
  return reordered_past
 
701
  def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
702
  reordered_past: Tuple[Any, ...] = ()
703
  for layer_past in past_key_values:
704
+ reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
705
  return reordered_past