mp-02 commited on
Commit
fa99101
1 Parent(s): 124bac3

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +44 -0
  2. cord_inference.py +80 -0
  3. sroie_inference.py +114 -0
  4. utils.py +40 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cord_inference import prediction as cord_prediction
2
+ from sroie_inference import prediction as sroie_prediction
3
+ import gradio as gr
4
+ import json
5
+
6
+ def prediction(image):
7
+
8
+ #we first use mp-02/layoutlmv3-finetuned-cord on the image, which gives us a JSON with some info and a blurred image
9
+ d, image_blurred = sroie_prediction(image)
10
+
11
+ #then we use the model fine-tuned on sroie (for now it is Theivaprakasham/layoutlmv3-finetuned-sroie)
12
+ d1, image1 = cord_prediction(image_blurred)
13
+
14
+ #we then link the two json files
15
+ if len(d) == 0:
16
+ k = d1
17
+ else:
18
+ k = json.dumps(d).split('}')[0] + ', ' + json.dumps(d1).split('{')[1]
19
+
20
+ return d, image_blurred, d1, image1, k
21
+
22
+
23
+ title = "Interactive demo: LayoutLMv3 for receipts"
24
+ description = "Demo for Microsoft's LayoutLMv3, a Transformer for state-of-the-art document image understanding tasks. This particular model is fine-tuned on CORD and SROIE, which are datasets of receipts.\n It firsts uses the fine-tune on SROIE to extract date, company and address, then the fine-tune on CORD for the other info.\n To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
25
+ examples = [['image.png']]
26
+
27
+ css = """.output_image, .input_image {height: 600px !important}"""
28
+
29
+ # we use a gradio interface that takes in input an image and return a JSON file that contains its info
30
+ # we show also the intermediate steps (first we take some info with the model fine-tuned on SROIE and we blur the relative boxes
31
+ # then we pass the image to the model fine-tuned on CORD
32
+ iface = gr.Interface(fn=prediction,
33
+ inputs=gr.Image(type="pil"),
34
+ outputs=[gr.JSON(label="json parsing"),
35
+ gr.Image(type="pil", label="blurred image"),
36
+ gr.JSON(label="json parsing"),
37
+ gr.Image(type="pil", label="annotated image"),
38
+ gr.JSON(label="json parsing")],
39
+ title=title,
40
+ description=description,
41
+ examples=examples,
42
+ css=css)
43
+
44
+ iface.launch()
cord_inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from utils import OCR, unnormalize_box
6
+
7
+
8
+ labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
9
+ id2label = {v: k for v, k in enumerate(labels)}
10
+ label2id = {k: v for v, k in enumerate(labels)}
11
+
12
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
13
+ processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
14
+ model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ model.to(device)
18
+
19
+
20
+ def prediction(image):
21
+ boxes, words = OCR(image)
22
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
23
+ offset_mapping = encoding.pop('offset_mapping')
24
+
25
+ for k, v in encoding.items():
26
+ encoding[k] = v.to(device)
27
+
28
+ outputs = model(**encoding)
29
+
30
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
31
+ token_boxes = encoding.bbox.squeeze().tolist()
32
+
33
+ inp_ids = encoding.input_ids.squeeze().tolist()
34
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
35
+
36
+ width, height = image.size
37
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
38
+
39
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
40
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
41
+ true_words = []
42
+
43
+ for id, i in enumerate(inp_words):
44
+ if not is_subword[id]:
45
+ true_words.append(i)
46
+ else:
47
+ true_words[-1] = true_words[-1]+i
48
+
49
+ true_predictions = true_predictions[1:-1]
50
+ true_boxes = true_boxes[1:-1]
51
+ true_words = true_words[1:-1]
52
+
53
+ preds = []
54
+ l_words = []
55
+ bboxes = []
56
+
57
+ for i, j in enumerate(true_predictions):
58
+ if j != 'others':
59
+ preds.append(true_predictions[i])
60
+ l_words.append(true_words[i])
61
+ bboxes.append(true_boxes[i])
62
+
63
+ d = {}
64
+ for id, i in enumerate(preds):
65
+ if i not in d.keys():
66
+ d[i] = l_words[id]
67
+ else:
68
+ d[i] = d[i] + ", " + l_words[id]
69
+ d = {k: v.strip() for (k, v) in d.items()}
70
+
71
+ # TODO: process the json
72
+
73
+ draw = ImageDraw.Draw(image, "RGBA")
74
+ font = ImageFont.load_default()
75
+
76
+ for prediction, box in zip(preds, bboxes):
77
+ draw.rectangle(box)
78
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
79
+
80
+ return d, image
sroie_inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
+ from utils import OCR, unnormalize_box
7
+
8
+
9
+ labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
+ id2label = {v: k for v, k in enumerate(labels)}
11
+ label2id = {k: v for v, k in enumerate(labels)}
12
+
13
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
+ processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
+ model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ model.to(device)
19
+
20
+
21
+ def blur(image, boxes):
22
+ image = np.array(image)
23
+ for box in boxes:
24
+
25
+ blur_x = int(box[0])
26
+ blur_y = int(box[1])
27
+ blur_width = int(box[2]-box[0])
28
+ blur_height = int(box[3]-box[1])
29
+
30
+ roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
31
+ blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
32
+ image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
33
+
34
+ return Image.fromarray(image, 'RGB')
35
+
36
+
37
+ def prediction(image):
38
+ boxes, words = OCR(image)
39
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
40
+ offset_mapping = encoding.pop('offset_mapping')
41
+
42
+ for k, v in encoding.items():
43
+ encoding[k] = v.to(device)
44
+
45
+ outputs = model(**encoding)
46
+
47
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
+ token_boxes = encoding.bbox.squeeze().tolist()
49
+
50
+ inp_ids = encoding.input_ids.squeeze().tolist()
51
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
52
+
53
+ width, height = image.size
54
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
55
+
56
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
57
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
58
+ true_words = []
59
+
60
+ for id, i in enumerate(inp_words):
61
+ if not is_subword[id]:
62
+ true_words.append(i)
63
+ else:
64
+ true_words[-1] = true_words[-1]+i
65
+
66
+ true_predictions = true_predictions[1:-1]
67
+ true_boxes = true_boxes[1:-1]
68
+ true_words = true_words[1:-1]
69
+
70
+ preds = []
71
+ l_words = []
72
+ bboxes = []
73
+
74
+ for i, j in enumerate(true_predictions):
75
+ if j != 'others':
76
+ preds.append(true_predictions[i])
77
+ l_words.append(true_words[i])
78
+ bboxes.append(true_boxes[i])
79
+
80
+ d = {}
81
+ for id, i in enumerate(preds):
82
+ if i not in d.keys():
83
+ d[i] = l_words[id]
84
+ else:
85
+ d[i] = d[i] + ", " + l_words[id]
86
+
87
+ d = {k: v.strip() for (k, v) in d.items()}
88
+
89
+ keys_to_pop = []
90
+ for k, v in d.items():
91
+ if k[:2] == "I-":
92
+ d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
93
+ keys_to_pop.append(k)
94
+
95
+ if "O" in d: d.pop("O")
96
+ if "B-TOTAL" in d: d.pop("B-TOTAL")
97
+ for k in keys_to_pop: d.pop(k)
98
+
99
+ blur_boxes = []
100
+ for prediction, box in zip(preds, bboxes):
101
+ if prediction != 'O' and prediction[2:] != 'TOTAL':
102
+ blur_boxes.append(box)
103
+
104
+ image = (blur(image, blur_boxes))
105
+
106
+ draw = ImageDraw.Draw(image, "RGBA")
107
+ font = ImageFont.load_default()
108
+
109
+ for prediction, box in zip(preds, bboxes):
110
+ draw.rectangle(box)
111
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
112
+
113
+ return d, image
114
+
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddleocr import PaddleOCR
2
+ from PIL import Image
3
+ from numpy import asarray
4
+
5
+ def normalize_bbox(bbox, width, height):
6
+
7
+ return [
8
+ int(1000 * (bbox[0] / width)),
9
+ int(1000 * (bbox[1] / height)),
10
+ int(1000 * (bbox[2] / width)),
11
+ int(1000 * (bbox[3] / height)),
12
+ ]
13
+
14
+ def unnormalize_box(bbox, width, height):
15
+
16
+ return [
17
+ width * (bbox[0] / 1000),
18
+ height * (bbox[1] / 1000),
19
+ width * (bbox[2] / 1000),
20
+ height * (bbox[3] / 1000),
21
+ ]
22
+
23
+
24
+ def OCR(image):
25
+ ocr = PaddleOCR(use_angle_cls=True)
26
+ result = ocr.ocr(asarray(image), cls=True)
27
+ bboxes = []
28
+ words = []
29
+
30
+ for idx in range(len(result)):
31
+ res = result[idx]
32
+
33
+ for line in res:
34
+ # print(line)
35
+ # print(line[0][0] + line[0][2])
36
+ bboxes.append(normalize_bbox(line[0][0]+line[0][2], image.width, image.height))
37
+ # print(line[1][0])
38
+ words.append(line[1][0])
39
+
40
+ return bboxes, words