bachpc commited on
Commit
17ae8b6
1 Parent(s): 87f7012

Fix bug and clean

Browse files
Files changed (1) hide show
  1. app.py +39 -41
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
-
3
  import PIL
4
  import cv2
5
  import numpy as np
@@ -9,7 +8,6 @@ import torch
9
  # import json
10
  from collections import OrderedDict, defaultdict
11
  import xml.etree.ElementTree as ET
12
-
13
  from paddleocr import PaddleOCR
14
  import pytesseract
15
  from pytesseract import Output
@@ -29,13 +27,13 @@ structure_class_names = [
29
  ]
30
  structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
31
  structure_class_thresholds = {
32
- "table": 0.5,
33
- "table column": 0.5,
34
- "table row": 0.5,
35
- "table column header": 0.25,
36
- "table projected row header": 0.25,
37
- "table spanning cell": 0.25,
38
- "no object": 10
39
  }
40
 
41
 
@@ -150,7 +148,7 @@ def convert_stucture(page_tokens, pil_img, structure_result):
150
  try:
151
  table_bbox = list(table_class_objects[0]['bbox'])
152
  except:
153
- table_bbox = (0,0,1000,1000)
154
  # print('table_class_objects:', table_class_objects)
155
  # print('table_bbox:', table_bbox)
156
 
@@ -186,17 +184,17 @@ def visualize_structure(pil_img, structure_result):
186
  min_y = result[1]
187
  w = result[2]
188
  h = result[3]
189
-
190
  x1 = int((min_x - w / 2) * width)
191
  y1 = int((min_y - h / 2) * height)
192
  x2 = int((min_x + w / 2) * width)
193
  y2 = int((min_y + h / 2) * height)
194
  # print(x1, y1, x2, y2)
195
-
196
  if score >= structure_class_thresholds[structure_class_names[class_id]]:
197
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
198
  #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
199
-
200
  return cv_to_PIL(image)
201
 
202
 
@@ -281,12 +279,12 @@ def cells_to_csv(cells):
281
  else:
282
  max_header_row = -1
283
 
284
- table_array = np.empty([num_rows, num_columns], dtype="object")
285
  if len(cells) > 0:
286
  for cell in cells:
287
  for row_num in cell['row_nums']:
288
  for column_num in cell['column_nums']:
289
- table_array[row_num, column_num] = cell["cell_text"]
290
 
291
  header = table_array[:max_header_row+1,:]
292
  flattened_header = []
@@ -301,7 +299,7 @@ def cells_to_html(cells):
301
  cells = sorted(cells, key=lambda k: min(k['column_nums']))
302
  cells = sorted(cells, key=lambda k: min(k['row_nums']))
303
 
304
- table = ET.Element("table")
305
  current_row = -1
306
 
307
  for cell in cells:
@@ -317,15 +315,15 @@ def cells_to_html(cells):
317
  if this_row > current_row:
318
  current_row = this_row
319
  if cell['header']:
320
- cell_tag = "th"
321
- row = ET.SubElement(table, "thead")
322
  else:
323
- cell_tag = "td"
324
- row = ET.SubElement(table, "tr")
325
  tcell = ET.SubElement(row, cell_tag, attrib=attrib)
326
  tcell.text = cell['cell_text']
327
 
328
- return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))
329
 
330
 
331
  # def cells_to_html(cells):
@@ -342,11 +340,11 @@ def cells_to_html(cells):
342
  # for cell in r_cells:
343
  # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
344
  # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
345
- # r_html += f'<td rowspan="{rowspan}" colspan="{colspan}">{escape(cell["text"])}</td>'
346
  # html_code += f'<tr>{r_html}</tr>'
347
  # html_code = '''<html>
348
  # <head>
349
- # <meta charset="UTF-8">
350
  # <style>
351
  # table, th, td {
352
  # border: 1px solid black;
