rishabh062's picture
Rename main.py to app.py
4195208
raw
history blame
15.3 kB
import io
import os
import boto3
import traceback
import re
import logging
import gradio as gr
from PIL import Image, ImageDraw
from docquery.document import load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import pipeline
# avoid ssl errors
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Init models
layoutlm_pipeline = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
lilt_model = AutoModelForQuestionAnswering.from_pretrained(
"nielsr/lilt-xlm-roberta-base"
)
donut_processor = DonutProcessor.from_pretrained(
"naver-clova-ix/donut-base-finetuned-docvqa"
)
donut_model = VisionEncoderDecoderModel.from_pretrained(
"naver-clova-ix/donut-base-finetuned-docvqa"
)
TEXTRACT = "Textract Query"
LAYOUTLM = "LayoutLM"
DONUT = "Donut"
LILT = "LiLT"
def image_to_byte_array(image: Image) -> bytes:
image_as_byte_array = io.BytesIO()
image.save(image_as_byte_array, format="PNG")
image_as_byte_array = image_as_byte_array.getvalue()
return image_as_byte_array
def run_textract(question, document):
logger.info(f"Running Textract model.")
image_as_byte_base64 = image_to_byte_array(image=document.b)
response = boto3.client("textract").analyze_document(
Document={
"Bytes": image_as_byte_base64,
},
FeatureTypes=[
"QUERIES",
],
QueriesConfig={
"Queries": [
{
"Text": question,
"Pages": [
"*",
],
},
]
},
)
logger.info(f"Output of Textract model {response}.")
for element in response["Blocks"]:
if element["BlockType"] == "QUERY_RESULT":
return {
"score": element["Confidence"],
"answer": element["Text"],
# "word_ids": element
}
else:
Exception("No QUERY_RESULT found in the response from Textract.")
def run_layoutlm(question, document):
logger.info(f"Running layoutlm model.")
result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
logger.info(f"Output of layoutlm model {result}.")
# [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
return {
"score": result["score"],
"answer": result["answer"],
"word_ids": [result["start"], result["end"]],
"page": 0,
}
def run_lilt(question, document):
logger.info(f"Running lilt model.")
# use this model + tokenizer
processed_document = document.context["image"][0][1]
words = [x[0] for x in processed_document]
boxes = [x[1] for x in processed_document]
encoding = lilt_tokenizer(
text=question,
text_pair=words,
boxes=boxes,
add_special_tokens=True,
return_tensors="pt",
)
outputs = lilt_model(**encoding)
logger.info(f"Output for lilt model {outputs}.")
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = encoding.input_ids[
0, answer_start_index: answer_end_index + 1
]
predict_answer = lilt_tokenizer.decode(
predict_answer_tokens, skip_special_tokens=True
)
return {
"score": "n/a",
"answer": predict_answer,
# "word_ids": element
}
def run_donut(question, document):
logger.info(f"Running donut model.")
# prepare encoder inputs
pixel_values = donut_processor(
document.context["image"][0][0], return_tensors="pt"
).pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = donut_processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
# generate answer
outputs = donut_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=donut_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=donut_processor.tokenizer.pad_token_id,
eos_token_id=donut_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
logger.info(f"Output for donut {outputs}")
sequence = donut_processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
donut_processor.tokenizer.pad_token, ""
)
sequence = re.sub(
r"<.*?>", "", sequence, count=1
).strip() # remove first task start token
result = donut_processor.token2json(sequence)
return {
"score": "n/a",
"answer": result["answer"],
# "word_ids": element
}
def process_path(path):
error = None
if path:
try:
document = load_document(path)
return (
document,
gr.update(visible=True, value=document.preview),
gr.update(visible=True),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
except Exception as e:
traceback.print_exc()
error = str(e)
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=True, value=error) if error is not None else None,
None,
)
def process_upload(file):
if file:
return process_path(file.name)
else:
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
def lift_word_boxes(document, page):
return document.context["image"][page][1]
def expand_bbox(word_boxes):
if len(word_boxes) == 0:
return None
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
return [min_x, min_y, max_x, max_y]
# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height, padding=0.005):
min_x, min_y, max_x, max_y = [c / 1000 for c in box]
if padding != 0:
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(max_x + padding, 1)
max_y = min(max_y + padding, 1)
return [min_x * width, min_y * height, max_x * width, max_y * height]
MODELS = {
LAYOUTLM: run_layoutlm,
DONUT: run_donut,
# LILT: run_lilt,
TEXTRACT: run_textract,
}
def process_question(question, document, model=list(MODELS.keys())[0]):
if not question or document is None:
return None, None, None
logger.info(f"Running for model {model}")
prediction = MODELS[model](question=question, document=document)
logger.info(f"Got prediction {prediction}")
pages = [x.copy().convert("RGB") for x in document.preview]
text_value = prediction["answer"]
if "word_ids" in prediction:
logger.info(f"Setting bounding boxes.")
image = pages[prediction["page"]]
draw = ImageDraw.Draw(image, "RGBA")
word_boxes = lift_word_boxes(document, prediction["page"])
x1, y1, x2, y2 = normalize_bbox(
expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
image.width,
image.height,
)
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
return (
gr.update(visible=True, value=pages),
gr.update(visible=True, value=prediction),
gr.update(
visible=True,
value=text_value,
),
)
def load_example_document(img, question, model):
if img is not None:
document = ImageDocument(Image.fromarray(img), get_ocr_reader())
preview, answer, answer_text = process_question(question, document, model)
return document, question, preview, gr.update(visible=True), answer, answer_text
else:
return None, None, None, gr.update(visible=False), None, None
CSS = """
#question input {
font-size: 16px;
}
#url-textbox {
padding: 0 !important;
}
#short-upload-box .w-full {
min-height: 10rem !important;
}
/* I think something like this can be used to re-shape
* the table
*/
/*
.gr-samples-table tr {
display: inline;
}
.gr-samples-table .p-2 {
width: 100px;
}
*/
#select-a-file {
width: 100%;
}
#file-clear {
padding-top: 2px !important;
padding-bottom: 2px !important;
padding-left: 8px !important;
padding-right: 8px !important;
margin-top: 10px;
}
.gradio-container .gr-button-primary {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700;
}
.gradio-container.dark button#submit-button {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700
}
table.gr-samples-table tr td {
border: none;
outline: none;
}
table.gr-samples-table tr td:first-of-type {
width: 0%;
}
div#short-upload-box div.absolute {
display: none !important;
}
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
gap: 0px 2%;
}
gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
gap: 0px;
}
gradio-app h2, .gradio-app h2 {
padding-top: 10px;
}
#answer {
overflow-y: scroll;
color: white;
background: #666;
border-color: #666;
font-size: 20px;
font-weight: bold;
}
#answer span {
color: white;
}
#answer textarea {
color:white;
background: #777;
border-color: #777;
font-size: 18px;
}
#url-error input {
color: red;
}
"""
examples = [
[
"scenario-1.png",
"What is the final consignee?",
],
[
"scenario-1.png",
"What are the payment terms?",
],
[
"scenario-2.png",
"What is the actual manufacturer?",
],
[
"scenario-3.png",
'What is the "ship to" destination?',
],
[
"scenario-4.png",
"What is the color?",
],
[
"scenario-5.png",
'What is the "said to contain"?',
],
[
"scenario-5.png",
'What is the "Net Weight"?',
],
[
"scenario-5.png",
'What is the "Freight Collect"?',
],
[
"bill_of_lading_1.png",
"What is the shipper?",
],
[
"japanese-invoice.png",
"What is the total amount?",
]
]
with gr.Blocks(css=CSS) as demo:
gr.Markdown("# Document Question Answer Comparator")
gr.Markdown("""
This space compares some of the latest models that can be used commercially.
- [LayoutLM](https://huggingface.co/impira/layoutlm-document-qa) uses text/layout and images. Uses tesseract for OCR.
- [Donut](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) OCR free document understanding. Uses vision encoder for OCR and a text decoder for providing the answer.
- [Textract Query](https://docs.aws.amazon.com/textract/latest/dg/what-is.html) OCR + document understanding solution of AWS.
""")
document = gr.Variable()
example_question = gr.Textbox(visible=False)
example_image = gr.Image(visible=False)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
gr.Markdown("## 1. Select a file", elem_id="select-a-file")
img_clear_button = gr.Button(
"Clear", variant="secondary", elem_id="file-clear", visible=False
)
image = gr.Gallery(visible=False)
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
gr.Examples(
examples=examples,
inputs=[example_image, example_question],
)
with gr.Column() as col:
gr.Markdown("## 2. Ask a question")
question = gr.Textbox(
label="Question",
placeholder="e.g. What is the invoice number?",
lines=1,
max_lines=1,
)
model = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
with gr.Row():
clear_button = gr.Button("Clear", variant="secondary")
submit_button = gr.Button(
"Submit", variant="primary", elem_id="submit-button"
)
with gr.Column():
output_text = gr.Textbox(
label="Top Answer", visible=False, elem_id="answer"
)
output = gr.JSON(label="Output", visible=False)
for cb in [img_clear_button, clear_button]:
cb.click(
lambda _: (
gr.update(visible=False, value=None),
None,
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False),
None,
None,
None,
gr.update(visible=False, value=None),
None,
),
inputs=clear_button,
outputs=[
image,
document,
output,
output_text,
img_clear_button,
example_image,
upload,
question,
],
)
upload.change(
fn=process_upload,
inputs=[upload],
outputs=[document, image, img_clear_button, output, output_text],
)
question.submit(
fn=process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
submit_button.click(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
model.change(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
example_image.change(
fn=load_example_document,
inputs=[example_image, example_question, model],
outputs=[document, question, image, img_clear_button, output, output_text],
)
if __name__ == "__main__":
demo.launch(enable_queue=False)