Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from PIL import Image, ImageDraw | |
import traceback | |
import torch | |
from docquery import pipeline | |
from docquery.document import load_bytes, load_document, ImageDocument | |
from docquery.ocr_reader import get_ocr_reader | |
from pdf2image import convert_from_path | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Initialize the logger | |
logging.basicConfig(filename="invoice_extraction.log", level=logging.DEBUG) # Create a log file | |
# Checkpoint for different models | |
CHECKPOINTS = { | |
"LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices", | |
} | |
PIPELINES = {} | |
class InvoiceKeyValuePair(): | |
""" | |
This class provides a utility to extract key-value pairs from invoices using LayoutLM. | |
""" | |
def __init__(self): | |
self.fields = { | |
"Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"], | |
"Vendor Address": ["Vendor Address?"], | |
"Customer Name": ["Customer Name?"], | |
"Customer Address": ["Customer Address?"], | |
"Invoice Number": ["Invoice Number?"], | |
"Invoice Date": ["Invoice Date?"], | |
"Due Date": ["Due Date?"], | |
"Subtotal": ["Subtotal?"], | |
"Total Tax": ["Total Tax?"], | |
"Invoice Total": ["Invoice Total?"], | |
"Amount Due": ["Amount Due?"], | |
"Payment Terms": ["Payment Terms?"], | |
"Remit To Name": ["Remit To Name?"], | |
"Remit To Address": ["Remit To Address?"], | |
} | |
self.model = list(CHECKPOINTS.keys())[0] | |
def ensure_list(self, x): | |
try: | |
# Log the function entry | |
logging.info(f'Entering ensure_list with x={x}') | |
# Check if 'x' is already a list | |
if isinstance(x, list): | |
return x | |
else: | |
# If 'x' is not a list, wrap it in a list and return | |
return [x] | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def construct_pipeline(self, task, model): | |
try: | |
# Log the function entry | |
logging.info(f'Entering construct_pipeline with task={task} and model={model}') | |
# Global dictionary to cache pipelines based on model checkpoint names | |
global PIPELINES | |
# Check if a pipeline for the specified model already exists in the cache | |
if model in PIPELINES: | |
# If it exists, return the cached pipeline | |
return PIPELINES[model] | |
try: | |
# Determine the device to use for inference (GPU if available, else CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Create the pipeline using the specified task and model checkpoint | |
ret = pipeline(task=task, model=CHECKPOINTS[model], device=device) | |
# Cache the created pipeline for future use | |
PIPELINES[model] = ret | |
# Return the constructed pipeline | |
return ret | |
except Exception as e: | |
# Handle exceptions and log the error message | |
logging.error("An error occurred:", exc_info=True) | |
return None | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return None | |
def run_pipeline(self, model, question, document, top_k): | |
try: | |
# Log the function entry | |
logging.info(f'Entering run_pipeline with model={model}, question={question}, and document={document}') | |
# Use the construct_pipeline method to get or create a pipeline for the specified model | |
pipeline = self.construct_pipeline("document-question-answering", model) | |
# Use the constructed pipeline to perform question-answering on the document | |
# Pass the question, document context, and top_k as arguments to the pipeline | |
return pipeline(question=question, **document.context, top_k=top_k) | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return None | |
def lift_word_boxes(self, document, page): | |
try: | |
# Log the function entry | |
logging.info(f'Entering lift_word_boxes with document={document} and page={page}') | |
# Extract the word boxes for the specified page from the document's context | |
return document.context["image"][page][1] | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def expand_bbox(self, word_boxes): | |
try: | |
# Log the function entry | |
logging.info(f'Entering expand_bbox with word_boxes={word_boxes}') | |
# Check if the input list of word boxes is empty | |
if len(word_boxes) == 0: | |
return None | |
# Extract the minimum and maximum coordinates of the word boxes | |
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes]) | |
# Calculate the overall minimum and maximum coordinates | |
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)] | |
# Return the expanded bounding box as [min_x, min_y, max_x, max_y] | |
return [min_x, min_y, max_x, max_y] | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return None | |
def normalize_bbox(self, box, width, height, padding=0.005): | |
try: | |
# Log the function entry | |
logging.info(f'Entering normalize_bbox with box={box}, width={width}, height={height}, and padding={padding}') | |
# Extract the bounding box coordinates and convert them from millimeters to fractions | |
min_x, min_y, max_x, max_y = [c / 1000 for c in box] | |
# Apply padding if specified (as a fraction of image dimensions) | |
if padding != 0: | |
min_x = max(0, min_x - padding) | |
min_y = max(0, min_y - padding) | |
max_x = min(max_x + padding, 1) | |
max_y = min(max_y + padding, 1) | |
# Scale the normalized coordinates to match the image dimensions | |
return [min_x * width, min_y * height, max_x * width, max_y * height] | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return None | |
def annotate_page(self, prediction, pages, document): | |
try: | |
# Log the function entry | |
logging.info(f'Entering annotate_page with prediction={prediction}, pages={pages}, and document={document}') | |
# Check if a prediction exists and contains word_ids | |
if prediction is not None and "word_ids" in prediction: | |
# Get the image of the page where the prediction was made | |
image = pages[prediction["page"]] | |
# Create a drawing object for the image | |
draw = ImageDraw.Draw(image, "RGBA") | |
# Extract word boxes for the page | |
word_boxes = self.lift_word_boxes(document, prediction["page"]) | |
# Expand and normalize the bounding box of the predicted words | |
x1, y1, x2, y2 = self.normalize_bbox( | |
self.expand_bbox([word_boxes[i] for i in prediction["word_ids"]]), | |
image.width, | |
image.height, | |
) | |
# Draw a semi-transparent green rectangle around the predicted words | |
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255))) | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
def process_fields(self, document, fields, model=list(CHECKPOINTS.keys())[0]): | |
try: | |
# Log the function entry | |
logging.info(f'Entering process_fields with document={document}, fields={fields}, and model={model}') | |
# Convert preview pages of the document to RGB format | |
pages = [x.copy().convert("RGB") for x in document.preview] | |
# Initialize dictionaries to store results | |
ret = {} | |
table = [] | |
# Iterate through the fields and associated questions | |
for (field_name, questions) in fields.items(): | |
# Extract answers for each question and filter based on score | |
answers = [ | |
a | |
for q in questions | |
for a in self.ensure_list(self.run_pipeline(model, q, document, top_k=1)) | |
if a.get("score", 1) > 0.5 | |
] | |
# Sort answers by score (higher score first) | |
answers.sort(key=lambda x: -x.get("score", 0) if x else 0) | |
# Get the top answer (if any) | |
top = answers[0] if len(answers) > 0 else None | |
# Annotate the page with the top answer's bounding box | |
self.annotate_page(top, pages, document) | |
# Store the top answer for the field and add it to the table | |
ret[field_name] = top | |
table.append([field_name, top.get("answer") if top is not None else None]) | |
# Return the table of key-value pairs | |
return table | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def process_document(self, document, fields, model, error=None): | |
try: | |
# Log the function entry | |
logging.info(f'Entering process_document with document={document}, fields={fields}, model={model}, and error={error}') | |
# Check if the document is not None and no error occurred during processing | |
if document is not None and error is None: | |
# Process the fields in the document using the specified model | |
table = self.process_fields(document, fields, model) | |
return table | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def process_path(self, path, fields, model): | |
try: | |
# Log the function entry | |
logging.info(f'Entering process_path with path={path}, fields={fields}, and model={model}') | |
# Initialize error and document variables | |
error = None | |
document = None | |
# Check if a file path is provided | |
if path: | |
try: | |
# Load the document from the specified file path | |
document = load_document(path) | |
except Exception as e: | |
# Handle exceptions and store the error message | |
logging.error("An error occurred:", exc_info=True) | |
error = str(e) | |
# Process the loaded document and extract key-value pairs | |
return self.process_document(document, fields, model, error) | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def pdf_to_image(self, file_path): | |
try: | |
# Log the function entry | |
logging.info(f'Entering pdf_to_image with file_path={file_path}') | |
# Convert PDF to a list of image objects (one for each page) | |
images = convert_from_path(file_path) | |
# Loop through each image and save it | |
for i, image in enumerate(images): | |
image_path = f'page_{i + 1}.png' | |
return image_path | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def process_upload(self, file): | |
try: | |
# Log the function entry | |
logging.info(f'Entering process_upload with file={file}') | |
# Get the model and fields from the instance | |
model = self.model | |
fields = self.fields | |
# Convert the uploaded PDF file to a list of image files | |
image = self.pdf_to_image(file) | |
# Use the first generated image file as the file path for processing | |
file = image | |
# Process the document (image) and extract key-value pairs | |
return self.process_path(file if file else None, fields, model) | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |
return [] | |
def extract_key_value_pair(self, invoice_file): | |
try: | |
# Log the function entry | |
logging.info(f'Entering extract_key_value_pair with invoice_file={invoice_file}') | |
# Process the uploaded invoice PDF file and extract key-value pairs | |
data = self.process_upload(invoice_file.name) | |
# Iterate through the extracted key-value pairs and print them | |
for item in data: | |
key, value = item | |
return f'{key}: {value}' | |
except Exception as e: | |
# Log exceptions | |
logging.error("An error occurred:", exc_info=True) | |