Emanuela Boros commited on
Commit
a945a9c
1 Parent(s): ddc72bb

updated device

Browse files
Files changed (1) hide show
  1. generic_nel.py +1 -2
generic_nel.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import Pipeline
2
  import nltk
3
- import torch
4
 
5
  nltk.download("averaged_perceptron_tagger")
6
  nltk.download("averaged_perceptron_tagger_eng")
@@ -90,7 +89,7 @@ class NelPipeline(Pipeline):
90
  def preprocess(self, text, **kwargs):
91
 
92
  outputs = self.model.generate(
93
- **self.tokenizer([text], return_tensors="pt"),
94
  num_beams=5,
95
  num_return_sequences=5,
96
  max_new_tokens=30,
 
1
  from transformers import Pipeline
2
  import nltk
 
3
 
4
  nltk.download("averaged_perceptron_tagger")
5
  nltk.download("averaged_perceptron_tagger_eng")
 
89
  def preprocess(self, text, **kwargs):
90
 
91
  outputs = self.model.generate(
92
+ **self.tokenizer([text], return_tensors="pt").to(self.device),
93
  num_beams=5,
94
  num_return_sequences=5,
95
  max_new_tokens=30,