File size: 3,305 Bytes
b419737
 
 
 
 
 
 
c1deaa6
c160de9
ab94877
 
c1deaa6
 
 
b419737
 
c1deaa6
 
ab94877
b419737
 
 
c1deaa6
 
 
b419737
 
 
 
 
 
c1deaa6
b419737
 
 
ab94877
b419737
 
 
 
 
ab94877
b419737
ab94877
b419737
 
ab94877
 
b419737
 
 
 
 
 
 
 
 
 
ab94877
b419737
ab94877
b419737
 
 
 
 
 
 
 
 
 
 
 
 
 
ab94877
b419737
 
 
ab94877
 
 
 
 
 
e9605f9
ab94877
 
 
e9605f9
ab94877
c1deaa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import cv2
from paddleocr import TextDetection
from spaces import GPU  # βœ… Required for ZeroGPU

MODEL_HUB_ID = "imperiusrex/Handwritten_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("πŸ”„ Loading models...")

processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
model.to(device)
model.eval()

ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")

print("βœ… Models loaded successfully.")

@GPU  # βœ… This tells Hugging Face this function needs the GPU (H200)
def recognize_handwritten_text(image_input):
    if image_input is None:
        return "Please upload an image."

    image_pil = Image.fromarray(image_input).convert("RGB")

    detection_results = ocr_det_model.predict(image_input, batch_size=1)

    detected_polys = []
    for res in detection_results:
        polys = res.get('dt_polys', [])
        if polys is not None:
            detected_polys.extend(polys.tolist())

    cropped_images = []
    if detected_polys:
        img_np = np.array(image_pil)

        for box in detected_polys:
            box = np.array(box, dtype=np.float32)

            width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
            height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))

            dst_rect = np.array([
                [0, 0],
                [width - 1, 0],
                [width - 1, height - 1],
                [0, height - 1]
            ], dtype=np.float32)

            M = cv2.getPerspectiveTransform(box, dst_rect)
            warped = cv2.warpPerspective(img_np, M, (width, height))
            cropped_images.append(Image.fromarray(warped).convert("RGB"))

        cropped_images.reverse()

    recognized_texts = []
    if cropped_images:
        for crop_img in cropped_images:
            pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device)
            with torch.no_grad():
                generated_ids = model.generate(pixel_values, max_new_tokens=64)
                generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                recognized_texts.append(generated_text)
    else:
        pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
        with torch.no_grad():
            generated_ids = model.generate(pixel_values, max_new_tokens=64)
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text)

    return "\n".join(recognized_texts)

# --- Gradio Interface ---
def build_interface():
    return gr.Interface(
        fn=recognize_handwritten_text,
        inputs=gr.Image(type="numpy", label="Upload Handwritten Image"),
        outputs="text",
        title="✍️ Handwritten Text Recognition",
        description="πŸ“· Upload a handwritten image. Uses PaddleOCR (detection) + TrOCR (recognition).",
    )

if __name__ == "__main__":
    iface = build_interface()
    iface.launch()