Send attention_mask to device

#9
by mverrilli - opened
Files changed (1) hide show
  1. instruct_pipeline.py +1 -1
instruct_pipeline.py CHANGED
@@ -131,7 +131,7 @@ class InstructionTextGenerationPipeline(Pipeline):
131
 
132
  generated_sequence = self.model.generate(
133
  input_ids=input_ids.to(self.model.device),
134
- attention_mask=attention_mask,
135
  pad_token_id=self.tokenizer.pad_token_id,
136
  **generate_kwargs,
137
  )
 
131
 
132
  generated_sequence = self.model.generate(
133
  input_ids=input_ids.to(self.model.device),
134
+ attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
135
  pad_token_id=self.tokenizer.pad_token_id,
136
  **generate_kwargs,
137
  )