DeepDiveDev commited on
Commit
d1bb7e2
·
verified ·
1 Parent(s): 30abd6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoProcessor, AutoModelForVision2Seq
3
  from PIL import Image
4
  import numpy as np
5
  import torch
@@ -8,23 +8,23 @@ import torch
8
  processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
- # Load the fallback model (allenai/olmOCR-7B-0225-preview)
12
- model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
13
  processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
14
-
15
 
16
  # Function to extract text using both models
17
  def extract_text(image):
18
  try:
19
- # Convert input to PIL Image
20
- if isinstance(image, np.ndarray):
 
 
21
  image = Image.fromarray(image)
22
  else:
23
- image = Image.open(image).convert("RGB")
24
 
25
- # Preprocessing
26
- image = image.convert("L") # Convert to grayscale for better OCR
27
- image = image.resize((640, 640)) # Resize to improve accuracy
28
 
29
  # Process with the primary model
30
  pixel_values = processor1(images=image, return_tensors="pt").pixel_values
@@ -32,7 +32,7 @@ def extract_text(image):
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  # If output seems incorrect, use the fallback model
35
- if len(extracted_text.strip()) < 2: # If output is too short, retry with second model
36
  inputs = processor2(images=image, return_tensors="pt").pixel_values
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -51,4 +51,4 @@ iface = gr.Interface(
51
  description="Upload a handwritten document and get the extracted text.",
52
  )
53
 
54
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import numpy as np
5
  import torch
 
8
  processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
+ # Load the fallback model (microsoft/trocr-base-handwritten)
 
12
  processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
13
+ model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
14
 
15
  # Function to extract text using both models
16
  def extract_text(image):
17
  try:
18
+ # Ensure the input is a PIL image
19
+ if isinstance(image, np.ndarray):
20
+ if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
21
+ image = np.stack([image] * 3, axis=-1)
22
  image = Image.fromarray(image)
23
  else:
24
+ image = Image.open(image).convert("RGB") # Ensure RGB mode
25
 
26
+ # Resize for better accuracy
27
+ image = image.resize((640, 640))
 
28
 
29
  # Process with the primary model
30
  pixel_values = processor1(images=image, return_tensors="pt").pixel_values
 
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  # If output seems incorrect, use the fallback model
35
+ if len(extracted_text.strip()) < 2:
36
  inputs = processor2(images=image, return_tensors="pt").pixel_values
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
51
  description="Upload a handwritten document and get the extracted text.",
52
  )
53
 
54
+ iface.launch()