intuitive262 commited on
Commit
10c178b
·
1 Parent(s): 3bc9acc

code files

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -1,17 +1,15 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
- import cv2
5
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
- from huggingface_hub import hf_hub_download
7
  import torch
 
8
  import re
9
 
10
- # Download and load the GOT OCR model
11
- got_model_path = hf_hub_download(repo_id="junyeopkim/got_2.0_torch_script", filename="got_2.0_tiny.torchscript")
12
- got_model = torch.jit.load(got_model_path)
13
 
14
- # Load the Surya-OCR model
15
  surya_processor = TrOCRProcessor.from_pretrained("suryavarmaaddala/suryaocr")
16
  surya_model = VisionEncoderDecoderModel.from_pretrained("suryavarmaaddala/suryaocr")
17
 
@@ -19,19 +17,17 @@ def preprocess_image(image):
19
  if isinstance(image, str):
20
  image = Image.open(image).convert("RGB")
21
  elif isinstance(image, np.ndarray):
22
- image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
23
  return image
24
 
25
- def got_ocr(image):
26
  image = preprocess_image(image)
27
- image = image.resize((224, 224))
28
- input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
29
- input_tensor = input_tensor.unsqueeze(0)
30
 
31
- with torch.no_grad():
32
- output = got_model(input_tensor)
33
 
34
- return output[0].item()
35
 
36
  def surya_ocr(image):
37
  image = preprocess_image(image)
@@ -57,10 +53,10 @@ def search_text(text, query):
57
 
58
  def process_and_search(image, search_query):
59
  try:
60
- got_score = got_ocr(image)
61
  surya_text = surya_ocr(image)
62
 
63
- result = f"GOT OCR Score: {got_score:.4f}\n\nExtracted Text:\n{surya_text}"
64
  processed_text = post_process_text(result)
65
 
66
  search = None
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
 
 
 
4
  import torch
5
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
  import re
7
 
8
+ # Load the first OCR model (Microsoft's TrOCR)
9
+ ms_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
+ ms_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
+ # Load the second OCR model (Surya-OCR)
13
  surya_processor = TrOCRProcessor.from_pretrained("suryavarmaaddala/suryaocr")
14
  surya_model = VisionEncoderDecoderModel.from_pretrained("suryavarmaaddala/suryaocr")
15
 
 
17
  if isinstance(image, str):
18
  image = Image.open(image).convert("RGB")
19
  elif isinstance(image, np.ndarray):
20
+ image = Image.fromarray(image).convert("RGB")
21
  return image
22
 
23
+ def microsoft_ocr(image):
24
  image = preprocess_image(image)
25
+ pixel_values = ms_processor(image, return_tensors="pt").pixel_values
 
 
26
 
27
+ generated_ids = ms_model.generate(pixel_values)
28
+ generated_text = ms_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
 
30
+ return generated_text
31
 
32
  def surya_ocr(image):
33
  image = preprocess_image(image)
 
53
 
54
  def process_and_search(image, search_query):
55
  try:
56
+ ms_text = microsoft_ocr(image)
57
  surya_text = surya_ocr(image)
58
 
59
+ result = f"Microsoft OCR Result:\n{ms_text}\n\nSurya OCR Result:\n{surya_text}"
60
  processed_text = post_process_text(result)
61
 
62
  search = None