|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import pdf2image |
|
|
from transformers import ( |
|
|
LayoutLMv3FeatureExtractor, |
|
|
LayoutLMv3TokenizerFast, |
|
|
LayoutLMv3Processor, |
|
|
LayoutLMv3ForSequenceClassification |
|
|
) |
|
|
import torch |
|
|
import logging |
|
|
import traceback |
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DocumentClassifier: |
|
|
"""Class to classify document types based on layout and text content""" |
|
|
def __init__(self): |
|
|
try: |
|
|
self.feature_extractor = LayoutLMv3FeatureExtractor() |
|
|
self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained( |
|
|
"microsoft/layoutlmv3-base" |
|
|
) |
|
|
self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer) |
|
|
self.model = LayoutLMv3ForSequenceClassification.from_pretrained( |
|
|
"RavindiG/layoutlmv3-document-classification-v2" |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
self.id2label = { |
|
|
0: 'Financial Report', |
|
|
1: 'Invoice or Receipt', |
|
|
2: 'Legal Document', |
|
|
3: 'Medical Record', |
|
|
4: 'Research Paper' |
|
|
} |
|
|
|
|
|
logger.info("DocumentClassifier initialized successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during initialization: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def predict_document_class(self, image): |
|
|
try: |
|
|
logger.info(f"Image size: {image.size}, mode: {image.mode}") |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
logger.info(f"Converting image from {image.mode} to RGB") |
|
|
image = image.convert('RGB') |
|
|
|
|
|
logger.info("Processing image...") |
|
|
encoded_inputs = self.processor(image, max_length=512, return_tensors="pt") |
|
|
|
|
|
device = next(self.model.parameters()).device |
|
|
|
|
|
for k, v in encoded_inputs.items(): |
|
|
encoded_inputs[k] = v.to(device) |
|
|
|
|
|
logger.info("Running model inference...") |
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**encoded_inputs) |
|
|
|
|
|
logits = outputs.logits |
|
|
logger.info(f"Logits: {logits}") |
|
|
|
|
|
threshold = getattr(self, "confidence_threshold", 1.0) |
|
|
max_logit = logits.max().item() |
|
|
|
|
|
|
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] |
|
|
|
|
|
if max_logit < threshold: |
|
|
logger.info("Max logit below threshold, returning 'Other'") |
|
|
predicted_label = "Other" |
|
|
confidence_dict = {"Other": 1.0} |
|
|
else: |
|
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
predicted_label = self.id2label.get(predicted_class_idx, "Other") |
|
|
|
|
|
|
|
|
confidence_dict = { |
|
|
self.id2label[i]: probabilities[i].item() |
|
|
for i in range(len(self.id2label)) |
|
|
} |
|
|
|
|
|
logger.info(f"Predicted label: {predicted_label}") |
|
|
|
|
|
return predicted_label, confidence_dict |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in predict_document_class: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def classify_document(self, file_path): |
|
|
"""Classify document type based on layout and extracted text""" |
|
|
try: |
|
|
logger.info(f"Processing file: {file_path}") |
|
|
|
|
|
if file_path.lower().endswith('.pdf'): |
|
|
logger.info("Converting PDF to image...") |
|
|
images = pdf2image.convert_from_path( |
|
|
file_path, |
|
|
dpi=120, |
|
|
fmt="RGB", |
|
|
first_page=1, |
|
|
last_page=1 |
|
|
) |
|
|
if not images: |
|
|
raise ValueError("No images extracted from PDF") |
|
|
image = images[0] |
|
|
logger.info(f"PDF converted successfully, image size: {image.size}") |
|
|
else: |
|
|
logger.info("Opening image file...") |
|
|
image = Image.open(file_path) |
|
|
logger.info(f"Image opened successfully, size: {image.size}") |
|
|
|
|
|
doc_type, confidence = self.predict_document_class(image) |
|
|
|
|
|
logger.info(f"Classification successful: {doc_type}") |
|
|
return doc_type, confidence, image |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error classifying document {file_path}: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return "Error", {"Error": 1.0}, None |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Initializing DocumentClassifier...") |
|
|
classifier = DocumentClassifier() |
|
|
logger.info("Classifier ready!") |
|
|
|
|
|
|
|
|
def classify_upload(file): |
|
|
"""Gradio interface function for file upload""" |
|
|
if file is None: |
|
|
return "No file uploaded", {}, None |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = file.name |
|
|
|
|
|
|
|
|
doc_type, confidence, image = classifier.classify_document(file_path) |
|
|
|
|
|
return doc_type, confidence, image |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in classify_upload: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return f"Error: {str(e)}", {}, None |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Document Classifier") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π Document Type Classifier |
|
|
|
|
|
Upload a document image or PDF to classify it into one of the following categories: |
|
|
- Financial Report |
|
|
- Invoice or Receipt |
|
|
- Legal Document |
|
|
- Medical Record |
|
|
- Research Paper |
|
|
|
|
|
**Note:** Only the first page of PDFs will be analyzed. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input = gr.File( |
|
|
label="Upload Document (Image or PDF)", |
|
|
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"] |
|
|
) |
|
|
classify_btn = gr.Button("Classify Document", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
label_output = gr.Textbox(label="Document Type", lines=2) |
|
|
confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5) |
|
|
image_output = gr.Image(label="Document Preview", type="pil") |
|
|
|
|
|
|
|
|
gr.Markdown("### Examples") |
|
|
gr.Markdown("Upload your document using the file uploader above.") |
|
|
|
|
|
|
|
|
classify_btn.click( |
|
|
fn=classify_upload, |
|
|
inputs=[file_input], |
|
|
outputs=[label_output, confidence_output, image_output] |
|
|
) |
|
|
|
|
|
|
|
|
file_input.change( |
|
|
fn=classify_upload, |
|
|
inputs=[file_input], |
|
|
outputs=[label_output, confidence_output, image_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |