Emanuela Boros
commited on
Commit
•
a945a9c
1
Parent(s):
ddc72bb
updated device
Browse files- generic_nel.py +1 -2
generic_nel.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from transformers import Pipeline
|
2 |
import nltk
|
3 |
-
import torch
|
4 |
|
5 |
nltk.download("averaged_perceptron_tagger")
|
6 |
nltk.download("averaged_perceptron_tagger_eng")
|
@@ -90,7 +89,7 @@ class NelPipeline(Pipeline):
|
|
90 |
def preprocess(self, text, **kwargs):
|
91 |
|
92 |
outputs = self.model.generate(
|
93 |
-
**self.tokenizer([text], return_tensors="pt"),
|
94 |
num_beams=5,
|
95 |
num_return_sequences=5,
|
96 |
max_new_tokens=30,
|
|
|
1 |
from transformers import Pipeline
|
2 |
import nltk
|
|
|
3 |
|
4 |
nltk.download("averaged_perceptron_tagger")
|
5 |
nltk.download("averaged_perceptron_tagger_eng")
|
|
|
89 |
def preprocess(self, text, **kwargs):
|
90 |
|
91 |
outputs = self.model.generate(
|
92 |
+
**self.tokenizer([text], return_tensors="pt").to(self.device),
|
93 |
num_beams=5,
|
94 |
num_return_sequences=5,
|
95 |
max_new_tokens=30,
|