katanaml commited on
Commit
7fd208a
1 Parent(s): e3fb728

CORD inference

Browse files
Files changed (6) hide show
  1. app.py +81 -0
  2. packages.txt +1 -0
  3. requirements.txt +9 -0
  4. test0.jpeg +0 -0
  5. test1.jpeg +0 -0
  6. test2.jpeg +0 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import PIL
6
+
7
+ processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
8
+ model = LayoutLMv2ForTokenClassification.from_pretrained("katanaml/layoutlmv2-finetuned-cord")
9
+
10
+ # define id2label
11
+ id2label = model.config.id2label
12
+
13
+ label_ints = np.random.randint(0,len(PIL.ImageColor.colormap.items()),30)
14
+ label_color_pil = [k for k,_ in PIL.ImageColor.colormap.items()]
15
+ label_color = [label_color_pil[i] for i in label_ints]
16
+ label2color = {}
17
+ for k,v in id2label.items():
18
+ if v[2:] == '':
19
+ label2color['o']=label_color[k]
20
+ else:
21
+ label2color[v[2:]]=label_color[k]
22
+
23
+ def unnormalize_box(bbox, width, height):
24
+ return [
25
+ width * (bbox[0] / 1000),
26
+ height * (bbox[1] / 1000),
27
+ width * (bbox[2] / 1000),
28
+ height * (bbox[3] / 1000),
29
+ ]
30
+ def iob_to_label(label):
31
+ label = label[2:]
32
+ if not label:
33
+ return 'o'
34
+ return label
35
+
36
+ def process_image(image):
37
+ width, height = image.size
38
+
39
+ # encode
40
+ encoding = processor(image, return_offsets_mapping=True, return_tensors="pt")
41
+ offset_mapping = encoding.pop('offset_mapping')
42
+
43
+ # forward pass
44
+ outputs = model(**encoding)
45
+
46
+ # get predictions
47
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
48
+ token_boxes = encoding.bbox.squeeze().tolist()
49
+
50
+ # only keep non-subword predictions
51
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
52
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
53
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
54
+
55
+ # draw predictions over the image
56
+ draw = ImageDraw.Draw(image)
57
+ font = ImageFont.load_default()
58
+ for prediction, box in zip(true_predictions, true_boxes):
59
+ predicted_label = iob_to_label(prediction).lower()
60
+ draw.rectangle(box, outline=label2color[predicted_label])
61
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font)
62
+
63
+ return image
64
+
65
+ title = "Interactive demo: LayoutLMv2 with CORD receipts dataset"
66
+ description = "Demo for Microsoft's LayoutLMv2, a Transformer for state-of-the-art document image understanding tasks. This particular model is fine-tuned on CORD, a dataset of manually annotated receipts. It annotates the words appearing in the image in up to 30 classes. To use it, simply upload an image or use the example image below. Results will show up in a few seconds. If you want to make the output bigger, right-click on it and select ‘Open image in new tab’."
67
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.14740'>LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding</a> | <a href='https://github.com/microsoft/unilm'>Github Repo</a> | <a href='https://katanaml.io' target='_blank'>Katana ML</a> | <a href='https://github.com/katanaml/sparrow'>Sparrow Github Repo</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_cord' alt='visitor badge'></center>"
68
+ examples =[['test0.jpeg', 'test1.jpeg', 'test2.jpeg']]
69
+
70
+ css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
71
+
72
+ iface = gr.Interface(fn=process_image,
73
+ inputs=gr.inputs.Image(type="pil"),
74
+ outputs=gr.outputs.Image(type="pil", label="annotated image"),
75
+ title=title,
76
+ description=description,
77
+ article=article,
78
+ examples=examples,
79
+ css=css,
80
+ enable_queue=True)
81
+ iface.launch(debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html
3
+ gradio
4
+ transformers
5
+ pyyaml==5.1
6
+ torch==1.10.0+cu111
7
+ torchvision==0.11.0+cu111
8
+ detectron2
9
+ pytesseract
test0.jpeg ADDED
test1.jpeg ADDED
test2.jpeg ADDED