Alexander Slessor commited on
Commit
317e1b6
1 Parent(s): d37221f

currently running

Browse files
Files changed (4) hide show
  1. .gitignore +10 -0
  2. README.md +7 -0
  3. handler.py +146 -0
  4. invoice_example.png +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.ipynb
3
+ *.pdf
4
+
5
+ test_endpoint.py
6
+ test_handler_local.py
7
+
8
+ setup
9
+ upload_to_hf
10
+ requirements.txt
README.md CHANGED
@@ -15,3 +15,10 @@ LayoutLMv2 is an improved version of LayoutLM with new pre-training tasks to mod
15
 
16
  [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740)
17
  Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou, ACL 2021
 
 
 
 
 
 
 
 
15
 
16
  [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740)
17
  Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou, ACL 2021
18
+
19
+
20
+ Examples & Guides
21
+
22
+ - https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv2/DocVQA/Fine_tuning_LayoutLMv2ForQuestionAnswering_on_DocVQA.ipynb
23
+
24
+ - https://mccormickml.com/2020/03/10/question-answering-with-a-fine-tuned-BERT/
handler.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any
3
+ # from transformers import LayoutLMForTokenClassification
4
+
5
+ from transformers import LayoutLMv2ForQuestionAnswering
6
+ from transformers import LayoutLMv2Processor
7
+ from transformers import LayoutLMv2FeatureExtractor
8
+ from transformers import LayoutLMv2ImageProcessor
9
+ from transformers import LayoutLMv2TokenizerFast
10
+
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from subprocess import run
13
+ import pdf2image
14
+
15
+ from pprint import pprint
16
+
17
+ # set device
18
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # install tesseract-ocr and pytesseract
21
+ # run("apt install -y tesseract-ocr", shell=True, check=True)
22
+
23
+ feature_extractor = LayoutLMv2FeatureExtractor()
24
+
25
+ class NoOCRReaderFound(Exception):
26
+ def __init__(self, e):
27
+ self.e = e
28
+
29
+ def __str__(self):
30
+ return f"Could not load OCR Reader: {self.e}"
31
+
32
+ # helper function to unnormalize bboxes for drawing onto the image
33
+ def unnormalize_box(bbox, width, height):
34
+ return [
35
+ width * (bbox[0] / 1000),
36
+ height * (bbox[1] / 1000),
37
+ width * (bbox[2] / 1000),
38
+ height * (bbox[3] / 1000),
39
+ ]
40
+
41
+ def pdf_to_image(b: bytes):
42
+ # First, try to extract text directly
43
+ # TODO: This library requires poppler, which is not present everywhere.
44
+ # We should look into alternatives. We could also gracefully handle this
45
+ # and simply fall back to _only_ extracted text
46
+ images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)]
47
+ encoded_inputs = feature_extractor(images)
48
+ print('feature_extractor: ', encoded_inputs.keys())
49
+ data = {}
50
+ data['image'] = encoded_inputs.pixel_values
51
+ data['words'] = encoded_inputs.words
52
+ data['boxes'] = encoded_inputs.boxes
53
+ return data
54
+
55
+
56
+ class EndpointHandler:
57
+ def __init__(self, path=""):
58
+ # self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
59
+ # self.processor = LayoutLMv2Processor.from_pretrained(path)
60
+ self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default
61
+ self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")
62
+ # self.processor = LayoutLMv2Processor(self.image_processor, self.tokenizer)
63
+ self.processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
64
+ # processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
65
+
66
+ self.model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
67
+
68
+ def __call__(self, data: dict[str, bytes]):
69
+ """
70
+ Args:
71
+ data (:obj:):
72
+ includes the deserialized image file as PIL.Image
73
+ """
74
+ image = data.pop("inputs", data)
75
+
76
+ # image = pdf_to_image(image)
77
+ images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
78
+ for image in images:
79
+ question = "what is the invoice date"
80
+ encoding = self.processor(
81
+ image,
82
+ question,
83
+ return_tensors="pt",
84
+ )
85
+ # print(encoding.keys())
86
+
87
+ outputs = self.model(**encoding)
88
+ # print(outputs.keys())
89
+ predicted_start_idx = outputs.start_logits.argmax(-1).item()
90
+ predicted_end_idx = outputs.end_logits.argmax(-1).item()
91
+
92
+ predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
93
+ predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
94
+ print('answer: ',predicted_answer)
95
+
96
+ target_start_index = torch.tensor([7])
97
+ target_end_index = torch.tensor([14])
98
+
99
+ outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
100
+ predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
101
+ predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
102
+ print(predicted_answer_span_start, predicted_answer_span_end)
103
+
104
+ # pprint(image)
105
+ # for image, words, boxes in zip(image['image'], image['words'], image['boxes']):
106
+ # print(image, words, boxes)
107
+
108
+ # question = "what is the invoice date"
109
+ # encoding = self.processor(
110
+ # image,
111
+ # question,
112
+ # words,
113
+ # boxes=boxes,
114
+ # return_tensors="pt",
115
+ # # apply_ocr=False
116
+ # )
117
+ # print(encoding.keys())
118
+
119
+
120
+ # process image
121
+ # encoding = self.processor(image, return_tensors="pt")
122
+
123
+ # # run prediction
124
+ # with torch.inference_mode():
125
+ # outputs = self.model(
126
+ # input_ids=encoding.input_ids.to(device),
127
+ # bbox=encoding.bbox.to(device),
128
+ # attention_mask=encoding.attention_mask.to(device),
129
+ # token_type_ids=encoding.token_type_ids.to(device),
130
+ # )
131
+ # predictions = outputs.logits.softmax(-1)
132
+
133
+ # # post process output
134
+ # result = []
135
+ # for item, inp_ids, bbox in zip(
136
+ # predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()
137
+ # ):
138
+ # label = self.model.config.id2label[int(item.argmax().cpu())]
139
+ # if label == "O":
140
+ # continue
141
+ # score = item.max().item()
142
+ # text = self.processor.tokenizer.decode(inp_ids)
143
+ # bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
144
+ # result.append({"label": label, "score": score, "text": text, "bbox": bbox})
145
+ # return {"predictions": result}
146
+ return ''
invoice_example.png ADDED