Ramzan0553's picture
Update app.py
53af52f verified
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import pickle
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
import easyocr
import torch
# ========== GPU Checks ==========
print("Torch GPU Available:", torch.cuda.is_available())
print("TensorFlow GPU Devices:", tf.config.list_physical_devices('GPU'))
# ========== Load Model and Label Encoder ==========
model_path = "MobileNetBest_Model.h5"
label_path = "MobileNet_Label_Encoder.pkl"
model = load_model(model_path)
print("βœ… MobileNet model loaded.")
# Label encoder
try:
with open(label_path, 'rb') as f:
label_map = pickle.load(f)
index_to_label = {v: k for k, v in label_map.items()}
print("βœ… Label encoder loaded:", index_to_label)
except:
index_to_label = {0: "Handwritten", 1: "Computerized"}
print("⚠️ Default labels used:", index_to_label)
# ========== Initialize EasyOCR (Force GPU) ==========
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
print("βœ… EasyOCR initialized with GPU:", torch.cuda.is_available())
# ========== Classify One Region ==========
def classify_text_region(region_img):
try:
region_img = cv2.resize(region_img, (224, 224))
region_img = region_img.astype("float32") / 255.0
region_img = img_to_array(region_img)
region_img = np.expand_dims(region_img, axis=0)
preds = model.predict(region_img)
if preds.shape[-1] == 1:
return "Computerized" if preds[0][0] > 0.5 else "Handwritten"
else:
class_idx = np.argmax(preds[0])
return index_to_label.get(class_idx, "Unknown")
except Exception as e:
print("❌ Classification error:", e)
return "Unknown"
# ========== OCR & Annotate ==========
def AnnotatedTextDetection_EasyOCR_from_array(img):
results = reader.readtext(img)
annotated_results = []
for (bbox, text, conf) in results:
if conf < 0.3 or text.strip() == "":
continue
x1, y1 = map(int, bbox[0])
x2, y2 = map(int, bbox[2])
crop = img[y1:y2, x1:x2]
if crop.size == 0:
continue
label = classify_text_region(crop)
annotated_results.append(f"{text.strip()} β†’ {label}")
color = (0, 255, 0) if label == "Computerized" else (255, 0, 0)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), "\n".join(annotated_results)
# ========== Inference Function ==========
def infer(image):
img = np.array(image)
max_dim = 1000
if img.shape[0] > max_dim or img.shape[1] > max_dim:
scale = max_dim / max(img.shape[0], img.shape[1])
img = cv2.resize(img, (int(img.shape[1]*scale), int(img.shape[0]*scale)))
annotated_img, result_text = AnnotatedTextDetection_EasyOCR_from_array(img)
return Image.fromarray(annotated_img), result_text
# ========== Gradio UI ==========
with gr.Blocks(
title="Text Type Classifier",
css="""
body {
background-color: white !important;
color: red !important;
}
h1, h2, h3, h4, h5, h6, label, .gr-box, .gr-button {
color: red !important;
}
.outer-box {
border: 8px solid black;
border-radius: 16px;
padding: 24px;
background-color: white;
}
.gr-box {
border: 6px solid #0288d1 !important;
border-radius: 12px;
padding: 16px;
background-color: white;
box-shadow: 0px 2px 10px rgba(0,0,0,0.1);
}
.gr-button {
background-color: #0288d1 !important;
color: white !important;
font-weight: bold;
border-radius: 8px;
margin-top: 10px;
}
.gr-button:hover {
background-color: #01579b !important;
}
"""
) as demo:
with gr.Column(elem_classes=["outer-box"]):
gr.Markdown(
"""
<div style="text-align: center;">
<h1><strong>Handwritten vs Computerized Text Classifier</strong></h1>
</div>
""",
elem_classes=["gr-box"]
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="numpy", elem_classes=["gr-box"])
submit_btn = gr.Button("Process Image", elem_classes=["gr-box", "gr-button"])
clear_btn = gr.Button("Clear", elem_classes=["gr-box", "gr-button"])
with gr.Column():
image_output = gr.Image(label="Annotated Output", type="numpy", elem_classes=["gr-box"])
text_output = gr.Textbox(label="Detected Results", lines=10, elem_classes=["gr-box"])
submit_btn.click(
fn=infer,
inputs=image_input,
outputs=[image_output, text_output]
)
clear_btn.click(
fn=lambda: (None, None, ""),
inputs=[],
outputs=[image_input, image_output, text_output]
)
demo.launch()