File size: 3,574 Bytes
8565879
 
8a6a4ae
 
8565879
 
 
 
 
 
 
8a6a4ae
 
 
 
 
8565879
8a6a4ae
8565879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a6a4ae
 
 
 
 
 
 
 
8565879
 
 
8a6a4ae
8565879
 
8a6a4ae
 
 
 
8565879
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import cv2
from PIL import Image
from ultralyticsplus import YOLO
from transformers import pipeline
import pandas as pd
import numpy as np
import easyocr
from utils import *

INVOICE = ["Numéro de facture", "Date", "Numéro de commande", "Echéance", "Total"]

model = YOLO('keremberke/yolov8s-table-extraction')
model.overrides['conf'] = 0.25  # NMS confidence threshold
model.overrides['iou'] = 0.45  # NMS IoU threshold
model.overrides['agnostic_nms'] = False  # NMS class-agnostic
model.overrides['max_det'] = 1000  # maximum number of detections per image

pipe = pipeline("object-detection", model="bilguun/table-transformer-structure-recognition")


def detect_tables(image):
  # image is an np array
  results = model.predict(image)
  

  result = results[0]
  xyxy = result.boxes.xyxy
  scores = result.boxes.conf
  tables = []
  for i in range(len(scores)):
    if scores[i] >= 0.5:
      table = image[int(xyxy[i,1]):int(xyxy[i,3]), int(xyxy[i,0]):int(xyxy[i,2])]
      table = Image.fromarray(table)
      tables.append(table)
  return tables

def insert(el, listt, pos):
  if not listt:
      listt.append(el)
  else:
      inserted = False
      for i in range(len(listt)):
          if el[pos] <= listt[i][pos]:
              listt.insert(i, el)
              inserted = True
              break
      if not inserted:
          listt.append(el)

def rec_table(table, reader):
  col_row = pipe(table)
  cols = []
  rows = []
  for el in col_row:
    if el["label"] == 'table column':
      insert(el["box"], cols, pos = "xmin")
    elif el["label"] == 'table row':
      insert(el["box"], rows, pos = "ymin")

  table = np.array(table)

  csv = []
  for row in rows:
    temp = []
    for col in cols:
      box = intersection(row, col)
      cell = table[box['ymin']:box['ymax'], box['xmin']:box['xmax']]
      res = get_ocr(cell,reader)
      temp.append(get_input(res))
    csv.append(temp)
  
  df = pd.DataFrame(csv)
  return df


def intersection(box1, box2):
    # Extract coordinates of first bounding box
    x1min, y1min, x1max, y1max = box1['xmin'], box1['ymin'], box1['xmax'], box1['ymax']

    # Extract coordinates of second bounding box
    x2min, y2min, x2max, y2max = box2['xmin'], box2['ymin'], box2['xmax'], box2['ymax']

    # Calculate coordinates of intersection
    xmin = max(x1min, x2min)
    ymin = max(y1min, y2min)
    xmax = min(x1max, x2max)
    ymax = min(y1max, y2max)

    # Check if there is no intersection
    if xmin >= xmax or ymin >= ymax:
        return None

    # Return the coordinates of the intersection
    return {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}


# def extract_tables(lang, image):
#   reader = easyocr.Reader([langs[lang]])
#   tables = detect_tables(image)

#   for i in range(len(tables)):
#     df = rec_table(tables[i], reader)
#     df.to_excel(f'table_{i+1}.xlsx', index=False, header=False)

def extract_tables(lang, image):
  reader = easyocr.Reader([langs[lang]])
  tables = detect_tables(image)
  csvs = []
  for i in range(len(tables)):
    df = rec_table(tables[i], reader)
    csv = df.to_csv(index=False, header=False)
    csvs.append(csvs)
  
  return csvs[0]

if __name__ == '__main__':
    lang = "french"
    to_be_extracted = INVOICE
    image_path = "./docs for ocr/invoices/facture.png"
    image = cv2.imread(image_path)
    print(image.shape)
    
    text_data = extract_data(lang, to_be_extracted, image)
    print(text_data)

    # extract_tables(lang, image) # extract tables from the image and download them in excel format to the current directory