layoutlmv2_cord / LayoutLMv2Main_cord2_gOcr_folder.py
doc2txt's picture
Upload LayoutLMv2Main_cord2_gOcr_folder.py
1eedb65 verified
# -*- coding: utf-8 -*-
"""inference with LayoutLMv2ForTokenClassification .ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1nhfx6XRncq2XsOBREZGJI7tRByt_TJIa
## Inference with LayoutLMv2ForTokenClassification + Gradio demo
In this notebook, we are going to perform inference with `LayoutLMv2ForTokenClassification` on new document images, when no label information is accessible. At the end, we will also make a cool [Gradio](https://gradio.app/) demo, that turns our inference code into a cool web interface.
## Install libraries
Let's first install the required libraries:
* HuggingFace Transformers + Detectron2 (for the model)
* HuggingFace Datasets (for getting the data)
* PyTesseract (for OCR)
"""
# !pip install -q transformers
# !pip install -q gradio
# !pip install 'git+https://github.com/facebookresearch/detectron2.git'
# !pip install -q datasets
# !sudo apt install tesseract-ocr
# !pip install -q pytesseract
# pip install torchvision
# import gradio as gr
import os
import time
import numpy as np
from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
from datasets import load_dataset
import torch
from transformers import LayoutLMv2ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
import json
from GoogleVisionService import GoogleVisionService
from getTextHelper import cord_label_to_color, get_word_boxes_google, get_word_boxes_tesseract, getImg, getImgAndPath, normalize_bbox, unnormalize_box
from datasets import load_dataset
import pytesseract
import cv2
class labelCounter:
lbl_i = 0
lbl = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
apply_ocr = False
# apply_ocr=True
processor = LayoutLMv2Processor.from_pretrained(
"microsoft/layoutlmv2-base-uncased", apply_ocr=apply_ocr)
workingPath = "/Users/eliaweiss/Documents/doc2txt/lineCv/"
docNumberList = [6, 7, 10, 21, 25, 29, 48, 50]
# load the fine-tuned model from the hub
model = LayoutLMv2ForTokenClassification.from_pretrained(
"doc2txt/layoutlmv2-finetuned-cord")
model.to(device)
# datasets = load_dataset("MarkusDressel/cord")
"""Let's create a list containing all unique labels, as well as dictionaries mapping integers to their label names and vice versa. This will be useful to convert the model's predictions to actual label names."""
# labels = datasets['train'].features['ner_tags'].feature.names
labels = ['I-menu.cnt', 'I-menu.discountprice', 'I-menu.etc', 'I-menu.itemsubtotal', 'I-menu.nm', 'I-menu.num', 'I-menu.price', 'I-menu.sub_cnt', 'I-menu.sub_etc', 'I-menu.sub_nm', 'I-menu.sub_price', 'I-menu.sub_unitprice', 'I-menu.unitprice', 'I-menu.vatyn', 'I-sub_total.discount_price', 'I-sub_total.etc',
'I-sub_total.othersvc_price', 'I-sub_total.service_price', 'I-sub_total.subtotal_price', 'I-sub_total.tax_price', 'I-total.cashprice', 'I-total.changeprice', 'I-total.creditcardprice', 'I-total.emoneyprice', 'I-total.menuqty_cnt', 'I-total.menutype_cnt', 'I-total.total_etc', 'I-total.total_price', 'I-void_menu.nm', 'I-void_menu.price']
# print(labels)
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
"""## Inference
"""
# font = ImageFont.load_default()
font = ImageFont.truetype(
"/Users/eliaweiss/work/ocrPlus/ocrPlus/DejaVuSans.ttf", 50)
# print(example.keys())
# for nn in docNumberList:
def getNextLabel(labels, true_predictions):
labelCounter.lbl = None
i = 0
while labelCounter.lbl not in true_predictions:
labelCounter.lbl_i += 1
if labelCounter.lbl_i >= len(labels) - 1:
labelCounter.lbl_i = 0
labelCounter.lbl = labels[labelCounter.lbl_i]
i += 1
if i >= len(labels):
break
return labelCounter.lbl
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
def drawLabels(image, true_predictions, true_boxes):
image_tmp = image.copy()
draw = ImageDraw.Draw(image_tmp)
color = cord_label_to_color(labelCounter.lbl)
draw.text((10, 10), text=labelCounter.lbl, fill=color, font=font)
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
color = cord_label_to_color(prediction)
if not labelCounter.lbl in prediction:
continue
# color = label2color[predicted_label] if predicted_label in label2color else 'black'
draw.rectangle(box, outline=color, width=5)
return image_tmp
# folder = "/Users/eliaweiss/.cache/huggingface/datasets/downloads/extracted/87634c2ab68012df3def8353986bcb092170ef7341c69e1a9cd97be52e513079/CORD/test/image/"
# folder = "/Users/eliaweiss/Documents/doc2txt/en_invoice_printed"
folder = "/Users/eliaweiss/ai/ICDAR-2019-SROIE/data/img"
for img_name in os.listdir(folder):
labelCounter.lbl_i = 0
start_time = time.time()
img_path = os.path.join(folder, img_name)
image = Image.open(img_path)
# image = Image.open(example['image_path'])
# pathOcr = workingPath + docNumber+".json"
# with open(pathOcr, encoding="utf-8") as f:
# gOcrJson = json.load(f)
gOcr = GoogleVisionService(img_path)
gOcrJson = gOcr.googleOcr()
image = image.convert("RGB")
width, height = image.size
"""We prepare it for the model using `LayoutLMv2Processor`."""
# Extract words and bounding boxes
words = []
boxes = []
words, boxes = get_word_boxes_google(gOcrJson)
boxes = [normalize_bbox(box, width, height) for box in boxes]
# # Use pytesseract to perform OCR on the image
# cv_image = cv2.imread(image_path)
# gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
# # Get word-level bounding boxes using pytesseract
# data = pytesseract.image_to_data(gray, output_type=pytesseract.Output.DICT)
# words, boxes = get_word_boxes_tesseract(data)
# boxes = [normalize_bbox(box, width, height) for box in boxes]
if not apply_ocr:
encoding = processor(image, words, boxes=boxes,
return_offsets_mapping=True, return_tensors="pt")
else:
encoding = processor(
image, return_offsets_mapping=True, return_tensors="pt")
offset_mapping = encoding.pop('offset_mapping')
print(encoding.keys())
"""Next, let's move everything to the GPU, if it's available."""
for k, v in encoding.items():
encoding[k] = v.to(device)
# forward pass
outputs = model(**encoding)
# print(outputs.logits.shape)
print("Time: " + str(time.time() - start_time))
"""Let's create the true predictions as well as the true boxes. With "true", I mean only taking into account tokens that are at the start of a given word. We can use the `offset_mapping` returned by the processor to determine which tokens are a subword."""
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
import numpy as np
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred]
for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(
token_boxes) if not is_subword[idx]]
# print(true_predictions)
# print(true_boxes)
"""Let's visualize the result!"""
labelCounter.lbl = "I-total.total_price" # getNextLabel(labels, true_predictions)
image_tmp = drawLabels(image, true_predictions, true_boxes)
# display the image with cv2
# Display the image
image_np = np.array(image_tmp)
# cv2.imwrite('output_image_v2.jpg', image_np)
cv2.imshow('Window Name ', image_np)
while (1):
k = cv2.waitKey(0) & 0xFF
if k == 255:
continue
if k == 126: # shift + `
break
print("k", k)
if k == 9: # tab - change direction
labelCounter.lbl = getNextLabel(labels, true_predictions)
image_tmp = drawLabels(image, true_predictions, true_boxes)
# display the image with cv2
# Display the image
image_np = np.array(image_tmp)
# cv2.imwrite('output_image_v2.jpg', image_np)
cv2.imshow('Window Name ', image_np)