doc2txt commited on
Commit
1eedb65
1 Parent(s): ee6fdb6

Upload LayoutLMv2Main_cord2_gOcr_folder.py

Browse files
Files changed (1) hide show
  1. LayoutLMv2Main_cord2_gOcr_folder.py +234 -0
LayoutLMv2Main_cord2_gOcr_folder.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """inference with LayoutLMv2ForTokenClassification .ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1nhfx6XRncq2XsOBREZGJI7tRByt_TJIa
8
+
9
+ ## Inference with LayoutLMv2ForTokenClassification + Gradio demo
10
+
11
+ In this notebook, we are going to perform inference with `LayoutLMv2ForTokenClassification` on new document images, when no label information is accessible. At the end, we will also make a cool [Gradio](https://gradio.app/) demo, that turns our inference code into a cool web interface.
12
+
13
+ ## Install libraries
14
+
15
+ Let's first install the required libraries:
16
+ * HuggingFace Transformers + Detectron2 (for the model)
17
+ * HuggingFace Datasets (for getting the data)
18
+ * PyTesseract (for OCR)
19
+ """
20
+
21
+ # !pip install -q transformers
22
+ # !pip install -q gradio
23
+
24
+ # !pip install 'git+https://github.com/facebookresearch/detectron2.git'
25
+
26
+ # !pip install -q datasets
27
+
28
+ # !sudo apt install tesseract-ocr
29
+ # !pip install -q pytesseract
30
+ # pip install torchvision
31
+
32
+ # import gradio as gr
33
+ import os
34
+ import time
35
+ import numpy as np
36
+ from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
37
+ from datasets import load_dataset
38
+ import torch
39
+ from transformers import LayoutLMv2ForTokenClassification
40
+ from PIL import Image, ImageDraw, ImageFont
41
+ import json
42
+ from GoogleVisionService import GoogleVisionService
43
+ from getTextHelper import cord_label_to_color, get_word_boxes_google, get_word_boxes_tesseract, getImg, getImgAndPath, normalize_bbox, unnormalize_box
44
+ from datasets import load_dataset
45
+
46
+ import pytesseract
47
+ import cv2
48
+
49
+
50
+ class labelCounter:
51
+ lbl_i = 0
52
+ lbl = None
53
+
54
+
55
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+
57
+ apply_ocr = False
58
+ # apply_ocr=True
59
+ processor = LayoutLMv2Processor.from_pretrained(
60
+ "microsoft/layoutlmv2-base-uncased", apply_ocr=apply_ocr)
61
+
62
+ workingPath = "/Users/eliaweiss/Documents/doc2txt/lineCv/"
63
+ docNumberList = [6, 7, 10, 21, 25, 29, 48, 50]
64
+ # load the fine-tuned model from the hub
65
+ model = LayoutLMv2ForTokenClassification.from_pretrained(
66
+ "doc2txt/layoutlmv2-finetuned-cord")
67
+ model.to(device)
68
+
69
+
70
+ # datasets = load_dataset("MarkusDressel/cord")
71
+
72
+ """Let's create a list containing all unique labels, as well as dictionaries mapping integers to their label names and vice versa. This will be useful to convert the model's predictions to actual label names."""
73
+
74
+ # labels = datasets['train'].features['ner_tags'].feature.names
75
+ labels = ['I-menu.cnt', 'I-menu.discountprice', 'I-menu.etc', 'I-menu.itemsubtotal', 'I-menu.nm', 'I-menu.num', 'I-menu.price', 'I-menu.sub_cnt', 'I-menu.sub_etc', 'I-menu.sub_nm', 'I-menu.sub_price', 'I-menu.sub_unitprice', 'I-menu.unitprice', 'I-menu.vatyn', 'I-sub_total.discount_price', 'I-sub_total.etc',
76
+ 'I-sub_total.othersvc_price', 'I-sub_total.service_price', 'I-sub_total.subtotal_price', 'I-sub_total.tax_price', 'I-total.cashprice', 'I-total.changeprice', 'I-total.creditcardprice', 'I-total.emoneyprice', 'I-total.menuqty_cnt', 'I-total.menutype_cnt', 'I-total.total_etc', 'I-total.total_price', 'I-void_menu.nm', 'I-void_menu.price']
77
+ # print(labels)
78
+
79
+ id2label = {v: k for v, k in enumerate(labels)}
80
+ label2id = {k: v for v, k in enumerate(labels)}
81
+
82
+
83
+ """## Inference
84
+
85
+ """
86
+ # font = ImageFont.load_default()
87
+ font = ImageFont.truetype(
88
+ "/Users/eliaweiss/work/ocrPlus/ocrPlus/DejaVuSans.ttf", 50)
89
+
90
+
91
+ # print(example.keys())
92
+ # for nn in docNumberList:
93
+ def getNextLabel(labels, true_predictions):
94
+ labelCounter.lbl = None
95
+ i = 0
96
+ while labelCounter.lbl not in true_predictions:
97
+ labelCounter.lbl_i += 1
98
+ if labelCounter.lbl_i >= len(labels) - 1:
99
+ labelCounter.lbl_i = 0
100
+ labelCounter.lbl = labels[labelCounter.lbl_i]
101
+ i += 1
102
+ if i >= len(labels):
103
+ break
104
+ return labelCounter.lbl
105
+
106
+
107
+ def iob_to_label(label):
108
+ label = label[2:]
109
+ if not label:
110
+ return 'other'
111
+ return label
112
+
113
+
114
+ def drawLabels(image, true_predictions, true_boxes):
115
+ image_tmp = image.copy()
116
+ draw = ImageDraw.Draw(image_tmp)
117
+
118
+ color = cord_label_to_color(labelCounter.lbl)
119
+
120
+ draw.text((10, 10), text=labelCounter.lbl, fill=color, font=font)
121
+
122
+ for prediction, box in zip(true_predictions, true_boxes):
123
+ predicted_label = iob_to_label(prediction).lower()
124
+ color = cord_label_to_color(prediction)
125
+
126
+ if not labelCounter.lbl in prediction:
127
+ continue
128
+ # color = label2color[predicted_label] if predicted_label in label2color else 'black'
129
+ draw.rectangle(box, outline=color, width=5)
130
+ return image_tmp
131
+
132
+ # folder = "/Users/eliaweiss/.cache/huggingface/datasets/downloads/extracted/87634c2ab68012df3def8353986bcb092170ef7341c69e1a9cd97be52e513079/CORD/test/image/"
133
+ # folder = "/Users/eliaweiss/Documents/doc2txt/en_invoice_printed"
134
+ folder = "/Users/eliaweiss/ai/ICDAR-2019-SROIE/data/img"
135
+ for img_name in os.listdir(folder):
136
+ labelCounter.lbl_i = 0
137
+ start_time = time.time()
138
+ img_path = os.path.join(folder, img_name)
139
+ image = Image.open(img_path)
140
+ # image = Image.open(example['image_path'])
141
+
142
+ # pathOcr = workingPath + docNumber+".json"
143
+ # with open(pathOcr, encoding="utf-8") as f:
144
+ # gOcrJson = json.load(f)
145
+ gOcr = GoogleVisionService(img_path)
146
+ gOcrJson = gOcr.googleOcr()
147
+
148
+
149
+ image = image.convert("RGB")
150
+ width, height = image.size
151
+
152
+ """We prepare it for the model using `LayoutLMv2Processor`."""
153
+
154
+ # Extract words and bounding boxes
155
+ words = []
156
+ boxes = []
157
+
158
+ words, boxes = get_word_boxes_google(gOcrJson)
159
+ boxes = [normalize_bbox(box, width, height) for box in boxes]
160
+
161
+ # # Use pytesseract to perform OCR on the image
162
+ # cv_image = cv2.imread(image_path)
163
+ # gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
164
+
165
+ # # Get word-level bounding boxes using pytesseract
166
+ # data = pytesseract.image_to_data(gray, output_type=pytesseract.Output.DICT)
167
+ # words, boxes = get_word_boxes_tesseract(data)
168
+ # boxes = [normalize_bbox(box, width, height) for box in boxes]
169
+
170
+ if not apply_ocr:
171
+ encoding = processor(image, words, boxes=boxes,
172
+ return_offsets_mapping=True, return_tensors="pt")
173
+ else:
174
+ encoding = processor(
175
+ image, return_offsets_mapping=True, return_tensors="pt")
176
+ offset_mapping = encoding.pop('offset_mapping')
177
+ print(encoding.keys())
178
+
179
+ """Next, let's move everything to the GPU, if it's available."""
180
+
181
+ for k, v in encoding.items():
182
+ encoding[k] = v.to(device)
183
+
184
+ # forward pass
185
+ outputs = model(**encoding)
186
+ # print(outputs.logits.shape)
187
+ print("Time: " + str(time.time() - start_time))
188
+
189
+ """Let's create the true predictions as well as the true boxes. With "true", I mean only taking into account tokens that are at the start of a given word. We can use the `offset_mapping` returned by the processor to determine which tokens are a subword."""
190
+
191
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
192
+ token_boxes = encoding.bbox.squeeze().tolist()
193
+
194
+ import numpy as np
195
+
196
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
197
+
198
+ true_predictions = [id2label[pred]
199
+ for idx, pred in enumerate(predictions) if not is_subword[idx]]
200
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(
201
+ token_boxes) if not is_subword[idx]]
202
+
203
+ # print(true_predictions)
204
+ # print(true_boxes)
205
+
206
+ """Let's visualize the result!"""
207
+
208
+ labelCounter.lbl = "I-total.total_price" # getNextLabel(labels, true_predictions)
209
+
210
+ image_tmp = drawLabels(image, true_predictions, true_boxes)
211
+
212
+ # display the image with cv2
213
+ # Display the image
214
+ image_np = np.array(image_tmp)
215
+ # cv2.imwrite('output_image_v2.jpg', image_np)
216
+ cv2.imshow('Window Name ', image_np)
217
+
218
+ while (1):
219
+ k = cv2.waitKey(0) & 0xFF
220
+ if k == 255:
221
+ continue
222
+ if k == 126: # shift + `
223
+ break
224
+ print("k", k)
225
+ if k == 9: # tab - change direction
226
+ labelCounter.lbl = getNextLabel(labels, true_predictions)
227
+
228
+ image_tmp = drawLabels(image, true_predictions, true_boxes)
229
+
230
+ # display the image with cv2
231
+ # Display the image
232
+ image_np = np.array(image_tmp)
233
+ # cv2.imwrite('output_image_v2.jpg', image_np)
234
+ cv2.imshow('Window Name ', image_np)