@@ -355,7 +353,7 @@ def cells_to_html(cells):
355
  # </style>
356
  # </head>
357
  # <body>
358
- # <table frame="hsides" rules="groups" width="100%%">
359
  # %s
360
  # </table>
361
  # </body>
@@ -367,22 +365,22 @@ def cells_to_html(cells):
367
 
368
  def main():
369
 
370
- st.set_page_config(layout="wide")
371
- st.title("Table Structure Recognition Demo")
372
  st.write('\n')
373
 
374
  cols = st.columns((1, 1))
375
- cols[0].subheader("Input page")
376
- cols[1].subheader("Table(s) detected")
377
 
378
- st.sidebar.title("Image upload")
379
  st.set_option('deprecation.showfileUploaderEncoding', False)
380
- filename = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg'])
381
 
382
- if st.sidebar.button("Analyze image"):
383
 
384
  if filename is None:
385
- st.sidebar.write("Please upload an image")
386
 
387
  else:
388
  print(filename)
@@ -394,31 +392,31 @@ def main():
394
  cols[1].image(vis_det_img)
395
 
396
  str_cols = st.columns((len(crop_images), ) * 5)
397
- str_cols[0].subheader("Table image")
398
- str_cols[1].subheader("OCR result")
399
- str_cols[2].subheader("Structure result")
400
- str_cols[3].subheader("Cells result")
401
- str_cols[4].subheader("CSV result")
402
 
403
- for img in crop_images:
404
  ocr_result = ocr(img)
405
  structure_result = table_structure(img)
406
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
407
  cells = extract_text_from_cells(cells)
408
  html_result = cells_to_html(cells)
409
  df, csv_result = cells_to_csv(cells)
410
- print(df)
411
 
412
  vis_ocr_img = visualize_ocr(img, ocr_result)
413
  vis_str_img = visualize_structure(img, structure_result)
414
  vis_cells_img = visualize_cells(img, cells)
415
-
416
  str_cols[0].image(img)
417
  str_cols[1].image(vis_ocr_img)
418
  str_cols[2].image(vis_str_img)
419
  str_cols[3].image(vis_cells_img)
420
  #str_cols[4].dataframe(df)
421
- str_cols[4].download_button("Download table", csv_result, "file.csv", "text/csv", key='download-csv')
422
 
423
  st.markdown(html_result, unsafe_allow_html=True)
424
 
 
1
  import streamlit as st
 
2
  import PIL
3
  import cv2
4
  import numpy as np
 
8
  # import json
9
  from collections import OrderedDict, defaultdict
10
  import xml.etree.ElementTree as ET
 
11
  from paddleocr import PaddleOCR
12
  import pytesseract
13
  from pytesseract import Output
 
27
  ]
28
  structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
29
  structure_class_thresholds = {
30
+ 'table': 0.5,
31
+ 'table column': 0.5,
32
+ 'table row': 0.5,
33
+ 'table column header': 0.25,
34
+ 'table projected row header': 0.25,
35
+ 'table spanning cell': 0.25,
36
+ 'no object': 10
37
  }
38
 
39
 
 
148
  try:
149
  table_bbox = list(table_class_objects[0]['bbox'])
150
  except:
151
+ table_bbox = (0, 0, 1000, 1000)
152
  # print('table_class_objects:', table_class_objects)
153
  # print('table_bbox:', table_bbox)
154
 
 
184
  min_y = result[1]
185
  w = result[2]
186
  h = result[3]
187
+
188
  x1 = int((min_x - w / 2) * width)
189
  y1 = int((min_y - h / 2) * height)
190
  x2 = int((min_x + w / 2) * width)
191
  y2 = int((min_y + h / 2) * height)
192
  # print(x1, y1, x2, y2)
193
+
194
  if score >= structure_class_thresholds[structure_class_names[class_id]]:
195
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
196
  #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
197
+
198
  return cv_to_PIL(image)
199
 
