Spaces:
Sleeping
Sleeping
File size: 7,200 Bytes
0efe8b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import gradio as gr
from PIL import Image
import pdf2image
from transformers import (
LayoutLMv3FeatureExtractor,
LayoutLMv3TokenizerFast,
LayoutLMv3Processor,
LayoutLMv3ForSequenceClassification
)
import torch
import logging
import traceback
import os
import tempfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DocumentClassifier:
"""Class to classify document types based on layout and text content"""
def __init__(self):
try:
self.feature_extractor = LayoutLMv3FeatureExtractor()
self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
"microsoft/layoutlmv3-base"
)
self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer)
self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
"RavindiG/layoutlmv3-document-classification-v2"
)
self.model.eval()
self.id2label = {
0: 'Financial Report',
1: 'Invoice or Receipt',
2: 'Legal Document',
3: 'Medical Record',
4: 'Research Paper'
}
logger.info("DocumentClassifier initialized successfully.")
except Exception as e:
logger.error(f"Error during initialization: {e}")
logger.error(traceback.format_exc())
raise
def predict_document_class(self, image):
try:
logger.info(f"Image size: {image.size}, mode: {image.mode}")
if image.mode != 'RGB':
logger.info(f"Converting image from {image.mode} to RGB")
image = image.convert('RGB')
logger.info("Processing image...")
encoded_inputs = self.processor(image, max_length=512, return_tensors="pt")
device = next(self.model.parameters()).device
for k, v in encoded_inputs.items():
encoded_inputs[k] = v.to(device)
logger.info("Running model inference...")
with torch.no_grad():
outputs = self.model(**encoded_inputs)
logits = outputs.logits
logger.info(f"Logits: {logits}")
threshold = getattr(self, "confidence_threshold", 1.0)
max_logit = logits.max().item()
# Get probabilities for all classes
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
if max_logit < threshold:
logger.info("Max logit below threshold, returning 'Other'")
predicted_label = "Other"
confidence_dict = {"Other": 1.0}
else:
predicted_class_idx = logits.argmax(-1).item()
predicted_label = self.id2label.get(predicted_class_idx, "Other")
# Create confidence dictionary for all classes
confidence_dict = {
self.id2label[i]: probabilities[i].item()
for i in range(len(self.id2label))
}
logger.info(f"Predicted label: {predicted_label}")
return predicted_label, confidence_dict
except Exception as e:
logger.error(f"Error in predict_document_class: {e}")
logger.error(traceback.format_exc())
raise
def classify_document(self, file_path):
"""Classify document type based on layout and extracted text"""
try:
logger.info(f"Processing file: {file_path}")
if file_path.lower().endswith('.pdf'):
logger.info("Converting PDF to image...")
images = pdf2image.convert_from_path(
file_path,
dpi=120,
fmt="RGB",
first_page=1,
last_page=1
)
if not images:
raise ValueError("No images extracted from PDF")
image = images[0]
logger.info(f"PDF converted successfully, image size: {image.size}")
else:
logger.info("Opening image file...")
image = Image.open(file_path)
logger.info(f"Image opened successfully, size: {image.size}")
doc_type, confidence = self.predict_document_class(image)
logger.info(f"Classification successful: {doc_type}")
return doc_type, confidence, image
except Exception as e:
logger.error(f"Error classifying document {file_path}: {e}")
logger.error(traceback.format_exc())
return "Error", {"Error": 1.0}, None
# Initialize the classifier
logger.info("Initializing DocumentClassifier...")
classifier = DocumentClassifier()
logger.info("Classifier ready!")
def classify_upload(file):
"""Gradio interface function for file upload"""
if file is None:
return "No file uploaded", {}, None
try:
# Get the file path
file_path = file.name
# Classify the document
doc_type, confidence, image = classifier.classify_document(file_path)
return doc_type, confidence, image
except Exception as e:
logger.error(f"Error in classify_upload: {e}")
logger.error(traceback.format_exc())
return f"Error: {str(e)}", {}, None
# Create Gradio interface
with gr.Blocks(title="Document Classifier") as demo:
gr.Markdown(
"""
# π Document Type Classifier
Upload a document image or PDF to classify it into one of the following categories:
- Financial Report
- Invoice or Receipt
- Legal Document
- Medical Record
- Research Paper
**Note:** Only the first page of PDFs will be analyzed.
"""
)
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Upload Document (Image or PDF)",
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
)
classify_btn = gr.Button("Classify Document", variant="primary")
with gr.Column():
label_output = gr.Textbox(label="Document Type", lines=2)
confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5)
image_output = gr.Image(label="Document Preview", type="pil")
# Examples
gr.Markdown("### Examples")
gr.Markdown("Upload your document using the file uploader above.")
# Connect the button to the function
classify_btn.click(
fn=classify_upload,
inputs=[file_input],
outputs=[label_output, confidence_output, image_output]
)
# Also trigger on file upload
file_input.change(
fn=classify_upload,
inputs=[file_input],
outputs=[label_output, confidence_output, image_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |