popaqy commited on
Commit
0b048f0
1 Parent(s): 6a28def

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -5
handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
2
  import torch
3
 
@@ -6,11 +7,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
  # load model and processor from path
9
- model_id = "popaqy/pegasus-base-qag-bg-finetuned-spelling6-bg"
10
- self.model = PegasusForConditionalGeneration.from_pretrained(model_id).to(device)
11
- self.tokenizer = AlbertTokenizer.from_pretrained(model_id)
12
 
13
- def __call__(self, data):
14
 
15
  # process input
16
  wrongSentence = data.pop("inputs", data)
@@ -22,4 +22,4 @@ class EndpointHandler():
22
  # run prediction
23
  output = self.model.generate(input['input_ids'])
24
 
25
- return self.tokenizer.batch_decode(output, skip_special_tokens=True)
 
1
+ from typing import Dict, List, Any
2
  from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
3
  import torch
4
 
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
  # load model and processor from path
10
+ self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device)
11
+ self.tokenizer = AlbertTokenizer.from_pretrained(path)
 
12
 
13
+ def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]:
14
 
15
  # process input
16
  wrongSentence = data.pop("inputs", data)
 
22
  # run prediction
23
  output = self.model.generate(input['input_ids'])
24
 
25
+ return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)}