bachpc's picture
Fix for streamlit 1.17.0
c9a52b3
raw
history blame
23 kB
import streamlit as st
import PIL
import cv2
import numpy as np
import pandas as pd
import torch
import os
import io
# import sys
# import json
from collections import OrderedDict, defaultdict
import xml.etree.ElementTree as ET
from tempfile import TemporaryDirectory
import xlsxwriter
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
from paddleocr import PaddleOCR
# import pytesseract
# from pytesseract import Output
import postprocess
@st.experimental_singleton(ttl=3600)
def load_ocr_instance():
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
return ocr_instance
@st.experimental_singleton(ttl=3600)
def load_detection_model():
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
return detection_model
@st.experimental_singleton(ttl=3600)
def load_structure_model():
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True, skip_validation=True, trust_repo=True)
return structure_model
ocr_instance, detection_model, structure_model = load_ocr_instance(), load_detection_model(), load_structure_model()
detection_class_names = ['table', 'table rotated', 'no object']
structure_class_names = [
'table', 'table column', 'table row', 'table column header',
'table projected row header', 'table spanning cell', 'no object'
]
detection_class_map = {k: v for v, k in enumerate(detection_class_names)}
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
detection_class_thresholds = {
'table': 0.5,
'table rotated': 0.5,
'no object': 10
}
structure_class_thresholds = {
'table': 0.42,
'table column': 0.56,
'table row': 0.5,
'table column header': 0.38,
'table projected row header': 0.27,
'table spanning cell': 0.4,
'no object': 10
}
def PIL_to_cv(pil_img):
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv_to_PIL(cv_img):
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
def table_detection(pil_img, imgsz=640):
image = PIL_to_cv(pil_img)
pred = detection_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def table_structure(pil_img, imgsz=640):
image = PIL_to_cv(pil_img)
pred = structure_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def crop_image(pil_img, detection_result, padding=30):
crop_images = []
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
for idx, result in enumerate(detection_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
if score < detection_class_thresholds[detection_class_names[class_id]]:
continue
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
x1_pad = max(0, x1 - padding)
y1_pad = max(0, y1 - padding)
x2_pad = min(width, x2 + padding)
y2_pad = min(height, y2 + padding)
crop_image = image[y1_pad:y2_pad, x1_pad:x2_pad, :]
crop_image = cv_to_PIL(crop_image)
if detection_class_names[class_id] == 'table rotated':
crop_image = crop_image.rotate(270, expand=True)
crop_images.append(crop_image)
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)
cv2.putText(image, f'{score:.2f}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(255, 0, 0))
return crop_images, cv_to_PIL(image)
def ocr(pil_img):
image = PIL_to_cv(pil_img)
result = ocr_instance.ocr(image)
ocr_res = []
for ps, (text, score) in result[0]:
x1 = min(p[0] for p in ps)
y1 = min(p[1] for p in ps)
x2 = max(p[0] for p in ps)
y2 = max(p[1] for p in ps)
word_info = {
'bbox': [x1, y1, x2, y2],
'text': text
}
ocr_res.append(word_info)
return ocr_res
def convert_stucture(page_tokens, pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
bboxes = []
scores = []
labels = []
for idx, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
bboxes.append([x1, y1, x2, y2])
scores.append(score)
labels.append(class_id)
table_objects = []
for bbox, score, label in zip(bboxes, scores, labels):
table_objects.append({'bbox': bbox, 'score': score, 'label': label})
# print('table_objects:', table_objects)
table = {'objects': table_objects, 'page_num': 0}
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
if len(table_class_objects) > 1:
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
try:
table_bbox = list(table_class_objects[0]['bbox'])
except:
table_bbox = (0, 0, 1000, 1000)
# print('table_class_objects:', table_class_objects)
# print('table_bbox:', table_bbox)
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
# print('tokens_in_table:', tokens_in_table)
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
return table_structures, cells, confidence_score
def visualize_image(pil_img):
plt.imshow(pil_img, interpolation='lanczos')
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=150)
plt.close()
return PIL.Image.open(img_buf)
def visualize_ocr(pil_img, ocr_result):
plt.imshow(pil_img, interpolation='lanczos')
plt.gcf().set_size_inches(20, 20)
ax = plt.gca()
for idx, result in enumerate(ocr_result):
bbox = result['bbox']
text = result['text']
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-')
ax.add_patch(rect)
ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7)
plt.xticks([], [])
plt.yticks([], [])
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=150)
plt.close()
return PIL.Image.open(img_buf)
def get_bbox_decorations(data_type, label):
if label == 0:
if data_type == 'detection':
return 'brown', 0.05, 3, '//'
else:
return 'brown', 0, 3, None
elif label == 1:
return 'red', 0.15, 2, None
elif label == 2:
return 'blue', 0.15, 2, None
elif label == 3:
return 'magenta', 0.2, 3, '//'
elif label == 4:
return 'cyan', 0.2, 4, '//'
elif label == 5:
return 'green', 0.2, 4, '\\\\'
return 'gray', 0, 0, None
def visualize_structure(pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
plt.imshow(pil_img, interpolation='lanczos')
plt.gcf().set_size_inches(20, 20)
ax = plt.gca()
for idx, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
if score < structure_class_thresholds[structure_class_names[class_id]]:
continue
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
bbox = [x1, y1, x2, y2]
color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
# Fill
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth, alpha=alpha,
edgecolor='none',facecolor=color,
linestyle=None)
ax.add_patch(rect)
# Hatch
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=1, alpha=0.4,
edgecolor=color, facecolor='none',
linestyle='--',hatch=hatch)
ax.add_patch(rect)
# Edge
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth,
edgecolor=color, facecolor='none',
linestyle='--')
ax.add_patch(rect)
plt.xticks([], [])
plt.yticks([], [])
legend_elements = []
for class_name in structure_class_names[:-1]:
color, alpha, linewidth, hatch = get_bbox_decorations('recognition', structure_class_map[class_name])
legend_elements.append(
Patch(facecolor=color, edgecolor=color, linestyle='--', label=class_name, hatch=hatch)
)
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
fontsize=10, ncol=3)
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=150)
plt.close()
return PIL.Image.open(img_buf)
def visualize_cells(pil_img, cells):
plt.imshow(pil_img, interpolation='lanczos')
plt.gcf().set_size_inches(20, 20)
ax = plt.gca()
for cell in cells:
bbox = cell['bbox']
if cell['header']:
facecolor = (1, 0, 0.45)
edgecolor = (1, 0, 0.45)
alpha = 0.3
linewidth = 2
hatch='//////'
elif cell['subheader']:
facecolor = (0.95, 0.6, 0.1)
edgecolor = (0.95, 0.6, 0.1)
alpha = 0.3
linewidth = 2
hatch='//////'
else:
facecolor = (0.3, 0.74, 0.8)
edgecolor = (0.3, 0.7, 0.6)
alpha = 0.3
linewidth = 2
hatch='\\\\\\\\\\\\'
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=facecolor, alpha=0.1)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
ax.add_patch(rect)
plt.xticks([], [])
plt.yticks([], [])
legend_elements = [Patch(facecolor=(0.3, 0.74, 0.8), edgecolor=(0.3, 0.7, 0.6),
label='Data cell', hatch='\\\\\\\\\\\\', alpha=0.3),
Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
label='Column header cell', hatch='//////', alpha=0.3),
Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
label='Projected row header cell', hatch='//////', alpha=0.3)]
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
fontsize=10, ncol=3)
plt.gcf().set_size_inches(10, 10)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=150)
plt.close()
return PIL.Image.open(img_buf)
# def pytess(cell_pil_img):
# return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
# def resize(pil_img, size=1800):
# length_x, width_y = pil_img.size
# factor = max(1, size / length_x)
# size = int(factor * length_x), int(factor * width_y)
# pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
# return pil_img, factor
# def image_smoothening(img):
# ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
# ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# blur = cv2.GaussianBlur(th2, (1, 1), 0)
# ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# return th3
# def remove_noise_and_smooth(pil_img):
# img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
# filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
# kernel = np.ones((1, 1), np.uint8)
# opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
# closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
# img = image_smoothening(img)
# or_image = cv2.bitwise_or(img, closing)
# pil_img = PIL.Image.fromarray(or_image)
# return pil_img
# def extract_text_from_cells(pil_img, cells):
# pil_img, factor = resize(pil_img)
# #pil_img = remove_noise_and_smooth(pil_img)
# #display(pil_img)
# for cell in cells:
# bbox = [x * factor for x in cell['bbox']]
# cell_pil_img = pil_img.crop(bbox)
# #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
# #cell_pil_img = tess_prep(cell_pil_img)
# cell['cell text'] = pytess(cell_pil_img)
# return cells
def extract_text_from_cells(cells, sep=' '):
for cell in cells:
spans = cell['spans']
text = ''
for span in spans:
if 'text' in span:
text += span['text'] + sep
cell['cell_text'] = text
return cells
def cells_to_csv(cells):
if len(cells) > 0:
num_columns = max([max(cell['column_nums']) for cell in cells]) + 1
num_rows = max([max(cell['row_nums']) for cell in cells]) + 1
else:
return
header_cells = [cell for cell in cells if cell['header']]
if len(header_cells) > 0:
max_header_row = max([max(cell['row_nums']) for cell in header_cells])
else:
max_header_row = -1
table_array = np.empty([num_rows, num_columns], dtype='object')
if len(cells) > 0:
for cell in cells:
for row_num in cell['row_nums']:
for column_num in cell['column_nums']:
table_array[row_num, column_num] = cell['cell_text']
header = table_array[:max_header_row+1,:]
flattened_header = []
for col in header.transpose():
flattened_header.append(' | '.join(OrderedDict.fromkeys(col)))
df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header)
return df, df.to_csv(index=None)
def cells_to_html(cells):
cells = sorted(cells, key=lambda k: min(k['column_nums']))
cells = sorted(cells, key=lambda k: min(k['row_nums']))
table = ET.Element('table')
current_row = -1
for cell in cells:
this_row = min(cell['row_nums'])
attrib = {}
colspan = len(cell['column_nums'])
if colspan > 1:
attrib['colspan'] = str(colspan)
rowspan = len(cell['row_nums'])
if rowspan > 1:
attrib['rowspan'] = str(rowspan)
if this_row > current_row:
current_row = this_row
if cell['header']:
cell_tag = 'th'
row = ET.SubElement(table, 'tr')
else:
cell_tag = 'td'
row = ET.SubElement(table, 'tr')
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
tcell.text = cell['cell_text']
return str(ET.tostring(table, encoding='unicode', short_empty_elements=False))
# def cells_to_html(cells):
# for cell in cells:
# cell['column_nums'].sort()
# cell['row_nums'].sort()
# n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
# n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
# html_code = ''
# for r in range(n_rows):
# r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
# r_cells.sort(key=lambda x: x['column_nums'][0])
# r_html = ''
# for cell in r_cells:
# rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
# colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
# r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>'
# html_code += f'<tr>{r_html}</tr>'
# html_code = '''<html>
# <head>
# <meta charset='UTF-8'>
# <style>
# table, th, td {
# border: 1px solid black;
# font-size: 10px;
# }
# </style>
# </head>
# <body>
# <table frame='hsides' rules='groups' width='100%%'>
# %s
# </table>
# </body>
# </html>''' % html_code
# soup = bs(html_code)
# html_code = soup.prettify()
# return html_code
def cells_to_excel(cells, file_path):
def int2xlsx(i):
if i < 26:
return chr(i + 65)
return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}'
cells = sorted(cells, key=lambda k: min(k['column_nums']))
cells = sorted(cells, key=lambda k: min(k['row_nums']))
workbook = xlsxwriter.Workbook(file_path)
cell_format = workbook.add_format(
{'align': 'center', 'valign': 'vcenter'}
)
worksheet = workbook.add_worksheet(name='Table')
table_start_index = 0
for cell in cells:
start_row = min(cell['row_nums'])
end_row = max(cell['row_nums'])
start_col = min(cell['column_nums'])
end_col = max(cell['column_nums'])
if start_row == end_row and start_col == end_col:
worksheet.write(
table_start_index + start_row,
start_col,
cell['cell_text'],
cell_format,
)
else:
if start_col == end_col and start_row == end_row:
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}'
else:
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}'
worksheet.merge_range(
excel_index, cell['cell_text'], cell_format
)
workbook.close()
def main():
st.set_page_config(layout='wide')
st.title('Table Extraction Demo')
filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg'])
if st.button('Analyze image'):
if filename is None:
st.write('Please upload an image')
else:
tabs = st.tabs(
['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
)
print(filename)
pil_img = PIL.Image.open(filename)
detection_result = table_detection(pil_img)
crop_images, vis_det_img = crop_image(pil_img, detection_result)
all_cells = []
with tabs[0]:
st.header('Table Detection')
st.image(vis_det_img)
with tabs[1]:
st.header('Table Structure Recognition')
str_cols = st.columns(4)
str_cols[0].subheader('Table image')
str_cols[1].subheader('OCR result')
str_cols[2].subheader('Structure result')
str_cols[3].subheader('Cells result')
for idx, img in enumerate(crop_images):
str_cols = st.columns(4)
vis_img = visualize_image(img)
str_cols[0].image(vis_img)
ocr_result = ocr(img)
vis_ocr_img = visualize_ocr(img, ocr_result)
str_cols[1].image(vis_ocr_img)
structure_result = table_structure(img)
vis_str_img = visualize_structure(img, structure_result)
str_cols[2].image(vis_str_img)
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
cells = extract_text_from_cells(cells)
vis_cells_img = visualize_cells(img, cells)
str_cols[3].image(vis_cells_img)
all_cells.append(cells)
#df, csv_result = cells_to_csv(cells)
#print(df)
with tabs[2]:
st.header('Extracted Table(s)')
for idx, col in enumerate(st.columns(len(all_cells))):
with col:
if len(all_cells) > 1:
st.header(f'Table {idx + 1}')
with TemporaryDirectory() as temp_dir_path:
df = None
xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx')
cells_to_excel(all_cells[idx], xlsx_path)
with open(xlsx_path, 'rb') as ref:
df = pd.read_excel(ref)
st.dataframe(df)
st.download_button(
'Download Excel File',
ref,
file_name=f'output_{idx}.xlsx',
)
for idx, cells in enumerate(all_cells):
html_result = cells_to_html(cells)
st.subheader(f'HTML Table {idx + 1}')
st.markdown(html_result, unsafe_allow_html=True)
if __name__ == '__main__':
main()