imperiusrex commited on
Commit
d0a0585
·
verified ·
1 Parent(s): 40c9220

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Setup ---
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
+ import cv2
8
+ from paddleocr import TextDetection
9
+ from huggingface_hub import spaces
10
+ import time
11
+
12
+ # Request H200 GPU
13
+ spaces.GPU.require("H200")
14
+
15
+ # --- Model Load ---
16
+ MODEL_HUB_ID = "imperiusrex/Handwritten_model"
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
20
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
21
+ model.to(device)
22
+ model.eval()
23
+ ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")
24
+
25
+ # --- Core OCR Function ---
26
+ def recognize_handwritten_text_from_npimg(np_img):
27
+ pil_img = Image.fromarray(np_img.astype(np.uint8)).convert("RGB")
28
+ image_np = np.array(pil_img)
29
+ detection_results = ocr_det_model.predict(image_np, batch_size=1)
30
+
31
+ detected_polys = []
32
+ for res in detection_results:
33
+ polys = res.get('dt_polys', [])
34
+ if polys is not None:
35
+ detected_polys.extend(polys.tolist())
36
+
37
+ cropped_images = []
38
+ if detected_polys:
39
+ for box in detected_polys:
40
+ box = np.array(box, dtype=np.float32)
41
+ width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
42
+ height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
43
+ dst_rect = np.array([
44
+ [0, 0],
45
+ [width - 1, 0],
46
+ [width - 1, height - 1],
47
+ [0, height - 1]
48
+ ], dtype=np.float32)
49
+ M = cv2.getPerspectiveTransform(box, dst_rect)
50
+ warped = cv2.warpPerspective(image_np, M, (width, height))
51
+ cropped_images.append(Image.fromarray(warped).convert("RGB"))
52
+ cropped_images.reverse()
53
+
54
+ recognized_texts = []
55
+ if cropped_images:
56
+ for crop_img in cropped_images:
57
+ pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device)
58
+ with torch.no_grad():
59
+ generated_ids = model.generate(pixel_values, max_new_tokens=64)
60
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
61
+ recognized_texts.append(generated_text)
62
+ else:
63
+ pixel_values = processor(images=pil_img, return_tensors="pt").pixel_values.to(device)
64
+ with torch.no_grad():
65
+ generated_ids = model.generate(pixel_values, max_new_tokens=64)
66
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
67
+ recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text)
68
+
69
+ return "\n".join(recognized_texts)
70
+
71
+
72
+ # --- Interface Function ---
73
+ def ocr_from_canvas(img):
74
+ if img is None:
75
+ return "Draw something to see OCR output."
76
+ np_img = np.array(img)
77
+ try:
78
+ result = recognize_handwritten_text_from_npimg(np_img)
79
+ except Exception as e:
80
+ result = f"[OCR error: {e}]"
81
+ return result
82
+
83
+
84
+ # --- UI Layout ---
85
+ with gr.Blocks(css=".gr-textbox textarea { font-family: monospace; font-size: 16px; }") as demo:
86
+ gr.Markdown("<h1>📝 Real-Time Handwriting OCR Canvas</h1>")
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ canvas = gr.ImageEditor(
91
+ label="Draw here (freehand, line, shapes)",
92
+ type="numpy",
93
+ tool="freedraw",
94
+ width=600,
95
+ height=400,
96
+ brush=gr.Brush(color="#000000", size=3),
97
+ background="#FFFFFF"
98
+ )
99
+ gr.Markdown(
100
+ """
101
+ - Use the canvas tools to draw freely, lines, rectangles, etc.
102
+ - You can adjust stroke width, brush color, and background color.
103
+ - The OCR will trigger every 4 seconds or when you draw.
104
+ """
105
+ )
106
+
107
+ with gr.Column():
108
+ output_text = gr.Textbox(
109
+ label="🧠 OCR Output",
110
+ lines=12,
111
+ max_lines=20,
112
+ interactive=False,
113
+ )
114
+
115
+ # Trigger OCR on change
116
+ canvas.change(fn=ocr_from_canvas, inputs=canvas, outputs=output_text)
117
+
118
+ demo.launch()