Alexander Slessor commited on
Commit
7086666
1 Parent(s): 317e1b6

completed initial template

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. README.md +7 -0
  3. handler.py +170 -95
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  __pycache__
2
  *.ipynb
3
  *.pdf
 
4
 
5
  test_endpoint.py
6
  test_handler_local.py
@@ -8,3 +9,4 @@ test_handler_local.py
8
  setup
9
  upload_to_hf
10
  requirements.txt
 
 
1
  __pycache__
2
  *.ipynb
3
  *.pdf
4
+ *.log
5
 
6
  test_endpoint.py
7
  test_handler_local.py
 
9
  setup
10
  upload_to_hf
11
  requirements.txt
12
+ notes.md
README.md CHANGED
@@ -22,3 +22,10 @@ Examples & Guides
22
  - https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv2/DocVQA/Fine_tuning_LayoutLMv2ForQuestionAnswering_on_DocVQA.ipynb
23
 
24
  - https://mccormickml.com/2020/03/10/question-answering-with-a-fine-tuned-BERT/
 
 
 
 
 
 
 
 
22
  - https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv2/DocVQA/Fine_tuning_LayoutLMv2ForQuestionAnswering_on_DocVQA.ipynb
23
 
24
  - https://mccormickml.com/2020/03/10/question-answering-with-a-fine-tuned-BERT/
25
+
26
+
27
+ # Errors
28
+
29
+ ```
30
+ The class LayoutLMv2FeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use LayoutLMv2ImageProcessor instead.
31
+ ```
handler.py CHANGED
@@ -1,27 +1,35 @@
1
  import torch
2
- from typing import Any
3
- # from transformers import LayoutLMForTokenClassification
4
-
5
  from transformers import LayoutLMv2ForQuestionAnswering
6
  from transformers import LayoutLMv2Processor
7
  from transformers import LayoutLMv2FeatureExtractor
8
  from transformers import LayoutLMv2ImageProcessor
9
  from transformers import LayoutLMv2TokenizerFast
10
-
 
 
 
 
 
 
11
  from PIL import Image, ImageDraw, ImageFont
12
  from subprocess import run
13
  import pdf2image
14
-
15
  from pprint import pprint
 
 
 
16
 
17
- # set device
18
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
  # install tesseract-ocr and pytesseract
21
  # run("apt install -y tesseract-ocr", shell=True, check=True)
22
 
23
  feature_extractor = LayoutLMv2FeatureExtractor()
24
 
 
 
 
 
25
  class NoOCRReaderFound(Exception):
26
  def __init__(self, e):
27
  self.e = e
@@ -29,15 +37,6 @@ class NoOCRReaderFound(Exception):
29
  def __str__(self):
30
  return f"Could not load OCR Reader: {self.e}"
31
 
32
- # helper function to unnormalize bboxes for drawing onto the image
33
- def unnormalize_box(bbox, width, height):
34
- return [
35
- width * (bbox[0] / 1000),
36
- height * (bbox[1] / 1000),
37
- width * (bbox[2] / 1000),
38
- height * (bbox[3] / 1000),
39
- ]
40
-
41
  def pdf_to_image(b: bytes):
42
  # First, try to extract text directly
43
  # TODO: This library requires poppler, which is not present everywhere.
@@ -53,17 +52,109 @@ def pdf_to_image(b: bytes):
53
  return data
54
 
55
 
56
- class EndpointHandler:
57
- def __init__(self, path=""):
58
- # self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
59
- # self.processor = LayoutLMv2Processor.from_pretrained(path)
60
- self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default
61
- self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")
62
- # self.processor = LayoutLMv2Processor(self.image_processor, self.tokenizer)
63
- self.processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
64
- # processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- self.model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
 
 
 
 
 
 
 
 
 
67
 
68
  def __call__(self, data: dict[str, bytes]):
