Spaces:
Running
Running
Upload 4 files
Browse files- app.py +44 -0
- cord_inference.py +80 -0
- sroie_inference.py +114 -0
- utils.py +40 -0
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cord_inference import prediction as cord_prediction
|
2 |
+
from sroie_inference import prediction as sroie_prediction
|
3 |
+
import gradio as gr
|
4 |
+
import json
|
5 |
+
|
6 |
+
def prediction(image):
|
7 |
+
|
8 |
+
#we first use mp-02/layoutlmv3-finetuned-cord on the image, which gives us a JSON with some info and a blurred image
|
9 |
+
d, image_blurred = sroie_prediction(image)
|
10 |
+
|
11 |
+
#then we use the model fine-tuned on sroie (for now it is Theivaprakasham/layoutlmv3-finetuned-sroie)
|
12 |
+
d1, image1 = cord_prediction(image_blurred)
|
13 |
+
|
14 |
+
#we then link the two json files
|
15 |
+
if len(d) == 0:
|
16 |
+
k = d1
|
17 |
+
else:
|
18 |
+
k = json.dumps(d).split('}')[0] + ', ' + json.dumps(d1).split('{')[1]
|
19 |
+
|
20 |
+
return d, image_blurred, d1, image1, k
|
21 |
+
|
22 |
+
|
23 |
+
title = "Interactive demo: LayoutLMv3 for receipts"
|
24 |
+
description = "Demo for Microsoft's LayoutLMv3, a Transformer for state-of-the-art document image understanding tasks. This particular model is fine-tuned on CORD and SROIE, which are datasets of receipts.\n It firsts uses the fine-tune on SROIE to extract date, company and address, then the fine-tune on CORD for the other info.\n To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
|
25 |
+
examples = [['image.png']]
|
26 |
+
|
27 |
+
css = """.output_image, .input_image {height: 600px !important}"""
|
28 |
+
|
29 |
+
# we use a gradio interface that takes in input an image and return a JSON file that contains its info
|
30 |
+
# we show also the intermediate steps (first we take some info with the model fine-tuned on SROIE and we blur the relative boxes
|
31 |
+
# then we pass the image to the model fine-tuned on CORD
|
32 |
+
iface = gr.Interface(fn=prediction,
|
33 |
+
inputs=gr.Image(type="pil"),
|
34 |
+
outputs=[gr.JSON(label="json parsing"),
|
35 |
+
gr.Image(type="pil", label="blurred image"),
|
36 |
+
gr.JSON(label="json parsing"),
|
37 |
+
gr.Image(type="pil", label="annotated image"),
|
38 |
+
gr.JSON(label="json parsing")],
|
39 |
+
title=title,
|
40 |
+
description=description,
|
41 |
+
examples=examples,
|
42 |
+
css=css)
|
43 |
+
|
44 |
+
iface.launch()
|
cord_inference.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
from utils import OCR, unnormalize_box
|
6 |
+
|
7 |
+
|
8 |
+
labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
|
9 |
+
id2label = {v: k for v, k in enumerate(labels)}
|
10 |
+
label2id = {k: v for v, k in enumerate(labels)}
|
11 |
+
|
12 |
+
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
|
13 |
+
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
|
14 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")
|
15 |
+
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
model.to(device)
|
18 |
+
|
19 |
+
|
20 |
+
def prediction(image):
|
21 |
+
boxes, words = OCR(image)
|
22 |
+
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
|
23 |
+
offset_mapping = encoding.pop('offset_mapping')
|
24 |
+
|
25 |
+
for k, v in encoding.items():
|
26 |
+
encoding[k] = v.to(device)
|
27 |
+
|
28 |
+
outputs = model(**encoding)
|
29 |
+
|
30 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
31 |
+
token_boxes = encoding.bbox.squeeze().tolist()
|
32 |
+
|
33 |
+
inp_ids = encoding.input_ids.squeeze().tolist()
|
34 |
+
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
35 |
+
|
36 |
+
width, height = image.size
|
37 |
+
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
|
38 |
+
|
39 |
+
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
40 |
+
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
41 |
+
true_words = []
|
42 |
+
|
43 |
+
for id, i in enumerate(inp_words):
|
44 |
+
if not is_subword[id]:
|
45 |
+
true_words.append(i)
|
46 |
+
else:
|
47 |
+
true_words[-1] = true_words[-1]+i
|
48 |
+
|
49 |
+
true_predictions = true_predictions[1:-1]
|
50 |
+
true_boxes = true_boxes[1:-1]
|
51 |
+
true_words = true_words[1:-1]
|
52 |
+
|
53 |
+
preds = []
|
54 |
+
l_words = []
|
55 |
+
bboxes = []
|
56 |
+
|
57 |
+
for i, j in enumerate(true_predictions):
|
58 |
+
if j != 'others':
|
59 |
+
preds.append(true_predictions[i])
|
60 |
+
l_words.append(true_words[i])
|
61 |
+
bboxes.append(true_boxes[i])
|
62 |
+
|
63 |
+
d = {}
|
64 |
+
for id, i in enumerate(preds):
|
65 |
+
if i not in d.keys():
|
66 |
+
d[i] = l_words[id]
|
67 |
+
else:
|
68 |
+
d[i] = d[i] + ", " + l_words[id]
|
69 |
+
d = {k: v.strip() for (k, v) in d.items()}
|
70 |
+
|
71 |
+
# TODO: process the json
|
72 |
+
|
73 |
+
draw = ImageDraw.Draw(image, "RGBA")
|
74 |
+
font = ImageFont.load_default()
|
75 |
+
|
76 |
+
for prediction, box in zip(preds, bboxes):
|
77 |
+
draw.rectangle(box)
|
78 |
+
draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
|
79 |
+
|
80 |
+
return d, image
|
sroie_inference.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
6 |
+
from utils import OCR, unnormalize_box
|
7 |
+
|
8 |
+
|
9 |
+
labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
|
10 |
+
id2label = {v: k for v, k in enumerate(labels)}
|
11 |
+
label2id = {k: v for v, k in enumerate(labels)}
|
12 |
+
|
13 |
+
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
14 |
+
processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
15 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
|
16 |
+
|
17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
model.to(device)
|
19 |
+
|
20 |
+
|
21 |
+
def blur(image, boxes):
|
22 |
+
image = np.array(image)
|
23 |
+
for box in boxes:
|
24 |
+
|
25 |
+
blur_x = int(box[0])
|
26 |
+
blur_y = int(box[1])
|
27 |
+
blur_width = int(box[2]-box[0])
|
28 |
+
blur_height = int(box[3]-box[1])
|
29 |
+
|
30 |
+
roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
|
31 |
+
blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
|
32 |
+
image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
|
33 |
+
|
34 |
+
return Image.fromarray(image, 'RGB')
|
35 |
+
|
36 |
+
|
37 |
+
def prediction(image):
|
38 |
+
boxes, words = OCR(image)
|
39 |
+
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
|
40 |
+
offset_mapping = encoding.pop('offset_mapping')
|
41 |
+
|
42 |
+
for k, v in encoding.items():
|
43 |
+
encoding[k] = v.to(device)
|
44 |
+
|
45 |
+
outputs = model(**encoding)
|
46 |
+
|
47 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
48 |
+
token_boxes = encoding.bbox.squeeze().tolist()
|
49 |
+
|
50 |
+
inp_ids = encoding.input_ids.squeeze().tolist()
|
51 |
+
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
52 |
+
|
53 |
+
width, height = image.size
|
54 |
+
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
|
55 |
+
|
56 |
+
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
57 |
+
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
58 |
+
true_words = []
|
59 |
+
|
60 |
+
for id, i in enumerate(inp_words):
|
61 |
+
if not is_subword[id]:
|
62 |
+
true_words.append(i)
|
63 |
+
else:
|
64 |
+
true_words[-1] = true_words[-1]+i
|
65 |
+
|
66 |
+
true_predictions = true_predictions[1:-1]
|
67 |
+
true_boxes = true_boxes[1:-1]
|
68 |
+
true_words = true_words[1:-1]
|
69 |
+
|
70 |
+
preds = []
|
71 |
+
l_words = []
|
72 |
+
bboxes = []
|
73 |
+
|
74 |
+
for i, j in enumerate(true_predictions):
|
75 |
+
if j != 'others':
|
76 |
+
preds.append(true_predictions[i])
|
77 |
+
l_words.append(true_words[i])
|
78 |
+
bboxes.append(true_boxes[i])
|
79 |
+
|
80 |
+
d = {}
|
81 |
+
for id, i in enumerate(preds):
|
82 |
+
if i not in d.keys():
|
83 |
+
d[i] = l_words[id]
|
84 |
+
else:
|
85 |
+
d[i] = d[i] + ", " + l_words[id]
|
86 |
+
|
87 |
+
d = {k: v.strip() for (k, v) in d.items()}
|
88 |
+
|
89 |
+
keys_to_pop = []
|
90 |
+
for k, v in d.items():
|
91 |
+
if k[:2] == "I-":
|
92 |
+
d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
|
93 |
+
keys_to_pop.append(k)
|
94 |
+
|
95 |
+
if "O" in d: d.pop("O")
|
96 |
+
if "B-TOTAL" in d: d.pop("B-TOTAL")
|
97 |
+
for k in keys_to_pop: d.pop(k)
|
98 |
+
|
99 |
+
blur_boxes = []
|
100 |
+
for prediction, box in zip(preds, bboxes):
|
101 |
+
if prediction != 'O' and prediction[2:] != 'TOTAL':
|
102 |
+
blur_boxes.append(box)
|
103 |
+
|
104 |
+
image = (blur(image, blur_boxes))
|
105 |
+
|
106 |
+
draw = ImageDraw.Draw(image, "RGBA")
|
107 |
+
font = ImageFont.load_default()
|
108 |
+
|
109 |
+
for prediction, box in zip(preds, bboxes):
|
110 |
+
draw.rectangle(box)
|
111 |
+
draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
|
112 |
+
|
113 |
+
return d, image
|
114 |
+
|
utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from paddleocr import PaddleOCR
|
2 |
+
from PIL import Image
|
3 |
+
from numpy import asarray
|
4 |
+
|
5 |
+
def normalize_bbox(bbox, width, height):
|
6 |
+
|
7 |
+
return [
|
8 |
+
int(1000 * (bbox[0] / width)),
|
9 |
+
int(1000 * (bbox[1] / height)),
|
10 |
+
int(1000 * (bbox[2] / width)),
|
11 |
+
int(1000 * (bbox[3] / height)),
|
12 |
+
]
|
13 |
+
|
14 |
+
def unnormalize_box(bbox, width, height):
|
15 |
+
|
16 |
+
return [
|
17 |
+
width * (bbox[0] / 1000),
|
18 |
+
height * (bbox[1] / 1000),
|
19 |
+
width * (bbox[2] / 1000),
|
20 |
+
height * (bbox[3] / 1000),
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
def OCR(image):
|
25 |
+
ocr = PaddleOCR(use_angle_cls=True)
|
26 |
+
result = ocr.ocr(asarray(image), cls=True)
|
27 |
+
bboxes = []
|
28 |
+
words = []
|
29 |
+
|
30 |
+
for idx in range(len(result)):
|
31 |
+
res = result[idx]
|
32 |
+
|
33 |
+
for line in res:
|
34 |
+
# print(line)
|
35 |
+
# print(line[0][0] + line[0][2])
|
36 |
+
bboxes.append(normalize_bbox(line[0][0]+line[0][2], image.width, image.height))
|
37 |
+
# print(line[1][0])
|
38 |
+
words.append(line[1][0])
|
39 |
+
|
40 |
+
return bboxes, words
|