matthayes mverrilli commited on
Commit
e19a525
1 Parent(s): 0cb09cc

Send attention_mask to device (#9)

Browse files

- Send attention_mask to device (76fb033a3c0c4b1d6764fcf69699ba60f6ad942e)


Co-authored-by: Michael Verrilli <mverrilli@users.noreply.huggingface.co>

Files changed (1) hide show
  1. instruct_pipeline.py +1 -1
instruct_pipeline.py CHANGED
@@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline):
131
 
132
  generated_sequence = self.model.generate(
133
  input_ids=input_ids.to(self.model.device),
134
- attention_mask=attention_mask,
135
  pad_token_id=self.tokenizer.pad_token_id,
136
  **generate_kwargs,
137
  )
 
131
 
132
  generated_sequence = self.model.generate(
133
  input_ids=input_ids.to(self.model.device),
134
+ attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
135
  pad_token_id=self.tokenizer.pad_token_id,
136
  **generate_kwargs,
137
  )