Spaces:
Runtime error
Runtime error
import easyocr | |
import json | |
import re | |
import cv2 | |
import easyocr | |
import google.generativeai as genai | |
langs = {'french': 'fr', 'english': 'en', 'arabic': 'ar'} | |
GOOGLE_API_KEY = "AIzaSyC6NXwTrucSl2JkY23YWsucFZPMBDoaqJw" | |
genai.configure(api_key=GOOGLE_API_KEY) | |
model = genai.GenerativeModel('gemini-pro') | |
def get_ocr(image,reader): | |
result = reader.readtext(image) | |
return result | |
def get_input(result, sep = " "): | |
a = {} | |
for (bbox, text, prob) in result: | |
if prob<0.3: | |
continue | |
k = True | |
for row in a: | |
if abs(bbox[0][1] - row)<=5: | |
k = False | |
a[row]+= sep | |
a[row]+= text | |
if k: | |
a[bbox[0][1]] = text | |
inputt = "" | |
for row in a: | |
inputt+= a[row] | |
inputt+= "\n" | |
return inputt | |
def imp_ocr(result, image, reader): | |
v_exp_rate = 0.2 | |
h_exp_rate = 0.3 | |
imp_result = [] | |
for i in range(len(result)): | |
prob = result[i][2] | |
if prob < 0.1: | |
continue | |
if prob < 0.9: | |
bbox = result[i][0] | |
x = int(bbox[0][0] - (bbox[2][0] - bbox[0][0])*h_exp_rate/2) | |
x_h = int(bbox[2][0] + (bbox[2][0] - bbox[0][0])*h_exp_rate/2) | |
y = int(bbox[0][1] - (bbox[2][1] - bbox[0][1])*v_exp_rate/2) | |
y_h = int(bbox[2][1] + (bbox[2][1] - bbox[0][1])*v_exp_rate/2) | |
x = max(x, 0) | |
x_h = min(x_h, image.shape[1]) | |
y = max(y, 0) | |
y_h = min(y_h, image.shape[0]) | |
sub_img = image[y:y_h, x:x_h] | |
res = get_ocr(sub_img,reader) | |
if not res: | |
imp_result.append(result[i]) | |
continue | |
if len(res)>1: | |
res = sorted(res, key=lambda x: x[2], reverse=True) | |
if res[0][2] >= prob: | |
imp_result.append((result[i][0], res[0][1], res[0][2])) | |
else: | |
imp_result.append(result[i]) | |
else: | |
imp_result.append(result[i]) | |
return imp_result | |
def extract_data(lang, to_be_extracted, image): | |
reader = easyocr.Reader([langs[lang]]) | |
ocr_result = get_ocr(image,reader) | |
imp_result = imp_ocr(ocr_result, image, reader) | |
inputt = get_input(imp_result, sep = " ") | |
return get_output(inputt, to_be_extracted, lang) | |
def get_output(inputt, to_be_extracted, lang): | |
prompt = f""" | |
Bellow is the ouptut text of an OCR system. The text is in {lang}. | |
Your job is: | |
1. Format the output in a python dictionary format with the following keys : | |
{to_be_extracted}. | |
2. correct any mistakes such as date formating (should be dd/mm/yyyy) or spelling mistakes (this is important). | |
here is your input: | |
{inputt} | |
""" | |
response = model.generate_content(prompt) | |
data = extract_json(response.text) | |
return data | |
def extract_json(text_response): | |
# This pattern matches a string that starts with '{' and ends with '}' | |
pattern = r'\{[^{}]*\}' | |
matches = re.finditer(pattern, text_response) | |
json_objects = [] | |
for match in matches: | |
json_str = match.group(0) | |
try: | |
# Validate if the extracted string is valid JSON | |
json_obj = eval(json_str) | |
json_objects.append(json_obj) | |
except json.JSONDecodeError: | |
# Extend the search for nested structures | |
extended_json_str = extend_search(text_response, match.span()) | |
try: | |
json_obj = eval(extended_json_str) | |
json_objects.append(json_obj) | |
except json.JSONDecodeError: | |
# Handle cases where the extraction is not valid JSON | |
continue | |
if json_objects: | |
return json_objects[0] | |
else: | |
return {} | |
def extend_search(text, span): | |
# Extend the search to try to capture nested structures | |
start, end = span | |
nest_count = 0 | |
for i in range(start, len(text)): | |
if text[i] == '{': | |
nest_count += 1 | |
elif text[i] == '}': | |
nest_count -= 1 | |
if nest_count == 0: | |
return text[start:i+1] | |
return text[start:end] |