File size: 8,365 Bytes
1eedb65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# -*- coding: utf-8 -*-
"""inference with LayoutLMv2ForTokenClassification .ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1nhfx6XRncq2XsOBREZGJI7tRByt_TJIa
## Inference with LayoutLMv2ForTokenClassification + Gradio demo
In this notebook, we are going to perform inference with `LayoutLMv2ForTokenClassification` on new document images, when no label information is accessible. At the end, we will also make a cool [Gradio](https://gradio.app/) demo, that turns our inference code into a cool web interface.
## Install libraries
Let's first install the required libraries:
* HuggingFace Transformers + Detectron2 (for the model)
* HuggingFace Datasets (for getting the data)
* PyTesseract (for OCR)
"""
# !pip install -q transformers
# !pip install -q gradio
# !pip install 'git+https://github.com/facebookresearch/detectron2.git'
# !pip install -q datasets
# !sudo apt install tesseract-ocr
# !pip install -q pytesseract
# pip install torchvision
# import gradio as gr
import os
import time
import numpy as np
from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
from datasets import load_dataset
import torch
from transformers import LayoutLMv2ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
import json
from GoogleVisionService import GoogleVisionService
from getTextHelper import cord_label_to_color, get_word_boxes_google, get_word_boxes_tesseract, getImg, getImgAndPath, normalize_bbox, unnormalize_box
from datasets import load_dataset
import pytesseract
import cv2
class labelCounter:
lbl_i = 0
lbl = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
apply_ocr = False
# apply_ocr=True
processor = LayoutLMv2Processor.from_pretrained(
"microsoft/layoutlmv2-base-uncased", apply_ocr=apply_ocr)
workingPath = "/Users/eliaweiss/Documents/doc2txt/lineCv/"
docNumberList = [6, 7, 10, 21, 25, 29, 48, 50]
# load the fine-tuned model from the hub
model = LayoutLMv2ForTokenClassification.from_pretrained(
"doc2txt/layoutlmv2-finetuned-cord")
model.to(device)
# datasets = load_dataset("MarkusDressel/cord")
"""Let's create a list containing all unique labels, as well as dictionaries mapping integers to their label names and vice versa. This will be useful to convert the model's predictions to actual label names."""
# labels = datasets['train'].features['ner_tags'].feature.names
labels = ['I-menu.cnt', 'I-menu.discountprice', 'I-menu.etc', 'I-menu.itemsubtotal', 'I-menu.nm', 'I-menu.num', 'I-menu.price', 'I-menu.sub_cnt', 'I-menu.sub_etc', 'I-menu.sub_nm', 'I-menu.sub_price', 'I-menu.sub_unitprice', 'I-menu.unitprice', 'I-menu.vatyn', 'I-sub_total.discount_price', 'I-sub_total.etc',
'I-sub_total.othersvc_price', 'I-sub_total.service_price', 'I-sub_total.subtotal_price', 'I-sub_total.tax_price', 'I-total.cashprice', 'I-total.changeprice', 'I-total.creditcardprice', 'I-total.emoneyprice', 'I-total.menuqty_cnt', 'I-total.menutype_cnt', 'I-total.total_etc', 'I-total.total_price', 'I-void_menu.nm', 'I-void_menu.price']
# print(labels)
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
"""## Inference
"""
# font = ImageFont.load_default()
font = ImageFont.truetype(
"/Users/eliaweiss/work/ocrPlus/ocrPlus/DejaVuSans.ttf", 50)
# print(example.keys())
# for nn in docNumberList:
def getNextLabel(labels, true_predictions):
labelCounter.lbl = None
i = 0
while labelCounter.lbl not in true_predictions:
labelCounter.lbl_i += 1
if labelCounter.lbl_i >= len(labels) - 1:
labelCounter.lbl_i = 0
labelCounter.lbl = labels[labelCounter.lbl_i]
i += 1
if i >= len(labels):
break
return labelCounter.lbl
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
def drawLabels(image, true_predictions, true_boxes):
image_tmp = image.copy()
draw = ImageDraw.Draw(image_tmp)
color = cord_label_to_color(labelCounter.lbl)
draw.text((10, 10), text=labelCounter.lbl, fill=color, font=font)
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
color = cord_label_to_color(prediction)
if not labelCounter.lbl in prediction:
continue
# color = label2color[predicted_label] if predicted_label in label2color else 'black'
draw.rectangle(box, outline=color, width=5)
return image_tmp
# folder = "/Users/eliaweiss/.cache/huggingface/datasets/downloads/extracted/87634c2ab68012df3def8353986bcb092170ef7341c69e1a9cd97be52e513079/CORD/test/image/"
# folder = "/Users/eliaweiss/Documents/doc2txt/en_invoice_printed"
folder = "/Users/eliaweiss/ai/ICDAR-2019-SROIE/data/img"
for img_name in os.listdir(folder):
labelCounter.lbl_i = 0
start_time = time.time()
img_path = os.path.join(folder, img_name)
image = Image.open(img_path)
# image = Image.open(example['image_path'])
# pathOcr = workingPath + docNumber+".json"
# with open(pathOcr, encoding="utf-8") as f:
# gOcrJson = json.load(f)
gOcr = GoogleVisionService(img_path)
gOcrJson = gOcr.googleOcr()
image = image.convert("RGB")
width, height = image.size
"""We prepare it for the model using `LayoutLMv2Processor`."""
# Extract words and bounding boxes
words = []
boxes = []
words, boxes = get_word_boxes_google(gOcrJson)
boxes = [normalize_bbox(box, width, height) for box in boxes]
# # Use pytesseract to perform OCR on the image
# cv_image = cv2.imread(image_path)
# gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
# # Get word-level bounding boxes using pytesseract
# data = pytesseract.image_to_data(gray, output_type=pytesseract.Output.DICT)
# words, boxes = get_word_boxes_tesseract(data)
# boxes = [normalize_bbox(box, width, height) for box in boxes]
if not apply_ocr:
encoding = processor(image, words, boxes=boxes,
return_offsets_mapping=True, return_tensors="pt")
else:
encoding = processor(
image, return_offsets_mapping=True, return_tensors="pt")
offset_mapping = encoding.pop('offset_mapping')
print(encoding.keys())
"""Next, let's move everything to the GPU, if it's available."""
for k, v in encoding.items():
encoding[k] = v.to(device)
# forward pass
outputs = model(**encoding)
# print(outputs.logits.shape)
print("Time: " + str(time.time() - start_time))
"""Let's create the true predictions as well as the true boxes. With "true", I mean only taking into account tokens that are at the start of a given word. We can use the `offset_mapping` returned by the processor to determine which tokens are a subword."""
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
import numpy as np
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred]
for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(
token_boxes) if not is_subword[idx]]
# print(true_predictions)
# print(true_boxes)
"""Let's visualize the result!"""
labelCounter.lbl = "I-total.total_price" # getNextLabel(labels, true_predictions)
image_tmp = drawLabels(image, true_predictions, true_boxes)
# display the image with cv2
# Display the image
image_np = np.array(image_tmp)
# cv2.imwrite('output_image_v2.jpg', image_np)
cv2.imshow('Window Name ', image_np)
while (1):
k = cv2.waitKey(0) & 0xFF
if k == 255:
continue
if k == 126: # shift + `
break
print("k", k)
if k == 9: # tab - change direction
labelCounter.lbl = getNextLabel(labels, true_predictions)
image_tmp = drawLabels(image, true_predictions, true_boxes)
# display the image with cv2
# Display the image
image_np = np.array(image_tmp)
# cv2.imwrite('output_image_v2.jpg', image_np)
cv2.imshow('Window Name ', image_np)
|