Ravindi Gunarathna commited on
Commit
0efe8b2
·
1 Parent(s): 3476479

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +209 -0
  2. packages.txt +2 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import pdf2image
4
+ from transformers import (
5
+ LayoutLMv3FeatureExtractor,
6
+ LayoutLMv3TokenizerFast,
7
+ LayoutLMv3Processor,
8
+ LayoutLMv3ForSequenceClassification
9
+ )
10
+ import torch
11
+ import logging
12
+ import traceback
13
+ import os
14
+ import tempfile
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class DocumentClassifier:
20
+ """Class to classify document types based on layout and text content"""
21
+ def __init__(self):
22
+ try:
23
+ self.feature_extractor = LayoutLMv3FeatureExtractor()
24
+ self.tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
25
+ "microsoft/layoutlmv3-base"
26
+ )
27
+ self.processor = LayoutLMv3Processor(self.feature_extractor, self.tokenizer)
28
+ self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
29
+ "RavindiG/layoutlmv3-document-classification-v2"
30
+ )
31
+ self.model.eval()
32
+
33
+ self.id2label = {
34
+ 0: 'Financial Report',
35
+ 1: 'Invoice or Receipt',
36
+ 2: 'Legal Document',
37
+ 3: 'Medical Record',
38
+ 4: 'Research Paper'
39
+ }
40
+
41
+ logger.info("DocumentClassifier initialized successfully.")
42
+ except Exception as e:
43
+ logger.error(f"Error during initialization: {e}")
44
+ logger.error(traceback.format_exc())
45
+ raise
46
+
47
+ def predict_document_class(self, image):
48
+ try:
49
+ logger.info(f"Image size: {image.size}, mode: {image.mode}")
50
+
51
+ if image.mode != 'RGB':
52
+ logger.info(f"Converting image from {image.mode} to RGB")
53
+ image = image.convert('RGB')
54
+
55
+ logger.info("Processing image...")
56
+ encoded_inputs = self.processor(image, max_length=512, return_tensors="pt")
57
+
58
+ device = next(self.model.parameters()).device
59
+
60
+ for k, v in encoded_inputs.items():
61
+ encoded_inputs[k] = v.to(device)
62
+
63
+ logger.info("Running model inference...")
64
+ with torch.no_grad():
65
+ outputs = self.model(**encoded_inputs)
66
+
67
+ logits = outputs.logits
68
+ logger.info(f"Logits: {logits}")
69
+
70
+ threshold = getattr(self, "confidence_threshold", 1.0)
71
+ max_logit = logits.max().item()
72
+
73
+ # Get probabilities for all classes
74
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
75
+
76
+ if max_logit < threshold:
77
+ logger.info("Max logit below threshold, returning 'Other'")
78
+ predicted_label = "Other"
79
+ confidence_dict = {"Other": 1.0}
80
+ else:
81
+ predicted_class_idx = logits.argmax(-1).item()
82
+ predicted_label = self.id2label.get(predicted_class_idx, "Other")
83
+
84
+ # Create confidence dictionary for all classes
85
+ confidence_dict = {
86
+ self.id2label[i]: probabilities[i].item()
87
+ for i in range(len(self.id2label))
88
+ }
89
+
90
+ logger.info(f"Predicted label: {predicted_label}")
91
+
92
+ return predicted_label, confidence_dict
93
+
94
+ except Exception as e:
95
+ logger.error(f"Error in predict_document_class: {e}")
96
+ logger.error(traceback.format_exc())
97
+ raise
98
+
99
+ def classify_document(self, file_path):
100
+ """Classify document type based on layout and extracted text"""
101
+ try:
102
+ logger.info(f"Processing file: {file_path}")
103
+
104
+ if file_path.lower().endswith('.pdf'):
105
+ logger.info("Converting PDF to image...")
106
+ images = pdf2image.convert_from_path(
107
+ file_path,
108
+ dpi=120,
109
+ fmt="RGB",
110
+ first_page=1,
111
+ last_page=1
112
+ )
113
+ if not images:
114
+ raise ValueError("No images extracted from PDF")
115
+ image = images[0]
116
+ logger.info(f"PDF converted successfully, image size: {image.size}")
117
+ else:
118
+ logger.info("Opening image file...")
119
+ image = Image.open(file_path)
120
+ logger.info(f"Image opened successfully, size: {image.size}")
121
+
122
+ doc_type, confidence = self.predict_document_class(image)
123
+
124
+ logger.info(f"Classification successful: {doc_type}")
125
+ return doc_type, confidence, image
126
+
127
+ except Exception as e:
128
+ logger.error(f"Error classifying document {file_path}: {e}")
129
+ logger.error(traceback.format_exc())
130
+ return "Error", {"Error": 1.0}, None
131
+
132
+
133
+ # Initialize the classifier
134
+ logger.info("Initializing DocumentClassifier...")
135
+ classifier = DocumentClassifier()
136
+ logger.info("Classifier ready!")
137
+
138
+
139
+ def classify_upload(file):
140
+ """Gradio interface function for file upload"""
141
+ if file is None:
142
+ return "No file uploaded", {}, None
143
+
144
+ try:
145
+ # Get the file path
146
+ file_path = file.name
147
+
148
+ # Classify the document
149
+ doc_type, confidence, image = classifier.classify_document(file_path)
150
+
151
+ return doc_type, confidence, image
152
+
153
+ except Exception as e:
154
+ logger.error(f"Error in classify_upload: {e}")
155
+ logger.error(traceback.format_exc())
156
+ return f"Error: {str(e)}", {}, None
157
+
158
+
159
+ # Create Gradio interface
160
+ with gr.Blocks(title="Document Classifier") as demo:
161
+ gr.Markdown(
162
+ """
163
+ # 📄 Document Type Classifier
164
+
165
+ Upload a document image or PDF to classify it into one of the following categories:
166
+ - Financial Report
167
+ - Invoice or Receipt
168
+ - Legal Document
169
+ - Medical Record
170
+ - Research Paper
171
+
172
+ **Note:** Only the first page of PDFs will be analyzed.
173
+ """
174
+ )
175
+
176
+ with gr.Row():
177
+ with gr.Column():
178
+ file_input = gr.File(
179
+ label="Upload Document (Image or PDF)",
180
+ file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
181
+ )
182
+ classify_btn = gr.Button("Classify Document", variant="primary")
183
+
184
+ with gr.Column():
185
+ label_output = gr.Textbox(label="Document Type", lines=2)
186
+ confidence_output = gr.Label(label="Confidence Scores", num_top_classes=5)
187
+ image_output = gr.Image(label="Document Preview", type="pil")
188
+
189
+ # Examples
190
+ gr.Markdown("### Examples")
191
+ gr.Markdown("Upload your document using the file uploader above.")
192
+
193
+ # Connect the button to the function
194
+ classify_btn.click(
195
+ fn=classify_upload,
196
+ inputs=[file_input],
197
+ outputs=[label_output, confidence_output, image_output]
198
+ )
199
+
200
+ # Also trigger on file upload
201
+ file_input.change(
202
+ fn=classify_upload,
203
+ inputs=[file_input],
204
+ outputs=[label_output, confidence_output, image_output]
205
+ )
206
+
207
+ # Launch the app
208
+ if __name__ == "__main__":
209
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ torchvision
5
+ Pillow
6
+ pdf2image
7
+ pytesseract