Update handler.py
Browse files- handler.py +5 -5
handler.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
|
2 |
import torch
|
3 |
|
@@ -6,11 +7,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
6 |
class EndpointHandler():
|
7 |
def __init__(self, path=""):
|
8 |
# load model and processor from path
|
9 |
-
|
10 |
-
self.
|
11 |
-
self.tokenizer = AlbertTokenizer.from_pretrained(model_id)
|
12 |
|
13 |
-
def __call__(self, data):
|
14 |
|
15 |
# process input
|
16 |
wrongSentence = data.pop("inputs", data)
|
@@ -22,4 +22,4 @@ class EndpointHandler():
|
|
22 |
# run prediction
|
23 |
output = self.model.generate(input['input_ids'])
|
24 |
|
25 |
-
return self.tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
|
3 |
import torch
|
4 |
|
|
|
7 |
class EndpointHandler():
|
8 |
def __init__(self, path=""):
|
9 |
# load model and processor from path
|
10 |
+
self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device)
|
11 |
+
self.tokenizer = AlbertTokenizer.from_pretrained(path)
|
|
|
12 |
|
13 |
+
def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]:
|
14 |
|
15 |
# process input
|
16 |
wrongSentence = data.pop("inputs", data)
|
|
|
22 |
# run prediction
|
23 |
output = self.model.generate(input['input_ids'])
|
24 |
|
25 |
+
return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)}
|