Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+
5
+
6
+ # check for GPU
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ # load the model
13
+ self.processor = DonutProcessor.from_pretrained(path)
14
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
15
+ # move model to device
16
+ self.model.to(device)
17
+ self.decoder_input_ids = self.processor.tokenizer(
18
+ "<s_cord-v2>", add_special_tokens=False, return_tensors="pt"
19
+ ).input_ids
20
+
21
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
22
+
23
+ inputs = data.pop("inputs", data)
24
+
25
+
26
+ # preprocess the input
27
+ pixel_values = self.processor(inputs, return_tensors="pt").pixel_values
28
+
29
+ # forward pass
30
+ outputs = self.model.generate(
31
+ pixel_values.to(device),
32
+ decoder_input_ids=self.decoder_input_ids.to(device),
33
+ max_length=self.model.decoder.config.max_position_embeddings,
34
+ early_stopping=True,
35
+ pad_token_id=self.processor.tokenizer.pad_token_id,
36
+ eos_token_id=self.processor.tokenizer.eos_token_id,
37
+ use_cache=True,
38
+ num_beams=1,
39
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
40
+ return_dict_in_generate=True,
41
+ )
42
+ # process output
43
+ prediction = self.processor.batch_decode(outputs.sequences)[0]
44
+ prediction = self.processor.token2json(prediction)
45
+
46
+ return prediction