FactureOCR / utils.py
Soufiane
added table extraction
8a6a4ae
import easyocr
import json
import re
import cv2
import easyocr
import google.generativeai as genai
langs = {'french': 'fr', 'english': 'en', 'arabic': 'ar'}
GOOGLE_API_KEY = "AIzaSyC6NXwTrucSl2JkY23YWsucFZPMBDoaqJw"
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel('gemini-pro')
def get_ocr(image,reader):
result = reader.readtext(image)
return result
def get_input(result, sep = " "):
a = {}
for (bbox, text, prob) in result:
if prob<0.3:
continue
k = True
for row in a:
if abs(bbox[0][1] - row)<=5:
k = False
a[row]+= sep
a[row]+= text
if k:
a[bbox[0][1]] = text
inputt = ""
for row in a:
inputt+= a[row]
inputt+= "\n"
return inputt
def imp_ocr(result, image, reader):
v_exp_rate = 0.2
h_exp_rate = 0.3
imp_result = []
for i in range(len(result)):
prob = result[i][2]
if prob < 0.1:
continue
if prob < 0.9:
bbox = result[i][0]
x = int(bbox[0][0] - (bbox[2][0] - bbox[0][0])*h_exp_rate/2)
x_h = int(bbox[2][0] + (bbox[2][0] - bbox[0][0])*h_exp_rate/2)
y = int(bbox[0][1] - (bbox[2][1] - bbox[0][1])*v_exp_rate/2)
y_h = int(bbox[2][1] + (bbox[2][1] - bbox[0][1])*v_exp_rate/2)
x = max(x, 0)
x_h = min(x_h, image.shape[1])
y = max(y, 0)
y_h = min(y_h, image.shape[0])
sub_img = image[y:y_h, x:x_h]
res = get_ocr(sub_img,reader)
if not res:
imp_result.append(result[i])
continue
if len(res)>1:
res = sorted(res, key=lambda x: x[2], reverse=True)
if res[0][2] >= prob:
imp_result.append((result[i][0], res[0][1], res[0][2]))
else:
imp_result.append(result[i])
else:
imp_result.append(result[i])
return imp_result
def extract_data(lang, to_be_extracted, image):
reader = easyocr.Reader([langs[lang]])
ocr_result = get_ocr(image,reader)
imp_result = imp_ocr(ocr_result, image, reader)
inputt = get_input(imp_result, sep = " ")
return get_output(inputt, to_be_extracted, lang)
def get_output(inputt, to_be_extracted, lang):
prompt = f"""
Bellow is the ouptut text of an OCR system. The text is in {lang}.
Your job is:
1. Format the output in a python dictionary format with the following keys :
{to_be_extracted}.
2. correct any mistakes such as date formating (should be dd/mm/yyyy) or spelling mistakes (this is important).
here is your input:
{inputt}
"""
response = model.generate_content(prompt)
data = extract_json(response.text)
return data
def extract_json(text_response):
# This pattern matches a string that starts with '{' and ends with '}'
pattern = r'\{[^{}]*\}'
matches = re.finditer(pattern, text_response)
json_objects = []
for match in matches:
json_str = match.group(0)
try:
# Validate if the extracted string is valid JSON
json_obj = eval(json_str)
json_objects.append(json_obj)
except json.JSONDecodeError:
# Extend the search for nested structures
extended_json_str = extend_search(text_response, match.span())
try:
json_obj = eval(extended_json_str)
json_objects.append(json_obj)
except json.JSONDecodeError:
# Handle cases where the extraction is not valid JSON
continue
if json_objects:
return json_objects[0]
else:
return {}
def extend_search(text, span):
# Extend the search to try to capture nested structures
start, end = span
nest_count = 0
for i in range(start, len(text)):
if text[i] == '{':
nest_count += 1
elif text[i] == '}':
nest_count -= 1
if nest_count == 0:
return text[start:i+1]
return text[start:end]