gaunernst's picture
Update app.py
90a9cee verified
import cv2
import gradio as gr
import numpy as np
import torch
from paddleocr import PaddleOCR
from PIL import Image
from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
from transformers.pipelines.document_question_answering import apply_tesseract
model_tag = "impira/layoutlm-document-qa"
MODEL = LayoutLMForQuestionAnswering.from_pretrained(model_tag).eval()
TOKENIZER = AutoTokenizer.from_pretrained(model_tag)
OCR = PaddleOCR(
lang="en",
det_limit_side_len=10_000,
det_db_score_mode="slow",
)
PADDLE_OCR_LABEL = "PaddleOCR (en)"
TESSERACT_LABEL = "Tesseract (HF default)"
def predict(image: Image.Image, question: str, ocr_engine: str):
image_np = np.array(image)
if ocr_engine == PADDLE_OCR_LABEL:
ocr_result = OCR.ocr(image_np, cls=False)[0]
words = [x[1][0] for x in ocr_result]
boxes = np.asarray([x[0] for x in ocr_result]) # (n_boxes, 4, 2)
for box in boxes:
cv2.polylines(image_np, [box.reshape(-1, 1, 2).astype(int)], True, (0, 255, 255), 3)
x1 = boxes[:, :, 0].min(1) * 1000 / image.width
y1 = boxes[:, :, 1].min(1) * 1000 / image.height
x2 = boxes[:, :, 0].max(1) * 1000 / image.width
y2 = boxes[:, :, 1].max(1) * 1000 / image.height
# (n_boxes, 4) in xyxy format
boxes = np.stack([x1, y1, x2, y2], axis=1).astype(int)
elif ocr_engine == TESSERACT_LABEL:
words, boxes = apply_tesseract(image, None, "")
for x1, y1, x2, y2 in boxes:
x1 = int(x1 * image.width / 1000)
y1 = int(y1 * image.height / 1000)
x2 = int(x2 * image.width / 1000)
y2 = int(y2 * image.height / 1000)
cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 255), 3)
else:
raise ValueError(f"Unsupported ocr_engine={ocr_engine}")
token_ids = TOKENIZER(question)["input_ids"]
token_boxes = [[0] * 4] * (len(token_ids) - 1) + [[1000] * 4]
n_question_tokens = len(token_ids)
token_ids.append(TOKENIZER.sep_token_id)
token_boxes.append([1000] * 4)
for word, box in zip(words, boxes):
new_ids = TOKENIZER(word, add_special_tokens=False)["input_ids"]
token_ids.extend(new_ids)
token_boxes.extend([box] * len(new_ids))
token_ids.append(TOKENIZER.sep_token_id)
token_boxes.append([1000] * 4)
with torch.inference_mode():
outputs = MODEL(
input_ids=torch.tensor(token_ids).unsqueeze(0),
bbox=torch.tensor(token_boxes).unsqueeze(0),
)
start_scores = outputs.start_logits.squeeze(0).softmax(-1)[n_question_tokens:]
end_scores = outputs.end_logits.squeeze(0).softmax(-1)[n_question_tokens:]
span_scores = start_scores.view(-1, 1) * end_scores.view(1, -1)
span_scores = torch.triu(span_scores) # don't allow start < end
score, indices = span_scores.flatten().max(-1)
start_idx = n_question_tokens + indices // span_scores.shape[1]
end_idx = n_question_tokens + indices % span_scores.shape[1]
answer = TOKENIZER.decode(token_ids[start_idx : end_idx + 1])
return answer, score, image_np
gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
"text",
gr.Radio([PADDLE_OCR_LABEL, TESSERACT_LABEL]),
],
outputs=[
gr.Textbox(label="Answer"),
gr.Number(label="Score"),
gr.Image(label="OCR results"),
],
examples=[
["example_01.jpg", "When did the sample take place?", PADDLE_OCR_LABEL],
["example_02.jpg", "What is the ID number?", PADDLE_OCR_LABEL],
],
).launch(server_name="0.0.0.0", server_port=7860)