Sathiyapramod commited on
Commit
bb94232
·
verified ·
1 Parent(s): 427e409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -1,31 +1,38 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
- from transformers import AutoProcessor, AutoModelForCausalLM
5
 
6
  # =========================
7
- # Model Setup
8
  # =========================
9
- # Florence-2 is much more robust for full-page handwriting than TrOCR
10
  model_id = 'microsoft/Florence-2-large'
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load model and processor with trust_remote_code=True for Florence architecture
14
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
 
 
 
 
 
 
 
 
 
 
15
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
16
 
17
  def run_ocr(image):
18
  if image is None:
19
  return "⚠️ Please upload an image."
20
 
21
- # Florence-2 uses specific task prompts.
22
- # <OCR_WITH_REGION> is best for messy handwriting and preserving layout.
23
- prompt = "<OCR_WITH_REGION>"
24
 
25
- # Preprocess image
26
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
27
 
28
- # Generate text
29
  with torch.no_grad():
30
  generated_ids = model.generate(
31
  input_ids=inputs["input_ids"],
@@ -35,43 +42,29 @@ def run_ocr(image):
35
  num_beams=3
36
  )
37
 
38
- # Decode result
39
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
 
41
- # Post-process to clean up the Florence-specific tags
42
  parsed_answer = processor.post_process_generation(
43
  generated_text,
44
  task=prompt,
45
  image_size=(image.width, image.height)
46
  )
47
 
48
- # Extract the plain text from the parsed dictionary
49
- result = parsed_answer.get(prompt, "Could not parse text.")
50
-
51
- # If the result is a dict (region based), we extract just the labels/text
52
- if isinstance(result, dict) and 'labels' in result:
53
- return "\n".join(result['labels'])
54
-
55
- return str(result)
56
 
57
  # =========================
58
  # Gradio UI
59
  # =========================
60
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
61
- gr.Markdown("# 🖋️ Advanced Handwritten Note Extractor")
62
- gr.Markdown("Using **Florence-2-Large** for contextual OCR. Better for full letters and messy notes.")
63
 
64
  with gr.Row():
65
- input_img = gr.Image(type="pil", label="Upload Handwritten Letter")
66
- output_text = gr.Textbox(label="Extracted Text", lines=15)
67
 
68
- btn = gr.Button("Extract Text", variant="primary")
69
  btn.click(fn=run_ocr, inputs=input_img, outputs=output_text)
70
 
71
- gr.Examples(
72
- examples=[], # You can add paths to example images here
73
- inputs=input_img
74
- )
75
-
76
  if __name__ == "__main__":
77
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import torch
4
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
5
 
6
  # =========================
7
+ # Model Setup & Patch
8
  # =========================
 
9
  model_id = 'microsoft/Florence-2-large'
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # PATCH: Explicitly handle the Florence2 configuration bug
13
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
14
+ if not hasattr(config, 'forced_bos_token_id'):
15
+ config.forced_bos_token_id = None
16
+
17
+ # Load model and processor
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
20
+ config=config,
21
+ trust_remote_code=True
22
+ ).to(device).eval()
23
+
24
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
25
 
26
  def run_ocr(image):
27
  if image is None:
28
  return "⚠️ Please upload an image."
29
 
30
+ # Using <DETAILED_CAPTION> or <OCR> task for better text flow
31
+ # Florence-2 works best with these specific task tags
32
+ prompt = "<OCR>"
33
 
 
34
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
35
 
 
36
  with torch.no_grad():
37
  generated_ids = model.generate(
38
  input_ids=inputs["input_ids"],
 
42
  num_beams=3
43
  )
44
 
 
45
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
46
 
47
+ # Clean up the output
48
  parsed_answer = processor.post_process_generation(
49
  generated_text,
50
  task=prompt,
51
  image_size=(image.width, image.height)
52
  )
53
 
54
+ return parsed_answer[prompt]
 
 
 
 
 
 
 
55
 
56
  # =========================
57
  # Gradio UI
58
  # =========================
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("## 🖋️ Handwritten Note to Text (Florence-2)")
 
61
 
62
  with gr.Row():
63
+ input_img = gr.Image(type="pil")
64
+ output_text = gr.Textbox(label="Extracted Text", lines=10)
65
 
66
+ btn = gr.Button("Convert to Text", variant="primary")
67
  btn.click(fn=run_ocr, inputs=input_img, outputs=output_text)
68
 
 
 
 
 
 
69
  if __name__ == "__main__":
70
  demo.launch()