Sathiyapramod commited on
Commit
285d260
·
verified ·
1 Parent(s): 881a1b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -106
app.py CHANGED
@@ -1,125 +1,77 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import numpy as np
4
- import cv2
5
  import torch
6
-
7
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
 
9
  # =========================
10
- # Model Loader (cached)
11
  # =========================
12
- processor = None
13
- model = None
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- def load_model():
17
- global processor, model
18
-
19
- if processor is None or model is None:
20
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
21
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
22
- model.to(device)
23
-
24
-
25
- # =========================
26
- # Line Segmentation Logic
27
- # =========================
28
- def segment_lines(image: Image.Image):
29
- """
30
- Splits image into individual text lines using horizontal projection
31
- """
32
-
33
- # Convert to grayscale
34
- gray = np.array(image.convert("L"))
35
-
36
- # Apply thresholding
37
- _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
38
-
39
- # Horizontal projection
40
- horizontal_sum = np.sum(thresh, axis=1)
41
-
42
- lines = []
43
- start = None
44
-
45
- for i, val in enumerate(horizontal_sum):
46
- if val > 0 and start is None:
47
- start = i
48
- elif val == 0 and start is not None:
49
- end = i
50
- lines.append((start, end))
51
- start = None
52
-
53
- # Edge case: last line
54
- if start is not None:
55
- lines.append((start, len(horizontal_sum)))
56
-
57
- # Crop line images
58
- line_images = []
59
- for (s, e) in lines:
60
- # Add small padding
61
- top = max(0, s - 5)
62
- bottom = min(image.height, e + 5)
63
-
64
- cropped = image.crop((0, top, image.width, bottom))
65
-
66
- # Skip very small/noise regions
67
- if bottom - top > 10:
68
- line_images.append(cropped)
69
-
70
- return line_images
71
-
72
-
73
- # =========================
74
- # OCR Prediction
75
- # =========================
76
- def predict(image):
77
- load_model()
78
 
 
79
  if image is None:
80
  return "⚠️ Please upload an image."
81
 
82
- try:
83
- # Segment into lines
84
- lines = segment_lines(image)
85
-
86
- if not lines:
87
- return "⚠️ No text detected. Try a clearer image."
88
-
89
- results = []
90
-
91
- for line_img in lines:
92
- pixel_values = processor(
93
- images=line_img,
94
- return_tensors="pt"
95
- ).pixel_values.to(device)
96
-
97
- generated_ids = model.generate(pixel_values)
98
- text = processor.batch_decode(
99
- generated_ids,
100
- skip_special_tokens=True
101
- )[0]
102
-
103
- results.append(text)
104
-
105
- final_text = "\n".join(results)
106
-
107
- return final_text if final_text.strip() else "⚠️ Could not extract text."
108
-
109
- except Exception as e:
110
- return f" Error occurred: {str(e)}"
111
-
 
 
 
 
 
112
 
113
  # =========================
114
  # Gradio UI
115
  # =========================
116
- demo = gr.Interface(
117
- fn=predict,
118
- inputs=gr.Image(type="pil", label="Upload Handwritten Image"),
119
- outputs=gr.Textbox(label="Extracted Text"),
120
- title="📝 Handwritten OCR (Multi-line)",
121
- description="Upload a handwritten note image. The model will extract text line by line.",
122
- )
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
 
 
3
  import torch
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
 
5
 
6
  # =========================
7
+ # Model Setup
8
  # =========================
9
+ # Florence-2 is much more robust for full-page handwriting than TrOCR
10
+ model_id = 'microsoft/Florence-2-large'
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load model and processor with trust_remote_code=True for Florence architecture
14
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
15
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def run_ocr(image):
18
  if image is None:
19
  return "⚠️ Please upload an image."
20
 
21
+ # Florence-2 uses specific task prompts.
22
+ # <OCR_WITH_REGION> is best for messy handwriting and preserving layout.
23
+ prompt = "<OCR_WITH_REGION>"
24
+
25
+ # Preprocess image
26
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
27
+
28
+ # Generate text
29
+ with torch.no_grad():
30
+ generated_ids = model.generate(
31
+ input_ids=inputs["input_ids"],
32
+ pixel_values=inputs["pixel_values"],
33
+ max_new_tokens=1024,
34
+ do_sample=False,
35
+ num_beams=3
36
+ )
37
+
38
+ # Decode result
39
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
+
41
+ # Post-process to clean up the Florence-specific tags
42
+ parsed_answer = processor.post_process_generation(
43
+ generated_text,
44
+ task=prompt,
45
+ image_size=(image.width, image.height)
46
+ )
47
+
48
+ # Extract the plain text from the parsed dictionary
49
+ result = parsed_answer.get(prompt, "Could not parse text.")
50
+
51
+ # If the result is a dict (region based), we extract just the labels/text
52
+ if isinstance(result, dict) and 'labels' in result:
53
+ return "\n".join(result['labels'])
54
+
55
+ return str(result)
56
 
57
  # =========================
58
  # Gradio UI
59
  # =========================
60
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
61
+ gr.Markdown("# 🖋️ Advanced Handwritten Note Extractor")
62
+ gr.Markdown("Using **Florence-2-Large** for contextual OCR. Better for full letters and messy notes.")
63
+
64
+ with gr.Row():
65
+ input_img = gr.Image(type="pil", label="Upload Handwritten Letter")
66
+ output_text = gr.Textbox(label="Extracted Text", lines=15)
67
+
68
+ btn = gr.Button("Extract Text", variant="primary")
69
+ btn.click(fn=run_ocr, inputs=input_img, outputs=output_text)
70
+
71
+ gr.Examples(
72
+ examples=[], # You can add paths to example images here
73
+ inputs=input_img
74
+ )
75
 
76
  if __name__ == "__main__":
77
  demo.launch()