Theivaprakasham commited on
Commit
c009a30
1 Parent(s): f555c3d

added app.py

Browse files
Files changed (3) hide show
  1. app.py +129 -0
  2. packages.txt +6 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from transformers import AutoModelForTokenClassification
7
+ from datasets.features import ClassLabel
8
+ from transformers import AutoProcessor
9
+ from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
10
+ import torch
11
+ from datasets import load_metric
12
+ from transformers import LayoutLMv3ForTokenClassification
13
+ from transformers.data.data_collator import default_data_collator
14
+
15
+
16
+ from transformers import AutoModelForTokenClassification
17
+ from datasets import load_dataset
18
+ from PIL import Image, ImageDraw, ImageFont
19
+
20
+
21
+ processor = AutoProcessor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-wildreceipt", apply_ocr=True)
22
+ model = AutoModelForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-wildreceipt")
23
+
24
+
25
+
26
+ # load image example
27
+ dataset = load_dataset("Theivaprakasham/wildreceipt", split="test")
28
+ Image.open(dataset[20]["image_path"]).convert("RGB").save("example1.png")
29
+ Image.open(dataset[13]["image_path"]).convert("RGB").save("example2.png")
30
+ Image.open(dataset[15]["image_path"]).convert("RGB").save("example3.png")
31
+
32
+ # define id2label, label2color
33
+ labels = dataset.features['ner_tags'].feature.names
34
+ id2label = {v: k for v, k in enumerate(labels)}
35
+ label2color = {
36
+ "Date_key": 'red',
37
+ "Date_value": 'green',
38
+ "Ignore": 'orange',
39
+ "Others": 'orange',
40
+ "Prod_item_key": 'red',
41
+ "Prod_item_value": 'green',
42
+ "Prod_price_key": 'red',
43
+ "Prod_price_value": 'green',
44
+ "Prod_quantity_key": 'red',
45
+ "Prod_quantity_value": 'green',
46
+ "Store_addr_key": 'red',
47
+ "Store_addr_value": 'green',
48
+ "Store_name_key": 'red',
49
+ "Store_name_value": 'green',
50
+ "Subtotal_key": 'red',
51
+ "Subtotal_value": 'green',
52
+ "Tax_key": 'red',
53
+ "Tax_value": 'green',
54
+ "Tel_key": 'red',
55
+ "Tel_value": 'green',
56
+ "Time_key": 'red',
57
+ "Time_value": 'green',
58
+ "Tips_key": 'red',
59
+ "Tips_value": 'green',
60
+ "Total_key": 'red',
61
+ "Total_value": 'blue'
62
+ }
63
+
64
+ def unnormalize_box(bbox, width, height):
65
+ return [
66
+ width * (bbox[0] / 1000),
67
+ height * (bbox[1] / 1000),
68
+ width * (bbox[2] / 1000),
69
+ height * (bbox[3] / 1000),
70
+ ]
71
+
72
+
73
+ def iob_to_label(label):
74
+ return label
75
+
76
+
77
+
78
+ def process_image(image):
79
+
80
+ print(type(image))
81
+ width, height = image.size
82
+
83
+ # encode
84
+ encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
85
+ offset_mapping = encoding.pop('offset_mapping')
86
+
87
+ # forward pass
88
+ outputs = model(**encoding)
89
+
90
+ # get predictions
91
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
92
+ token_boxes = encoding.bbox.squeeze().tolist()
93
+
94
+ # only keep non-subword predictions
95
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
96
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
97
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
98
+
99
+ # draw predictions over the image
100
+ draw = ImageDraw.Draw(image)
101
+ font = ImageFont.load_default()
102
+ for prediction, box in zip(true_predictions, true_boxes):
103
+ predicted_label = iob_to_label(prediction)
104
+ draw.rectangle(box, outline=label2color[predicted_label])
105
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font)
106
+
107
+ return image
108
+
109
+
110
+ title = "Restaurant/ Hotel Bill information extraction using LayoutLMv3 model"
111
+ description = "Restaurant/ Hotel Bill information extraction - We use Microsoft's LayoutLMv3 trained on WildReceipt Dataset to predict the Store_name_value, Store_name_key, Store_addr_value, Store_addr_key, Tel_value, Tel_key, Date_value, Date_key, Time_value, Time_key, Prod_item_value, Prod_item_key, Prod_quantity_value, Prod_quantity_key, Prod_price_value, Prod_price_key, Subtotal_value, Subtotal_key, Tax_value, Tax_key, Tips_value, Tips_key, Total_value, Total_key. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
112
+
113
+ article="<b>References</b><br>[1] Y. Xu et al., “LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking.” 2022. <a href='https://arxiv.org/abs/2204.08387'>Paper Link</a><br>[2] <a href='https://github.com/NielsRogge/Transformers-Tutorials/tree/master/LayoutLMv3'>LayoutLMv3 training and inference</a><br>[3] Hongbin Sun, Zhanghui Kuang, Xiaoyu Yue, Chenhao Lin, and Wayne Zhang. 2021. Spatial Dual-Modality Graph Reasoning for Key Information Extraction. arXiv. DOI:https://doi.org/10.48550/ARXIV.2103.14470 <a href='https://doi.org/10.48550/ARXIV.2103.14470'>Paper Link</a>"
114
+
115
+ examples =[['example1.png'],['example2.png'],['example3.png']]
116
+
117
+ css = """.output_image, .input_image {height: 600px !important}"""
118
+
119
+ iface = gr.Interface(fn=process_image,
120
+ inputs=gr.inputs.Image(type="pil"),
121
+ outputs=gr.outputs.Image(type="pil", label="annotated image"),
122
+ title=title,
123
+ description=description,
124
+ article=article,
125
+ examples=examples,
126
+ css=css,
127
+ analytics_enabled = True, enable_queue=True)
128
+
129
+ iface.launch(inline=False, share=False, debug=False)
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6 -y
4
+ libgl1
5
+ -y libgl1-mesa-glx
6
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git
2
+ PyYAML==6.0
3
+ pytesseract==0.3.9
4
+ datasets==2.2.2
5
+ seqeval==1.2.2