200
 
 
279
  else:
280
  max_header_row = -1
281
 
282
+ table_array = np.empty([num_rows, num_columns], dtype='object')
283
  if len(cells) > 0:
284
  for cell in cells:
285
  for row_num in cell['row_nums']:
286
  for column_num in cell['column_nums']:
287
+ table_array[row_num, column_num] = cell['cell_text']
288
 
289
  header = table_array[:max_header_row+1,:]
290
  flattened_header = []
 
299
  cells = sorted(cells, key=lambda k: min(k['column_nums']))
300
  cells = sorted(cells, key=lambda k: min(k['row_nums']))
301
 
302
+ table = ET.Element('table')
303
  current_row = -1
304
 
305
  for cell in cells:
 
315
  if this_row > current_row:
316
  current_row = this_row
317
  if cell['header']:
318
+ cell_tag = 'th'
319
+ row = ET.SubElement(table, 'thead')
320
  else:
321
+ cell_tag = 'td'
322
+ row = ET.SubElement(table, 'tr')
323
  tcell = ET.SubElement(row, cell_tag, attrib=attrib)
324
  tcell.text = cell['cell_text']
325
 
326
+ return str(ET.tostring(table, encoding='unicode', short_empty_elements=False))
327
 
328
 
329
  # def cells_to_html(cells):
 
340
  # for cell in r_cells:
341
  # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
342
  # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
343
+ # r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>'
344
  # html_code += f'<tr>{r_html}</tr>'
345
  # html_code = '''<html>
346
  # <head>
347
+ # <meta charset='UTF-8'>
348
  # <style>
349
  # table, th, td {
350
  # border: 1px solid black;
 
353
  # </style>
354
  # </head>
355
  # <body>
356
+ # <table frame='hsides' rules='groups' width='100%%'>
357
  # %s
358
  # </table>
359
  # </body>
 
365
 
366
  def main():
367
 
368
+ st.set_page_config(layout='wide')
369
+ st.title('Table Extraction Demo')
370
  st.write('\n')
371
 
372
  cols = st.columns((1, 1))
373
+ cols[0].subheader('Input page')
374
+ cols[1].subheader('Table(s) detected')
375
 
376
+ st.sidebar.title('Image upload')
377
  st.set_option('deprecation.showfileUploaderEncoding', False)
378
+ filename = st.sidebar.file_uploader('Upload files', type=['png', 'jpeg', 'jpg'])
379
 
380
+ if st.sidebar.button('Analyze image'):
381
 
382
  if filename is None:
383
+ st.sidebar.write('Please upload an image')
384
 
385
  else:
386
  print(filename)
 
392
  cols[1].image(vis_det_img)
393
 
394
  str_cols = st.columns((len(crop_images), ) * 5)
395
+ str_cols[0].subheader('Table image')
396
+ str_cols[1].subheader('OCR result')
397
+ str_cols[2].subheader('Structure result')
398
+ str_cols[3].subheader('Cells result')
399
+ str_cols[4].subheader('CSV result')
400
 
401
+ for i, img in enumerate(crop_images):
402
  ocr_result = ocr(img)
403
  structure_result = table_structure(img)
404
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
405
  cells = extract_text_from_cells(cells)
406
  html_result = cells_to_html(cells)
407
  df, csv_result = cells_to_csv(cells)
408
+ #print(df)
409
 
410
  vis_ocr_img = visualize_ocr(img, ocr_result)
411
  vis_str_img = visualize_structure(img, structure_result)
412
  vis_cells_img = visualize_cells(img, cells)
413
+
414
  str_cols[0].image(img)
415
  str_cols[1].image(vis_ocr_img)
416
  str_cols[2].image(vis_str_img)
417
  str_cols[3].image(vis_cells_img)
418
  #str_cols[4].dataframe(df)
419
+ str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
420
 
421
  st.markdown(html_result, unsafe_allow_html=True)
422