Jeney commited on
Commit
d75ed94
1 Parent(s): 1c7db93

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -10,8 +10,8 @@ from transformers import DonutProcessor, VisionEncoderDecoderModel
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  # load model and processor from path
13
- self.processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v2.29")
14
- self.model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v2.29")
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -22,7 +22,7 @@ class EndpointHandler:
22
  return self.process_document(image)
23
 
24
 
25
- def process_document(self, image):
26
  # prepare encoder inputs
27
  pixel_values = self.processor(image, return_tensors="pt").pixel_values
28
 
@@ -49,4 +49,4 @@ class EndpointHandler:
49
  sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
50
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
51
 
52
- return self.processor.token2json(sequence)
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  # load model and processor from path
13
+ self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
14
+ self.model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
22
  return self.process_document(image)
23
 
24
 
25
+ def process_document(self, image:Image) -> dict[str, Any]:
26
  # prepare encoder inputs
27
  pixel_values = self.processor(image, return_tensors="pt").pixel_values
28
 
 
49
  sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
50
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
51
 
52
+ return self.processor.token2json(sequence)