File size: 7,200 Bytes
0efe8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import gradio as gr
from PIL import Image
import pdf2image
from transformers import (
    LayoutLMv3FeatureExtractor, 
    LayoutLMv3TokenizerFast, 
    LayoutLMv3Processor, 
    LayoutLMv3ForSequenceClassification
)
import torch
import logging
import traceback
import os
import tempfile

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DocumentClassifier:
    """Class to classify document types based on layout and text content"""
    def __init__(self):
        try:
            self.feature_extractor = LayoutLMv3FeatureExtractor()
            self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
                "microsoft/layoutlmv3-base"
            )
            self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer)
            self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
                "RavindiG/layoutlmv3-document-classification-v2"
            )
            self.model.eval()

            self.id2label = { 
                0: 'Financial Report', 
                1: 'Invoice or Receipt',
                2: 'Legal Document', 
                3: 'Medical Record',
                4: 'Research Paper'
            }

            logger.info("DocumentClassifier initialized successfully.")
        except Exception as e:
            logger.error(f"Error during initialization: {e}")
            logger.error(traceback.format_exc())
            raise

    def predict_document_class(self, image):
        try:
            logger.info(f"Image size: {image.size}, mode: {image.mode}")
            
            if image.mode != 'RGB':
                logger.info(f"Converting image from {image.mode} to RGB")
                image = image.convert('RGB')
            
            logger.info("Processing image...")
            encoded_inputs = self.processor(image, max_length=512, return_tensors="pt")

            device = next(self.model.parameters()).device
            
            for k, v in encoded_inputs.items():
                encoded_inputs[k] = v.to(device)

            logger.info("Running model inference...")
            with torch.no_grad():
                outputs = self.model(**encoded_inputs)

            logits = outputs.logits
            logger.info(f"Logits: {logits}")

            threshold = getattr(self, "confidence_threshold", 1.0)
            max_logit = logits.max().item()
            
            # Get probabilities for all classes
            probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
            
            if max_logit < threshold:
                logger.info("Max logit below threshold, returning 'Other'")
                predicted_label = "Other"
                confidence_dict = {"Other": 1.0}
            else:
                predicted_class_idx = logits.argmax(-1).item()
                predicted_label = self.id2label.get(predicted_class_idx, "Other")
                
                # Create confidence dictionary for all classes
                confidence_dict = {
                    self.id2label[i]: probabilities[i].item() 
                    for i in range(len(self.id2label))
                }
                
                logger.info(f"Predicted label: {predicted_label}")
            
            return predicted_label, confidence_dict
            
        except Exception as e:
            logger.error(f"Error in predict_document_class: {e}")
            logger.error(traceback.format_exc())
            raise

    def classify_document(self, file_path):
        """Classify document type based on layout and extracted text"""
        try:
            logger.info(f"Processing file: {file_path}")
            
            if file_path.lower().endswith('.pdf'):
                logger.info("Converting PDF to image...")
                images = pdf2image.convert_from_path(
                    file_path,
                    dpi=120,
                    fmt="RGB",
                    first_page=1,
                    last_page=1
                )
                if not images:
                    raise ValueError("No images extracted from PDF")
                image = images[0]
                logger.info(f"PDF converted successfully, image size: {image.size}")
            else:
                logger.info("Opening image file...")
                image = Image.open(file_path)
                logger.info(f"Image opened successfully, size: {image.size}")
            
            doc_type, confidence = self.predict_document_class(image)
            
            logger.info(f"Classification successful: {doc_type}")
            return doc_type, confidence, image
        
        except Exception as e:
            logger.error(f"Error classifying document {file_path}: {e}")
            logger.error(traceback.format_exc())
            return "Error", {"Error": 1.0}, None


# Initialize the classifier
logger.info("Initializing DocumentClassifier...")
classifier = DocumentClassifier()
logger.info("Classifier ready!")


def classify_upload(file):
    """Gradio interface function for file upload"""
    if file is None:
        return "No file uploaded", {}, None
    
    try:
        # Get the file path
        file_path = file.name
        
        # Classify the document
        doc_type, confidence, image = classifier.classify_document(file_path)
        
        return doc_type, confidence, image
    
    except Exception as e:
        logger.error(f"Error in classify_upload: {e}")
        logger.error(traceback.format_exc())
        return f"Error: {str(e)}", {}, None


# Create Gradio interface
with gr.Blocks(title="Document Classifier") as demo:
    gr.Markdown(
        """
        # πŸ“„ Document Type Classifier
        
        Upload a document image or PDF to classify it into one of the following categories:
        - Financial Report
        - Invoice or Receipt
        - Legal Document
        - Medical Record
        - Research Paper
        
        **Note:** Only the first page of PDFs will be analyzed.
        """
    )
    
    with gr.Row():
        with gr.Column():
            file_input = gr.File(
                label="Upload Document (Image or PDF)",
                file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
            )
            classify_btn = gr.Button("Classify Document", variant="primary")
        
        with gr.Column():
            label_output = gr.Textbox(label="Document Type", lines=2)
            confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5)
            image_output = gr.Image(label="Document Preview", type="pil")
    
    # Examples
    gr.Markdown("### Examples")
    gr.Markdown("Upload your document using the file uploader above.")
    
    # Connect the button to the function
    classify_btn.click(
        fn=classify_upload,
        inputs=[file_input],
        outputs=[label_output, confidence_output, image_output]
    )
    
    # Also trigger on file upload
    file_input.change(
        fn=classify_upload,
        inputs=[file_input],
        outputs=[label_output, confidence_output, image_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()