Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from hezar.models import Model | |
from hezar.utils import load_image, draw_boxes | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import io | |
# Load models on CPU (Hugging Face Spaces default) | |
craft_model = Model.load("hezarai/CRAFT", device="cpu") | |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
def recognize_handwritten_text(image): | |
try: | |
# Ensure image is a PIL image and convert to a compatible format | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(np.array(image)).convert("RGB") | |
# Save the uploaded image to a temporary file in JPEG format | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: | |
image.save(tmp_file.name, format="JPEG") | |
tmp_path = tmp_file.name | |
# Load image with hezar utils using file path | |
processed_image = load_image(tmp_path) | |
# Ensure processed_image is in a compatible format (convert to NumPy if needed) | |
if not isinstance(processed_image, np.ndarray): | |
processed_image = np.array(Image.open(tmp_path)) | |
# Detect text regions with CRAFT | |
outputs = craft_model.predict(processed_image) | |
if not outputs or "boxes" not in outputs[0]: | |
return Image.fromarray(processed_image), "No text detected" | |
boxes = outputs[0]["boxes"] | |
print(f"Debug: Boxes structure = {boxes}") # Log the exact structure | |
pil_image = Image.fromarray(processed_image) | |
texts = [] | |
# Handle box format (assuming [x, y, width, height] or [[x1, y1], [x2, y2]]) | |
for box in boxes: | |
if len(box) == 4: # [x, y, width, height] | |
x, y, width, height = box | |
x_min, y_min = x, y | |
x_max, y_max = x + width, y + height | |
elif len(box) == 2 and all(len(p) == 2 for p in box): # [[x1, y1], [x2, y2]] | |
x1, y1 = box[0] | |
x2, y2 = box[1] | |
x_min, y_min = min(x1, x2), min(y1, y2) | |
x_max, y_max = max(x1, x2), max(y1, y2) | |
else: | |
print(f"Debug: Skipping invalid box {box}") # Log invalid boxes | |
continue | |
crop = pil_image.crop((x_min, y_min, x_max, y_max)) | |
pixel_values = processor(images=crop, return_tensors="pt").pixel_values | |
generated_ids = trocr_model.generate(pixel_values) | |
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
texts.append(text) | |
# Draw boxes on the image | |
result_image = draw_boxes(processed_image, boxes) | |
result_pil = Image.fromarray(result_image) | |
# Join recognized texts | |
text_data = " ".join(texts) if texts else "No text recognized" | |
return result_pil, f"Recognized text: {text_data}" | |
except Exception as e: | |
return Image.fromarray(np.array(image)), f"Error: {str(e)}" | |
finally: | |
# Clean up temporary file | |
if 'tmp_path' in locals(): | |
os.unlink(tmp_path) | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=recognize_handwritten_text, | |
inputs=gr.Image(type="pil", label="Upload any image format"), | |
outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")], | |
title="Handwritten Text Detection and Recognition", | |
description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text." | |
) | |
# Launch the app | |
interface.launch() |