File size: 3,977 Bytes
8a6a4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import easyocr
import json
import re
import cv2
import easyocr
import google.generativeai as genai


langs = {'french': 'fr', 'english': 'en', 'arabic': 'ar'}


GOOGLE_API_KEY = "AIzaSyC6NXwTrucSl2JkY23YWsucFZPMBDoaqJw"
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel('gemini-pro')

def get_ocr(image,reader):
  result = reader.readtext(image)
  return result


def get_input(result, sep = " "):
    a = {}
    for (bbox, text, prob) in result:
        if prob<0.3:
            continue
        k = True
        for row in a:
            if abs(bbox[0][1] - row)<=5:
                k = False
                a[row]+= sep
                a[row]+= text

        if k:
            a[bbox[0][1]] = text

    inputt = ""
    for row in a:
        inputt+= a[row]
        inputt+= "\n"

    return inputt


def imp_ocr(result, image, reader):
  v_exp_rate = 0.2
  h_exp_rate = 0.3
  imp_result = []
  for i in range(len(result)):
    prob = result[i][2]
    if prob < 0.1:
      continue
    if prob < 0.9:
      bbox = result[i][0]
      x = int(bbox[0][0] - (bbox[2][0] - bbox[0][0])*h_exp_rate/2)
      x_h = int(bbox[2][0] + (bbox[2][0] - bbox[0][0])*h_exp_rate/2)
      y = int(bbox[0][1] - (bbox[2][1] - bbox[0][1])*v_exp_rate/2)
      y_h =  int(bbox[2][1] + (bbox[2][1] - bbox[0][1])*v_exp_rate/2)

      x = max(x, 0)
      x_h = min(x_h, image.shape[1])
      y = max(y, 0)
      y_h = min(y_h, image.shape[0])

      sub_img = image[y:y_h, x:x_h]
      res = get_ocr(sub_img,reader)
      if not res:
        imp_result.append(result[i])
        continue
      if len(res)>1:
        res = sorted(res, key=lambda x: x[2], reverse=True)

      if res[0][2] >= prob:
        imp_result.append((result[i][0], res[0][1], res[0][2]))
      else:
        imp_result.append(result[i])
    else:
      imp_result.append(result[i])
  return imp_result


def extract_data(lang, to_be_extracted, image):
   reader = easyocr.Reader([langs[lang]])
   ocr_result = get_ocr(image,reader)
   imp_result = imp_ocr(ocr_result, image, reader)
   inputt = get_input(imp_result, sep = "     ")
   return get_output(inputt, to_be_extracted, lang)

def get_output(inputt, to_be_extracted, lang):   
    prompt = f"""
    Bellow is the ouptut text of an OCR system. The text is in {lang}.
    Your job is:
    1. Format the output in a python dictionary format with the following keys :
    {to_be_extracted}.
    2. correct any mistakes such as date formating (should be dd/mm/yyyy) or spelling mistakes (this is important).

    here is your input:
    {inputt}
    """
    response = model.generate_content(prompt)
    data = extract_json(response.text)
    return data

def extract_json(text_response):
    # This pattern matches a string that starts with '{' and ends with '}'
    pattern = r'\{[^{}]*\}'
    matches = re.finditer(pattern, text_response)
    json_objects = []
    for match in matches:
        json_str = match.group(0)
        try:
            # Validate if the extracted string is valid JSON
            json_obj = eval(json_str)
            json_objects.append(json_obj)
        except json.JSONDecodeError:
            # Extend the search for nested structures
            extended_json_str = extend_search(text_response, match.span())
            try:
                json_obj = eval(extended_json_str)
                json_objects.append(json_obj)
            except json.JSONDecodeError:
                # Handle cases where the extraction is not valid JSON
                continue
    if json_objects:
        return json_objects[0]
    else:
        return {}


def extend_search(text, span):
    # Extend the search to try to capture nested structures
    start, end = span
    nest_count = 0
    for i in range(start, len(text)):
        if text[i] == '{':
            nest_count += 1
        elif text[i] == '}':
            nest_count -= 1
            if nest_count == 0:
                return text[start:i+1]
    return text[start:end]