import gradio as gr from PIL import Image import pdf2image from transformers import ( LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification ) import torch import logging import traceback import os import tempfile logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DocumentClassifier: """Class to classify document types based on layout and text content""" def __init__(self): try: self.feature_extractor = LayoutLMv3FeatureExtractor() self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained( "microsoft/layoutlmv3-base" ) self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer) self.model = LayoutLMv3ForSequenceClassification.from_pretrained( "RavindiG/layoutlmv3-document-classification-v2" ) self.model.eval() self.id2label = { 0: 'Financial Report', 1: 'Invoice or Receipt', 2: 'Legal Document', 3: 'Medical Record', 4: 'Research Paper' } logger.info("DocumentClassifier initialized successfully.") except Exception as e: logger.error(f"Error during initialization: {e}") logger.error(traceback.format_exc()) raise def predict_document_class(self, image): try: logger.info(f"Image size: {image.size}, mode: {image.mode}") if image.mode != 'RGB': logger.info(f"Converting image from {image.mode} to RGB") image = image.convert('RGB') logger.info("Processing image...") encoded_inputs = self.processor(image, max_length=512, return_tensors="pt") device = next(self.model.parameters()).device for k, v in encoded_inputs.items(): encoded_inputs[k] = v.to(device) logger.info("Running model inference...") with torch.no_grad(): outputs = self.model(**encoded_inputs) logits = outputs.logits logger.info(f"Logits: {logits}") threshold = getattr(self, "confidence_threshold", 1.0) max_logit = logits.max().item() # Get probabilities for all classes probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] if max_logit < threshold: logger.info("Max logit below threshold, returning 'Other'") predicted_label = "Other" confidence_dict = {"Other": 1.0} else: predicted_class_idx = logits.argmax(-1).item() predicted_label = self.id2label.get(predicted_class_idx, "Other") # Create confidence dictionary for all classes confidence_dict = { self.id2label[i]: probabilities[i].item() for i in range(len(self.id2label)) } logger.info(f"Predicted label: {predicted_label}") return predicted_label, confidence_dict except Exception as e: logger.error(f"Error in predict_document_class: {e}") logger.error(traceback.format_exc()) raise def classify_document(self, file_path): """Classify document type based on layout and extracted text""" try: logger.info(f"Processing file: {file_path}") if file_path.lower().endswith('.pdf'): logger.info("Converting PDF to image...") images = pdf2image.convert_from_path( file_path, dpi=120, fmt="RGB", first_page=1, last_page=1 ) if not images: raise ValueError("No images extracted from PDF") image = images[0] logger.info(f"PDF converted successfully, image size: {image.size}") else: logger.info("Opening image file...") image = Image.open(file_path) logger.info(f"Image opened successfully, size: {image.size}") doc_type, confidence = self.predict_document_class(image) logger.info(f"Classification successful: {doc_type}") return doc_type, confidence, image except Exception as e: logger.error(f"Error classifying document {file_path}: {e}") logger.error(traceback.format_exc()) return "Error", {"Error": 1.0}, None # Initialize the classifier logger.info("Initializing DocumentClassifier...") classifier = DocumentClassifier() logger.info("Classifier ready!") def classify_upload(file): """Gradio interface function for file upload""" if file is None: return "No file uploaded", {}, None try: # Get the file path file_path = file.name # Classify the document doc_type, confidence, image = classifier.classify_document(file_path) return doc_type, confidence, image except Exception as e: logger.error(f"Error in classify_upload: {e}") logger.error(traceback.format_exc()) return f"Error: {str(e)}", {}, None # Create Gradio interface with gr.Blocks(title="Document Classifier") as demo: gr.Markdown( """ # 📄 Document Type Classifier Upload a document image or PDF to classify it into one of the following categories: - Financial Report - Invoice or Receipt - Legal Document - Medical Record - Research Paper **Note:** Only the first page of PDFs will be analyzed. """ ) with gr.Row(): with gr.Column(): file_input = gr.File( label="Upload Document (Image or PDF)", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"] ) classify_btn = gr.Button("Classify Document", variant="primary") with gr.Column(): label_output = gr.Textbox(label="Document Type", lines=2) confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5) image_output = gr.Image(label="Document Preview", type="pil") # Examples gr.Markdown("### Examples") gr.Markdown("Upload your document using the file uploader above.") # Connect the button to the function classify_btn.click( fn=classify_upload, inputs=[file_input], outputs=[label_output, confidence_output, image_output] ) # Also trigger on file upload file_input.change( fn=classify_upload, inputs=[file_input], outputs=[label_output, confidence_output, image_output] ) # Launch the app if __name__ == "__main__": demo.launch()