69
  """
@@ -75,72 +166,56 @@ class EndpointHandler:
75
 
76
  # image = pdf_to_image(image)
77
  images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
78
- for image in images:
79
- question = "what is the invoice date"
80
- encoding = self.processor(
81
- image,
82
- question,
83
- return_tensors="pt",
84
- )
85
- # print(encoding.keys())
86
-
87
- outputs = self.model(**encoding)
88
- # print(outputs.keys())
89
- predicted_start_idx = outputs.start_logits.argmax(-1).item()
90
- predicted_end_idx = outputs.end_logits.argmax(-1).item()
91
-
92
- predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
93
- predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
94
- print('answer: ',predicted_answer)
95
-
96
- target_start_index = torch.tensor([7])
97
- target_end_index = torch.tensor([14])
98
-
99
- outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
100
- predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
101
- predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
102
- print(predicted_answer_span_start, predicted_answer_span_end)
103
-
104
- # pprint(image)
105
- # for image, words, boxes in zip(image['image'], image['words'], image['boxes']):
106
- # print(image, words, boxes)
107
-
108
- # question = "what is the invoice date"
109
- # encoding = self.processor(
110
- # image,
111
- # question,
112
- # words,
113
- # boxes=boxes,
114
- # return_tensors="pt",
115
- # # apply_ocr=False
116
- # )
117
- # print(encoding.keys())
118
-
119
-
120
- # process image
121
- # encoding = self.processor(image, return_tensors="pt")
122
-
123
- # # run prediction
124
- # with torch.inference_mode():
125
- # outputs = self.model(
126
- # input_ids=encoding.input_ids.to(device),
127
- # bbox=encoding.bbox.to(device),
128
- # attention_mask=encoding.attention_mask.to(device),
129
- # token_type_ids=encoding.token_type_ids.to(device),
130
- # )
131
- # predictions = outputs.logits.softmax(-1)
132
-
133
- # # post process output
134
- # result = []
135
- # for item, inp_ids, bbox in zip(
136
- # predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()
137
- # ):
138
- # label = self.model.config.id2label[int(item.argmax().cpu())]
139
- # if label == "O":
140
- # continue
141
- # score = item.max().item()
142
- # text = self.processor.tokenizer.decode(inp_ids)
143
- # bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
144
- # result.append({"label": label, "score": score, "text": text, "bbox": bbox})
145
- # return {"predictions": result}
146
- return ''
 
1
  import torch
2
+ from typing import Any, Optional
 
 
3
  from transformers import LayoutLMv2ForQuestionAnswering
4
  from transformers import LayoutLMv2Processor
5
  from transformers import LayoutLMv2FeatureExtractor
6
  from transformers import LayoutLMv2ImageProcessor
7
  from transformers import LayoutLMv2TokenizerFast
8
+ from transformers.tokenization_utils_base import BatchEncoding
9
+ from transformers.tokenization_utils_base import TruncationStrategy
10
+ from transformers.utils import TensorType
11
+ from transformers.modeling_outputs import (
12
+ QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase
13
+ )
14
+ import numpy as np
15
  from PIL import Image, ImageDraw, ImageFont
16
  from subprocess import run
17
  import pdf2image
 
18
  from pprint import pprint
19
+ import logging
20
+ from os import environ
21
+ from dataclasses import dataclass
22
 
 
23
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
24
  # install tesseract-ocr and pytesseract
25
  # run("apt install -y tesseract-ocr", shell=True, check=True)
26
 
27
  feature_extractor = LayoutLMv2FeatureExtractor()
28
 
29
+ # @dataclass
30
+ # class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase):
31
+ # token_logits: Optional[torch.FloatTensor] = None
32
+
33
  class NoOCRReaderFound(Exception):
34
  def __init__(self, e):
35
  self.e = e
 
37
  def __str__(self):
38
  return f"Could not load OCR Reader: {self.e}"
39
 
 
 
 
 
 
 
 
 
 
40
  def pdf_to_image(b: bytes):
41
  # First, try to extract text directly
42
  # TODO: This library requires poppler, which is not present everywhere.
 
52
  return data
53
 
54
 
55
+ def setup_logger(which_logger: Optional[str] = None):
56
+ lib_level = logging.DEBUG # Default level for your logger
57
+ root_level = logging.INFO
58
+ log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s'
59
+ logging.basicConfig(
60
+ filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'),# taken from loca .env file, not set in settings.py
61
+ format=log_format,
62
+ datefmt='%d-%b-%y %H:%M:%S',
63
+ level=root_level,
64
+ force=True
65
+ )
66
+ log = logging.getLogger(which_logger)
67
+ log.setLevel(lib_level)
68
+ return log
69
+
70
+ logger = setup_logger(__name__)
71
+
72
+
73
+ class Funcs:
74
+ # helper function to unnormalize bboxes for drawing onto the image
75
+ @staticmethod
76
+ def unnormalize_box(bbox, width, height):
77
+ return [
78
+ width * (bbox[0] / 1000),
79
+ height * (bbox[1] / 1000),
80
+ width * (bbox[2] / 1000),
81
+ height * (bbox[3] / 1000),
82
+ ]
83
+
84
+ @staticmethod
85
+ def num_spans(encoding: BatchEncoding) -> int:
86
+ return len(encoding["input_ids"])
87
+
88
+ @staticmethod
89
+ def p_mask(num_spans: int, encoding: BatchEncoding) -> list:
90
+ try:
91
+ return [
92
+ [tok != 1 for tok in encoding.sequence_ids(span_id)] \
93
+ for span_id in range(num_spans)
94
+ ]
95
+ except Exception as e:
96
+ raise
97
+
98
+ @staticmethod
99
+ def token_start_end(encoding, tokenizer):
100
+ sequence_ids = encoding.sequence_ids()
101
+
102
+ # Start token index of the current span in the text.
103
+ token_start_index = 0
104
+ while sequence_ids[token_start_index] != 1:
105
+ token_start_index += 1
106
+
107
+ # End token index of the current span in the text.
108
+ token_end_index = len(encoding.input_ids) - 1
109
+ while sequence_ids[token_end_index] != 1:
110
+ token_end_index -= 1
111
+
112
+ print("Token start index:", token_start_index)
113
+ print("Token end index:", token_end_index)
114
+ print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1]))
115
+ return token_start_index, token_end_index
116
+
117
+ @staticmethod
118
+ def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer):
119
+ word_ids = encoding.word_ids()[token_start_index:token_end_index+1]
120
+ print("Word ids:", word_ids)
121
+ for id in word_ids:
122
+ if id == word_idx_start:
123
+ start_position = token_start_index
124
+ else:
125
+ token_start_index += 1
126
+
127
+ for id in word_ids[::-1]:
128
+ if id == word_idx_end:
129
+ end_position = token_end_index
130
+ else:
131
+ token_end_index -= 1
132
+
133
+ print("Reconstructed answer:",
134
+ tokenizer.decode(encoding.input_ids[start_position:end_position+1])
135
+ )
136
+ return start_position, end_position
137
+
138
+ @staticmethod
139
+ def sigmoid(_outputs):
140
+ return 1.0 / (1.0 + np.exp(-_outputs))
141
+
142
+ @staticmethod
143
+ def softmax(_outputs):
144
+ maxes = np.max(_outputs, axis=-1, keepdims=True)
145
+ shifted_exp = np.exp(_outputs - maxes)
146
+ return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
147
 
148
+ class EndpointHandler:
149
+ def __init__(self, path="./"):
150
+ # self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path).to(device)
151
+ self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path)
152
+ self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path)
153
+ # self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default
154
+ self.processor = LayoutLMv2Processor.from_pretrained(
155
+ path,
156
+ # image_processor=self.image_processor,
157
+ tokenizer=self.tokenizer)
158
 
159
  def __call__(self, data: dict[str, bytes]):
160
  """
 
