import streamlit as st from PIL import Image import os import TDTSR import pytesseract from pytesseract import Output import postprocess as pp import pandas as pd import matplotlib.pyplot as plt import cv2 import numpy as np from transformers import TrOCRProcessor, VisionEncoderDecoderModel from cv2 import dnn_superres pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' st.set_option('deprecation.showPyplotGlobalUse', False) st.set_page_config(layout='wide') st.title("Table Detection and Table Structure Recognition") c1, c2, c3 = st.columns((1,1,1)) def PIL_to_cv(pil_img): return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) def cv_to_PIL(cv_img): return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) def pytess(cell_pil_img): return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='preserve_interword_spaces')['text']).strip() def TrOCR(cell_pil_img): processor = TrOCRProcessor.from_pretrained("SalML/trocr-base-printed") model = VisionEncoderDecoderModel.from_pretrained("SalML/trocr-base-printed") pixel_values = processor(images=cell_pil_img, return_tensors="pt").pixel_values generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text def super_res(pil_img): # requires opencv-contrib-python installed without the opencv-python sr = dnn_superres.DnnSuperResImpl_create() image = PIL_to_cv(pil_img) model_path = "./LapSRN_x8.pb" model_name = model_path.split('/')[1].split('_')[0].lower() model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1]) sr.readModel(model_path) sr.setModel(model_name, model_scale) final_img = sr.upsample(image) final_img = cv_to_PIL(final_img) return final_img def sharpen_image(pil_img): img = PIL_to_cv(pil_img) sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) # sharpen_kernel = np.array([[0, -1, 0], # [-1, 5,-1], # [0, -1, 0]]) sharpen = cv2.filter2D(img, -1, sharpen_kernel) pil_img = cv_to_PIL(sharpen) return pil_img def preprocess_magic(pil_img): cv_img = PIL_to_cv(pil_img) grayscale_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) _, binary_image = cv2.threshold(grayscale_image, 0, 255, cv2.THRESH_OTSU) count_white = np.sum(binary_image > 0) count_black = np.sum(binary_image == 0) if count_black > count_white: binary_image = 255 - binary_image black_text_white_background_image = binary_image return cv_to_PIL(black_text_white_background_image) ### main code: for td_sample in os.listdir('D:/Jupyter/Multi-Type-TD-TSR/TD_samples/'): image = Image.open("D:/Jupyter/Multi-Type-TD-TSR/TD_samples/"+td_sample).convert("RGB") model, image, probas, bboxes_scaled = TDTSR.table_detector(image, THRESHOLD_PROBA=0.6) TDTSR.plot_results_detection(c1, model, image, probas, bboxes_scaled) cropped_img_list = TDTSR.plot_table_detection(c2, model, image, probas, bboxes_scaled) for unpadded_table in cropped_img_list: # table : pil_img table = TDTSR.add_margin(unpadded_table) model, image, probas, bboxes_scaled = TDTSR.table_struct_recog(table, THRESHOLD_PROBA=0.6) # The try, except block of code below plots table header row and simple rows try: rows, cols = TDTSR.plot_structure(c3, model, image, probas, bboxes_scaled, class_to_show=0) rows, cols = TDTSR.sort_table_featuresv2(rows, cols) # headers, rows, cols are ordered dictionaries with 5th element value of tuple being pil_img rows, cols = TDTSR.individual_table_featuresv2(table, rows, cols) # TDTSR.plot_table_features(c1, header, row_header, rows, cols) except Exception as printableException: st.write(td_sample, ' terminated with exception:', printableException) # master_row = TDTSR.master_row_set(header, row_header, rows, cols) master_row = rows # cells_img = TDTSR.object_to_cells(master_row, cols) cells_img = TDTSR.object_to_cellsv2(master_row, cols) headers = [] cells_list = [] # st.write(cells_img) for n, kv in enumerate(cells_img.items()): k, row_images = kv if n == 0: for idx, header in enumerate(row_images): # plt.imshow(header) # c2.pyplot() # c2.write(pytess(header)) ############################ SR_img = super_res(header) # # w, h = SR_img.size # # SR_img = SR_img.crop((0 ,0 ,w, h-60)) # plt.imshow(SR_img) # c3.pyplot() # c3.write(pytess(SR_img)) header_text = pytess(SR_img) if header_text == '': header_text = 'empty_col'+str(idx) headers.append(header_text) else: for cells in row_images: # plt.imshow(cells) # c2.pyplot() # c2.write(pytess(cells)) ############################## SR_img = super_res(cells) # # w, h = SR_img.size # # SR_img = SR_img.crop((0 ,0 ,w, h-60)) # plt.imshow(SR_img) # c3.pyplot() # c3.write(pytess(SR_img)) cells_list.append(pytess(SR_img)) df = pd.DataFrame("", index=range(0, len(master_row)), columns=headers) cell_idx = 0 for nrows in range(len(master_row)-1): for ncols in range(len(cols)): df.iat[nrows, ncols] = cells_list[cell_idx] cell_idx += 1 c3.dataframe(df) # break