ankur-bohra commited on
Commit
4affd67
1 Parent(s): 8a18a68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -1
app.py CHANGED
@@ -1,3 +1,131 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/AliShaker/layoutlmv3-finetuned-wildreceipt").launch()
 
1
+ # https://huggingface.co/spaces/Theivaprakasham/wildreceipt/raw/main/app.py
2
+
3
+ import os
4
+ os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
5
+
6
  import gradio as gr
7
+ import numpy as np
8
+ from transformers import AutoModelForTokenClassification
9
+ from datasets.features import ClassLabel
10
+ from transformers import AutoProcessor
11
+ from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
12
+ import torch
13
+ from datasets import load_metric
14
+ from transformers import LayoutLMv3ForTokenClassification
15
+ from transformers.data.data_collator import default_data_collator
16
+
17
+
18
+ from transformers import AutoModelForTokenClassification
19
+ from datasets import load_dataset
20
+ from PIL import Image, ImageDraw, ImageFont
21
+
22
+
23
+ processor = AutoProcessor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-wildreceipt", apply_ocr=True)
24
+ model = AutoModelForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-wildreceipt")
25
+
26
+
27
+
28
+ # load image example
29
+ dataset = load_dataset("Theivaprakasham/wildreceipt", split="test")
30
+ Image.open(dataset[20]["image_path"]).convert("RGB").save("example1.png")
31
+ Image.open(dataset[13]["image_path"]).convert("RGB").save("example2.png")
32
+ Image.open(dataset[15]["image_path"]).convert("RGB").save("example3.png")
33
+
34
+ # define id2label, label2color
35
+ labels = dataset.features['ner_tags'].feature.names
36
+ id2label = {v: k for v, k in enumerate(labels)}
37
+ label2color = {
38
+ "Date_key": 'red',
39
+ "Date_value": 'green',
40
+ "Ignore": 'orange',
41
+ "Others": 'orange',
42
+ "Prod_item_key": 'red',
43
+ "Prod_item_value": 'green',
44
+ "Prod_price_key": 'red',
45
+ "Prod_price_value": 'green',
46
+ "Prod_quantity_key": 'red',
47
+ "Prod_quantity_value": 'green',
48
+ "Store_addr_key": 'red',
49
+ "Store_addr_value": 'green',
50
+ "Store_name_key": 'red',
51
+ "Store_name_value": 'green',
52
+ "Subtotal_key": 'red',
53
+ "Subtotal_value": 'green',
54
+ "Tax_key": 'red',
55
+ "Tax_value": 'green',
56
+ "Tel_key": 'red',
57
+ "Tel_value": 'green',
58
+ "Time_key": 'red',
59
+ "Time_value": 'green',
60
+ "Tips_key": 'red',
61
+ "Tips_value": 'green',
62
+ "Total_key": 'red',
63
+ "Total_value": 'blue'
64
+ }
65
+
66
+ def unnormalize_box(bbox, width, height):
67
+ return [
68
+ width * (bbox[0] / 1000),
69
+ height * (bbox[1] / 1000),
70
+ width * (bbox[2] / 1000),
71
+ height * (bbox[3] / 1000),
72
+ ]
73
+
74
+
75
+ def iob_to_label(label):
76
+ return label
77
+
78
+
79
+
80
+ def process_image(image):
81
+
82
+ print(type(image))
83
+ width, height = image.size
84
+
85
+ # encode
86
+ encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
87
+ offset_mapping = encoding.pop('offset_mapping')
88
+
89
+ # forward pass
90
+ outputs = model(**encoding)
91
+
92
+ # get predictions
93
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
94
+ token_boxes = encoding.bbox.squeeze().tolist()
95
+
96
+ # only keep non-subword predictions
97
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
98
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
99
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
100
+
101
+ # draw predictions over the image
102
+ draw = ImageDraw.Draw(image)
103
+ font = ImageFont.load_default()
104
+ for prediction, box in zip(true_predictions, true_boxes):
105
+ predicted_label = iob_to_label(prediction)
106
+ draw.rectangle(box, outline=label2color[predicted_label])
107
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font)
108
+
109
+ return image
110
+
111
+
112
+ title = "Restaurant/ Hotel Bill information extraction using LayoutLMv3 model"
113
+ description = "Restaurant/ Hotel Bill information extraction - We use Microsoft's LayoutLMv3 trained on WildReceipt Dataset to predict the Store_name_value, Store_name_key, Store_addr_value, Store_addr_key, Tel_value, Tel_key, Date_value, Date_key, Time_value, Time_key, Prod_item_value, Prod_item_key, Prod_quantity_value, Prod_quantity_key, Prod_price_value, Prod_price_key, Subtotal_value, Subtotal_key, Tax_value, Tax_key, Tips_value, Tips_key, Total_value, Total_key. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
114
+
115
+ article="<b>References</b><br>[1] Y. Xu et al., “LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking.” 2022. <a href='https://arxiv.org/abs/2204.08387'>Paper Link</a><br>[2] <a href='https://github.com/NielsRogge/Transformers-Tutorials/tree/master/LayoutLMv3'>LayoutLMv3 training and inference</a><br>[3] Hongbin Sun, Zhanghui Kuang, Xiaoyu Yue, Chenhao Lin, and Wayne Zhang. 2021. Spatial Dual-Modality Graph Reasoning for Key Information Extraction. arXiv. DOI:https://doi.org/10.48550/ARXIV.2103.14470 <a href='https://doi.org/10.48550/ARXIV.2103.14470'>Paper Link</a>"
116
+
117
+ examples =[['example1.png'],['example2.png'],['example3.png']]
118
+
119
+ css = """.output_image, .input_image {height: 600px !important}"""
120
+
121
+ iface = gr.Interface(fn=process_image,
122
+ inputs=gr.inputs.Image(type="pil"),
123
+ outputs=gr.outputs.Image(type="pil", label="annotated image"),
124
+ title=title,
125
+ description=description,
126
+ article=article,
127
+ examples=examples,
128
+ css=css,
129
+ analytics_enabled = True, enable_queue=True)
130
 
131
+ iface.launch(inline=False, share=False, debug=False)