166
 
167
  # image = pdf_to_image(image)
168
  images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
169
+ question = "what is the bill date"
170
+ with torch.no_grad():
171
+ for image in images:
172
+ # max_seq_len = min(self.tokenizer.model_max_length, 512)
173
+ # doc_stride = min(max_seq_len // 2, 256)
174
+ encoding = self.processor(
175
+ image,
176
+ question,
177
+ # max_length=max_seq_len,
178
+ # stride=doc_stride,
179
+ truncation=True,
180
+ # truncation=TruncationStrategy.ONLY_SECOND,
181
+ # return_offsets_mapping=True,
182
+ # return_token_type_ids=True,
183
+ # return_overflowing_tokens=True,
184
+ return_tensors=TensorType.PYTORCH
185
+ )
186
+ print('encoding: ', encoding.keys())
187
+
188
+ # for k, v in encoding.items():
189
+ # encoding[k] = v.to(self.model.device)
190
+
191
+ # num_spans = Funcs.num_spans(encoding)
192
+ # p_mask = Funcs.p_mask(num_spans, encoding)
193
+ # offset_mapping = encoding.pop('offset_mapping')
194
+ # smaple_mapping = encoding.pop('overflow_to_sample_mapping')
195
+
196
+ outputs = self.model(**encoding)
197
+ # print('model outputs: ', outputs.keys())
198
+ start_logits = outputs.start_logits
199
+ end_logits = outputs.end_logits
200
+
201
+ predicted_start_idx = start_logits.argmax(-1).item()
202
+ predicted_end_idx = end_logits.argmax(-1).item()
203
+
204
+ predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
205
+ predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
206
+ # print('answer: ', predicted_answer)
207
+ target_start_index = torch.tensor([7])
208
+ target_end_index = torch.tensor([14])
209
+ outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
210
+ predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
211
+ predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
212
+ # print(predicted_answer_span_start, predicted_answer_span_end)
213
+ logger.info(f'''
214
+ START
215
+ predicted_start_idx: {predicted_start_idx}
216
+ predicted_end_idx: {predicted_end_idx}
217
+ ---
218
+ answer: {predicted_answer}
219
+
220
+ END''')
221
+ return {'data': 'success'}