""" Copyright (C) 2023 Microsoft Corporation Script to process, edit, filter, and canonicalize SciTSR to align it with PubTables-1M. We still need to verify that this script works correctly. If you use this code in your published work, we request that you cite our papers and table-transformer GitHub repo. """ import json import os from collections import defaultdict import traceback from difflib import SequenceMatcher import argparse import fitz from fitz import Rect from PIL import Image import xml.etree.ElementTree as ET from xml.dom import minidom import editdistance import numpy as np from tqdm import tqdm def adjust_bbox_coordinates(data, doc): # Change bbox coordinates to be relative to PyMuPDF page.rect coordinate space media_box = doc[0].mediabox mat = doc[0].transformation_matrix for cell in ['cells']: if not 'bbox' in cell: continue bbox = list(Rect(cell['bbox']) * mat) bbox = [bbox[0] + media_box[0], bbox[1] - media_box[1], bbox[2] + media_box[0], bbox[3] - media_box[1]] cell['bbox'] = bbox def table_to_text(table_dict): return ' '.join([cell['text_content'].strip() for cell in table_dict['cells']]) def align(page_string="", xml_string="", page_character_rewards=None, xml_character_rewards=None, match_reward=2, space_match_reward=3, lowercase_match_reward=2, mismatch_penalty=-5, page_new_gap_penalty=-2, xml_new_gap_penalty=-5, page_continue_gap_penalty=-0.01, xml_continue_gap_penalty=-0.1, page_boundary_gap_reward=0.01, gap_not_after_space_penalty=-1, score_only=False, gap_character='_'): ''' Dynamic programming sequence alignment between two text strings; the first text string is considered to come from the PDF document; the second text string is considered to come from the XML document. Traceback convention: -1 = up, 1 = left, 0 = diag up-left ''' scores = np.zeros((len(page_string) + 1, len(xml_string) + 1)) pointers = np.zeros((len(page_string) + 1, len(xml_string) + 1)) # Initialize first column for row_idx in range(1, len(page_string) + 1): scores[row_idx, 0] = scores[row_idx - 1, 0] + page_boundary_gap_reward pointers[row_idx, 0] = -1 # Initialize first row for col_idx in range(1, len(xml_string) + 1): #scores[0, col_idx] = scores[0, col_idx - 1] + 0 pointers[0, col_idx] = 1 for row_idx in range(1, len(page_string) + 1): for col_idx in range(1, len(xml_string) + 1): # Score if matching the characters if page_string[row_idx - 1].lower() == xml_string[col_idx - 1].lower(): if page_string[row_idx - 1] == ' ': reward = space_match_reward elif page_string[row_idx - 1] == xml_string[col_idx - 1]: reward = match_reward else: reward = lowercase_match_reward if not page_character_rewards is None: reward *= page_character_rewards[row_idx-1] if not xml_character_rewards is None: reward *= xml_character_rewards[col_idx-1] diag_score = scores[row_idx - 1, col_idx - 1] + reward else: diag_score = scores[row_idx - 1, col_idx - 1] + mismatch_penalty if pointers[row_idx, col_idx - 1] == 1: same_row_score = scores[row_idx, col_idx - 1] + page_continue_gap_penalty else: same_row_score = scores[row_idx, col_idx - 1] + page_new_gap_penalty if not xml_string[col_idx - 1] == ' ': same_row_score += gap_not_after_space_penalty if col_idx == len(xml_string): same_col_score = scores[row_idx - 1, col_idx] + page_boundary_gap_reward elif pointers[row_idx - 1, col_idx] == -1: same_col_score = scores[row_idx - 1, col_idx] + xml_continue_gap_penalty else: same_col_score = scores[row_idx - 1, col_idx] + xml_new_gap_penalty if not page_string[row_idx - 1] == ' ': same_col_score += gap_not_after_space_penalty max_score = max(diag_score, same_col_score, same_row_score) scores[row_idx, col_idx] = max_score if diag_score == max_score: pointers[row_idx, col_idx] = 0 elif same_col_score == max_score: pointers[row_idx, col_idx] = -1 else: pointers[row_idx, col_idx] = 1 score = scores[len(page_string), len(xml_string)] if score_only: return score cur_row = len(page_string) cur_col = len(xml_string) aligned_page_string = "" aligned_xml_string = "" while not (cur_row == 0 and cur_col == 0): if pointers[cur_row, cur_col] == -1: cur_row -= 1 aligned_xml_string += gap_character aligned_page_string += page_string[cur_row] elif pointers[cur_row, cur_col] == 1: cur_col -= 1 aligned_page_string += gap_character aligned_xml_string += xml_string[cur_col] else: cur_row -= 1 cur_col -= 1 aligned_xml_string += xml_string[cur_col] aligned_page_string += page_string[cur_row] aligned_page_string = aligned_page_string[::-1] aligned_xml_string = aligned_xml_string[::-1] alignment = [aligned_page_string, aligned_xml_string] return alignment, score def locate_table(page_words, table): #sorted_words = sorted(words, key=functools.cmp_to_key(compare_meta)) sorted_words = page_words page_text = " ".join([word[4] for word in sorted_words]) page_text_source = [] for num, word in enumerate(sorted_words): for c in word[4]: page_text_source.append(num) page_text_source.append(None) page_text_source = page_text_source[:-1] table_text = table_to_text(table) table_text_source = [] for num, cell in enumerate(table['cells']): for c in cell['text_content'].strip(): table_text_source.append(num) table_text_source.append(None) table_text_source = table_text_source[:-1] X = page_text.replace("~", "^") Y = table_text.replace("~", "^") match_reward = 3 mismatch_penalty = -2 #new_gap_penalty = -10 continue_gap_penalty = -0.05 page_boundary_gap_reward = 0.2 alignment, score = align(X, Y, match_reward=match_reward, mismatch_penalty=mismatch_penalty, page_boundary_gap_reward=page_boundary_gap_reward, score_only=False, gap_character='~') table_words = set() column_words = dict() row_words = dict() cell_words = dict() page_count = 0 table_count = 0 for char1, char2 in zip(alignment[0], alignment[1]): if not char1 == "~": if char1 == char2: table_words.add(page_text_source[page_count]) cell_num = table_text_source[table_count] if not cell_num is None: if cell_num in cell_words: cell_words[cell_num].add(page_text_source[page_count]) else: cell_words[cell_num] = set([page_text_source[page_count]]) page_count += 1 if not char2 == "~": table_count += 1 inliers = [] for word_num in table_words: if word_num: inliers.append(sorted_words[word_num]) if len(inliers) == 0: return None, None cell_bboxes = {} for cell_num, cell in enumerate(table['cells']): cell_bbox = None if cell_num in cell_words: for word_num in cell_words[cell_num]: if not word_num is None: word_bbox = sorted_words[word_num][0:4] if not cell_bbox: cell_bbox = [entry for entry in word_bbox] else: cell_bbox[0] = min(cell_bbox[0], word_bbox[0]) cell_bbox[1] = min(cell_bbox[1], word_bbox[1]) cell_bbox[2] = max(cell_bbox[2], word_bbox[2]) cell_bbox[3] = max(cell_bbox[3], word_bbox[3]) cell_bboxes[cell_num] = cell_bbox return cell_bboxes, inliers def string_similarity(string1, string2): return SequenceMatcher(None, string1, string2).ratio() # My current theory is that this is the correct code but that some examples are simply wrong # (for example, the bolded text is aligned correctly but not the normal text) def adjust_bbox_coordinates(data, doc): # Change bbox coordinates to be relative to PyMuPDF page.rect coordinate space media_box = doc[0].mediabox mat = doc[0].transformation_matrix for cell in data['html']['cells']: if not 'bbox' in cell: continue bbox = list(Rect(cell['bbox']) * mat) bbox = [bbox[0] + media_box[0], bbox[1] - media_box[1], bbox[2] + media_box[0], bbox[3] - media_box[1]] cell['bbox'] = bbox def create_document_page_image(doc, page_num, zoom=None, output_image_max_dim=1000): page = doc[page_num] if zoom is None: zoom = output_image_max_dim / max(page.rect) mat = fitz.Matrix(zoom, zoom) pix = page.get_pixmap(matrix = mat, alpha = False) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) return img class AnnotationMismatchException(Exception): pass class HTMLParseOverlappingGridCellsException(Exception): pass class HTMLParseMissingGridCellsException(Exception): pass class SmallTableException(Exception): pass class AmbiguousHeaderException(Exception): pass class OversizedHeaderException(Exception): pass class RowColumnOverlapException(Exception): pass # Average edit distance > 0.05 class TextAnnotationQualityException(Exception): pass class UndeterminedRowBoundaryException(Exception): pass class UndeterminedColumnBoundaryException(Exception): pass # For cases where the data contains a cent symbol in its own column/cell, etc. class OversegmentedColumnsException(Exception): pass # For cases where the iterative cell text bounding box adjustment doesn't quickly converge class RunawayTextAdjustmentException(Exception): pass # For cases where the same grid cell is assigned to multiple spanning cells class AmbiguousSpanningCellException(Exception): pass class RowsIntersectException(Exception): pass class ColumnsIntersectException(Exception): pass class DotsInCellTextBboxException(Exception): pass class PoorTextCellFitException(Exception): pass # Use for interrupting after a specific event occurs for debugging class DebugException(Exception): pass # Specific to this dataset; cells at the top of the table can be incorrectly merged # Merged cells at the top of the table are not inherently bad, but for this dataset we need to catch these class OvermergedCellsException(Exception): pass # Headers that are incomplete and stopped at a projected row header class IncompleteHeaderException(Exception): pass class NoTableBodyException(Exception): pass class DotsRetainedException(Exception): pass class BadProjectedRowHeaderException(Exception): pass class MultipleColumnHeadersException(Exception): pass class TextAnnotationQualityException(Exception): pass def create_table_dict(annotation_data): table_dict = {} table_dict['reject'] = [] table_dict['fix'] = [] cells = [] for cell in annotation_data['cells']: new_cell = {} new_cell['text_content'] = ' '.join(cell['content']).strip() new_cell['pdf_text_tight_bbox'] = [] new_cell['column_nums'] = list(range(cell['start_col'], cell['end_col']+1)) new_cell['row_nums'] = list(range(cell['start_row'], cell['end_row']+1)) new_cell['is_column_header'] = False cells.append(new_cell) # Make sure no grid locations are duplicated # Could be bad data or bad parsing algorithm grid_cell_locations = [] for cell in cells: for row_num in cell['row_nums']: for column_num in cell['column_nums']: grid_cell_locations.append((row_num, column_num)) if not len(grid_cell_locations) == len(set(grid_cell_locations)): table_dict['reject'].append("HTML overlapping grid cells") num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 table_dict['cells'] = cells table_dict['rows'] = {row_num: {'is_column_header': False} for row_num in range(num_rows)} table_dict['columns'] = {column_num: {} for column_num in range(num_columns)} return table_dict def complete_table_grid(table_dict): rects_by_row = defaultdict(lambda: [None, None, None, None]) rects_by_column = defaultdict(lambda: [None, None, None, None]) table_rect = Rect() # Determine bounding box for rows and columns for cell in table_dict['cells']: if not 'pdf_text_tight_bbox' in cell or len(cell['pdf_text_tight_bbox']) == 0: continue bbox = cell['pdf_text_tight_bbox'] table_rect.include_rect(list(bbox)) min_row = min(cell['row_nums']) if rects_by_row[min_row][1] is None: rects_by_row[min_row][1] = bbox[1] else: rects_by_row[min_row][1] = min(rects_by_row[min_row][1], bbox[1]) max_row = max(cell['row_nums']) if rects_by_row[max_row][3] is None: rects_by_row[max_row][3] = bbox[3] else: rects_by_row[max_row][3] = max(rects_by_row[max_row][3], bbox[3]) min_column = min(cell['column_nums']) if rects_by_column[min_column][0] is None: rects_by_column[min_column][0] = bbox[0] else: rects_by_column[min_column][0] = min(rects_by_column[min_column][0], bbox[0]) max_column = max(cell['column_nums']) if rects_by_column[max_column][2] is None: rects_by_column[max_column][2] = bbox[2] else: rects_by_column[max_column][2] = max(rects_by_column[max_column][2], bbox[2]) table_bbox = list(table_rect) table_dict['pdf_table_bbox'] = table_bbox for row_num, row_rect in rects_by_row.items(): row_rect[0] = table_bbox[0] row_rect[2] = table_bbox[2] for col_num, col_rect in rects_by_column.items(): col_rect[1] = table_bbox[1] col_rect[3] = table_bbox[3] for k, row in table_dict['rows'].items(): v = rects_by_row[k] table_dict['rows'][k]['pdf_row_bbox'] = list(v) for k, column in table_dict['columns'].items(): v = rects_by_column[k] table_dict['columns'][k]['pdf_column_bbox'] = list(v) for k, row in table_dict['rows'].items(): for elem in row['pdf_row_bbox']: if elem is None: table_dict['reject'].append("undetermined row boundary") for k, column in table_dict['columns'].items(): for elem in column['pdf_column_bbox']: if elem is None: table_dict['reject'].append("undetermined column boundary") # Adjust bounding boxes if minor overlap fixed_overlap = False num_rows = len(table_dict['rows']) for row_num in range(num_rows-1): row1_bbox = table_dict['rows'][row_num]['pdf_row_bbox'] row2_bbox = table_dict['rows'][row_num+1]['pdf_row_bbox'] overlap1 = overlap(row1_bbox, row2_bbox) overlap2 = overlap(row2_bbox, row1_bbox) if overlap1 > 0 and overlap2 > 0: if overlap1 < 0.18 and overlap2 < 0.18: fixed_overlap = True midpoint = 0.5 * (row1_bbox[3] + row2_bbox[1]) table_dict['rows'][row_num]['pdf_row_bbox'][3] = midpoint table_dict['rows'][row_num+1]['pdf_row_bbox'][1] = midpoint fixed_overlap = True else: table_dict['reject'].append("rows intersect") # Intersect each row and column to determine grid cell bounding boxes #page_words = page.get_text_words() for cell in table_dict['cells']: rows_rect = Rect() cols_rect = Rect() for row_num in cell['row_nums']: rows_rect.include_rect(table_dict['rows'][row_num]['pdf_row_bbox']) for col_num in cell['column_nums']: cols_rect.include_rect(table_dict['columns'][col_num]['pdf_column_bbox']) pdf_bbox = rows_rect.intersect(cols_rect) cell['pdf_bbox'] = list(pdf_bbox) def identify_projected_row_headers(table_dict): num_cols = len(table_dict['columns']) cells_with_text_count_by_row = defaultdict(int) all_cells_in_row_only_in_one_row_by_row = defaultdict(lambda: True) has_first_column_cell_with_text_by_row = defaultdict(bool) for cell in table_dict['cells']: if len(cell['text_content']) > 0: for row_num in cell['row_nums']: cells_with_text_count_by_row[row_num] += 1 if 0 in cell['column_nums']: has_first_column_cell_with_text_by_row[row_num] = True one_row_only = len(cell['row_nums']) == 1 for row_num in cell['row_nums']: all_cells_in_row_only_in_one_row_by_row[row_num] = all_cells_in_row_only_in_one_row_by_row[row_num] and one_row_only projected_row_header_rows = set() for row_num, row in table_dict['rows'].items(): if (not row['is_column_header'] and cells_with_text_count_by_row[row_num] == 1 and all_cells_in_row_only_in_one_row_by_row[row_num] and has_first_column_cell_with_text_by_row[row_num]): projected_row_header_rows.add(row_num) return projected_row_header_rows def annotate_projected_row_headers(table_dict): num_cols = len(table_dict['columns']) projected_row_header_rows = identify_projected_row_headers(table_dict) cells_to_remove = [] for cell in table_dict['cells']: if len(set(cell['row_nums']).intersection(projected_row_header_rows)) > 0: if len(cell['text_content']) > 0: cell['column_nums'] = list(range(num_cols)) cell['is_projected_row_header'] = True else: cells_to_remove.append(cell) # Consolidate blank cells after the first cell into the projected row header else: cell['is_projected_row_header'] = False for cell in cells_to_remove: table_dict['fix'].append('merged projected row header') table_dict['cells'].remove(cell) for row_num, row in table_dict['rows'].items(): if row_num in projected_row_header_rows: row['is_projected_row_header'] = True else: row['is_projected_row_header'] = False # Delete projected row headers in last rows num_rows = len(table_dict['rows']) row_nums_to_delete = [] for row_num in range(num_rows-1, -1, -1): if table_dict['rows'][row_num]['is_projected_row_header']: row_nums_to_delete.append(row_num) else: break if len(row_nums_to_delete) > 0: for row_num in row_nums_to_delete: del table_dict['rows'][row_num] table_dict['fix'].append('removed projected row header at bottom of table') for cell in table_dict['cells'][:]: if row_num in cell['row_nums']: table_dict['cells'].remove(cell) def merge_group(table_dict, group): cells_to_delete = [] if len(group) == 1: return table_dict group = sorted(group, key=lambda k: min(k['row_nums'])) cell = group[0] try: cell_text_rect = Rect(cell['pdf_text_tight_bbox']) except: cell_text_rect = Rect() for cell2 in group[1:]: cell['row_nums'] = list(set(sorted(cell['row_nums'] + cell2['row_nums']))) cell['column_nums'] = list(set(sorted(cell['column_nums'] + cell2['column_nums']))) cell['text_content'] = (cell['text_content'].strip() + " " + cell2['text_content'].strip()).strip() try: cell2_text_rect = Rect(cell2['pdf_text_tight_bbox']) except: cell2_text_rect = Rect() cell_text_rect = cell_text_rect.include_rect(list(cell2_text_rect)) if cell_text_rect.get_area() == 0: cell['pdf_text_tight_bbox'] = [] else: cell['pdf_text_tight_bbox'] = list(cell_text_rect) cell['is_projected_row_header'] = False cells_to_delete.append(cell2) try: for cell in cells_to_delete: table_dict['cells'].remove(cell) table_dict['fix'].append('merged oversegmented spanning cell') except: table_dict['reject'].append("ambiguous spanning cell") #raise AmbiguousSpanningCellException def remove_empty_rows(table_dict): num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) has_content_by_row = defaultdict(bool) for cell in table_dict['cells']: has_content = len(cell['text_content'].strip()) > 0 for row_num in cell['row_nums']: has_content_by_row[row_num] = has_content_by_row[row_num] or has_content row_num_corrections = np.cumsum([int(not has_content_by_row[row_num]) for row_num in range(num_rows)]).tolist() # Delete cells in empty rows and renumber other cells cells_to_delete = [] for cell in table_dict['cells']: new_row_nums = [] for row_num in cell['row_nums']: if has_content_by_row[row_num]: new_row_nums.append(row_num - row_num_corrections[row_num]) cell['row_nums'] = new_row_nums if len(new_row_nums) == 0: cells_to_delete.append(cell) for cell in cells_to_delete: table_dict['fix'].append('removed empty row') table_dict['cells'].remove(cell) rows = {} for row_num, has_content in has_content_by_row.items(): if has_content: new_row_num = row_num - row_num_corrections[row_num] rows[new_row_num] = table_dict['rows'][row_num] table_dict['rows'] = rows def merge_rows(table_dict): num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) co_occurrence_matrix = np.zeros((num_rows, num_rows)) for cell in table_dict['cells']: for row_num1 in cell['row_nums']: for row_num2 in cell['row_nums']: if row_num1 >= row_num2: continue co_occurrence_matrix[row_num1, row_num2] += len(cell['column_nums']) new_row_num = 0 current_row_group = 0 keep_row = [True] row_grouping = [current_row_group] for row_num in range(num_rows-1): if not co_occurrence_matrix[row_num, row_num+1] == num_columns: keep_row.append(True) new_row_num += 1 else: table_dict['fix'].append('merged rows spanned together in every column') keep_row.append(False) row_grouping.append(new_row_num) for cell in table_dict['cells']: cell['row_nums'] = [row_grouping[row_num] for row_num in cell['row_nums'] if keep_row[row_num]] table_dict['rows'] = {row_grouping[row_num]: table_dict['rows'][row_num] for row_num in range(num_rows) if keep_row[row_num]} def remove_empty_columns(table_dict): num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) has_content_by_column = defaultdict(bool) for cell in table_dict['cells']: has_content = len(cell['text_content'].strip()) > 0 for column_num in cell['column_nums']: has_content_by_column[column_num] = has_content_by_column[column_num] or has_content column_num_corrections = np.cumsum([int(not has_content_by_column[column_num]) for column_num in range(num_columns)]).tolist() # Delete cells in empty columns and renumber other cells cells_to_delete = [] for cell in table_dict['cells']: new_column_nums = [] for column_num in cell['column_nums']: if has_content_by_column[column_num]: new_column_nums.append(column_num - column_num_corrections[column_num]) cell['column_nums'] = new_column_nums if len(new_column_nums) == 0: cells_to_delete.append(cell) for cell in cells_to_delete: table_dict['fix'].append('removed empty column') table_dict['cells'].remove(cell) columns = {} for column_num, has_content in has_content_by_column.items(): if has_content: new_column_num = column_num - column_num_corrections[column_num] columns[new_column_num] = table_dict['columns'][column_num] table_dict['columns'] = columns def merge_columns(table_dict): num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) co_occurrence_matrix = np.zeros((num_columns, num_columns)) for cell in table_dict['cells']: for column_num1 in cell['column_nums']: for column_num2 in cell['column_nums']: if column_num1 >= column_num2: continue co_occurrence_matrix[column_num1, column_num2] += len(cell['row_nums']) new_column_num = 0 current_column_group = 0 keep_column = [True] column_grouping = [current_column_group] for column_num in range(num_columns-1): if not co_occurrence_matrix[column_num, column_num+1] == num_rows: keep_column.append(True) new_column_num += 1 else: table_dict['fix'].append('merged columns spanned together in every row') keep_column.append(False) column_grouping.append(new_column_num) for cell in table_dict['cells']: cell['column_nums'] = [column_grouping[column_num] for column_num in cell['column_nums'] if keep_column[column_num]] table_dict['columns'] = {column_grouping[column_num]: table_dict['columns'][column_num] for column_num in range(num_columns) if keep_column[column_num]} # Look for tables with blank cells to merge in the first column def merge_spanning_cells_in_first_column(table_dict): first_column_cells = [cell for cell in table_dict['cells'] if 0 in cell['column_nums']] first_column_cells = sorted(first_column_cells, key=lambda item: max(item['row_nums'])) first_column_merge_exclude = set() # Look for blank cells at bottom of first column text_by_row_num = {} for cell in table_dict['cells']: if 0 in cell['column_nums']: for row_num in cell['row_nums']: if not cell['is_column_header']: text_by_row_num[row_num] = cell['text_content'].strip() else: text_by_row_num[row_num] = "_" bottom_blank_rows = set() blank_rows = set() still_bottom = True add_bottom_rows = True for row_num in sorted(table_dict['rows'].keys(), reverse=True): if len(text_by_row_num[row_num]) > 0: still_bottom = False elif still_bottom: bottom_blank_rows.add(row_num) else: add_bottom_rows = False break if add_bottom_rows: first_column_merge_exclude = first_column_merge_exclude.union(bottom_blank_rows) # Look for tables with multiple headers num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) cell_grid = np.zeros((num_rows, num_columns)).astype('str').tolist() for cell in table_dict['cells']: for row_num in cell['row_nums']: for column_num in cell['column_nums']: cell_grid[row_num][column_num] = cell['text_content'] for row_num1 in range(num_rows-1): row1 = table_dict['rows'][row_num1] if not row1['is_column_header']: continue for row_num2 in range(row_num1+1, num_rows): row2 = table_dict['rows'][row_num2] if row2['is_column_header']: continue if cell_grid[row_num1] == cell_grid[row_num2]: first_column_merge_exclude.add(row_num2) for cell1 in table_dict['cells']: for cell2 in table_dict['cells']: if cell1['is_column_header'] and not cell2['is_column_header']: if cell1['text_content'] == cell2['text_content'] and len(cell1['text_content'].strip()) > 0: for row_num in cell2['row_nums']: first_column_merge_exclude.add(row_num) current_filled_cell = None groups = defaultdict(list) group_num = -1 for cell in first_column_cells: if len(set(cell['row_nums']).intersection(first_column_merge_exclude)) > 0: group_num += 1 elif len(cell['text_content']) > 0: group_num += 1 if group_num >= 0: groups[group_num].append(cell) for group_num, group in groups.items(): if len(group) > 1 and not group[0]['is_projected_row_header'] and not group[0]['is_column_header']: merge_group(table_dict, group) # STANDARDS: # 1. Column header, if it exists, is a tree structure. FinTabNet contains no header annotation so we can only # infer the header given some assumptions. If the top row does not contain all leaf nodes, complete the tree down to the # leaf nodes. # 2. There should be no blank cells in the column header. Blank cells should be aggregated into supercells where possible. # - First, blank supercells should be split into blank grid cells. # - If a column header cell has only blank grid cells directly below it, extend the cell downward to consume # any rows of entirely blank cells directly below it. # - After doing this for all column header cells, if a column header cell has only blank cells above it, consume # any rows of entirely blank cells directly above it. # - Blank supercells that occur after this grouping are arguably rightly annotated as supercells, but we will not # annotate these as supercells at the moment for consistency with previous ms_datasets that annotated these as blank # cells only. Detecting a blank supercell would not impact the structure inferred for the table. # - Any remaining blank cells are ambiguous, and while it is not good table design to have these, they're not likely # to be a nuisance. # 3. There should be no blank cells in the row header. This is trickier because the row header is not explicitly # annotated and must be inferred. See below for more on determining which columns are in the row header. # - For columns in the row header, blank cells should be aggregated under the first cell that does not span the # entire row. This assumes "top" vertical alignment for text. "Middle" vertical alignment is normally already # associated with supercells and is already explicit. # 4. Inferring the row header. # - The row header is explicit whenever the first N columns do not have a column header. In other words, when # the stub header is blank. Otherwise it is implicit which columns, if any, correspond to the column header. # - If a row header exists, it is also a tree just like the column header and must end at a column of leaf nodes. # Not only does this mean supercells cannot be the final column of a row header, but repeated values in a column # mean that the column cannot be the final column of a row header. # - A column that is not part of the row header (possibly the first column) can have repeated values. Having repeated # values is an indication of row header continuation but not of row header status to begin with. # - In most cases, numeric values are data. If the numeric values are integer and sorted, this may be part of the row # header. # - Rows where only one cell has content, either left justified or centered across the table, are part of an implicit # first column that begins a row header. The stub header belongs in this first column if there is not a row cell. # 5. A row cell at the top of the table is either the title of the table (if there are no other row cells in the table), # or part of the row header if there are additional row cells below, and belongs in an implicit column. # 6. Tables have at least one row and two columns. A table with only one column is a list. def correct_header(table_dict, assume_header_if_more_than_two_columns=True): num_columns = len(table_dict['columns']) num_rows = len(table_dict['rows']) if num_columns < 2 or num_rows < 1: table_dict['reject'].append("small table") #raise SmallTableException("Table does not have at least one row and two columns") #---DETERMINE FULL EXTENT OF COLUMN HEADER # - Each of the below steps determines different rows that must be in the column header. # - The final column header includes all rows that are originally annotated as being in the column # header plus any additional rows determined to be in the column header by the following steps. table_has_column_header = False # First determine if there is definitely a column header. Four cases: # 1. We specify that we want to assume there is one for all tables with more than two columns: if assume_header_if_more_than_two_columns and num_columns > 2: table_has_column_header = True # 2. An annotator says there is if not table_has_column_header: header_rows = [row_num for row_num, row in table_dict['rows'].items() if row['is_column_header']] if 0 in header_rows: table_has_column_header = True # 3. The cell occupying the first row and column is blank if not table_has_column_header: for cell in table_dict['cells']: if 0 in cell['column_nums'] and 0 in cell['row_nums'] and len(cell['text_content'].strip()) == 0: table_has_column_header = True break # 4. There is a horizontal spanning cell in the first row if not table_has_column_header: for cell in table_dict['cells']: if 0 in cell['row_nums'] and len(cell['column_nums']) > 1: table_has_column_header = True break # Then determine if the column header needs to be extended past its current annotated extent. # 1. A header that already is annotated in at least one row continues at least until each column # has a cell occupying only that column # 2. A header with a column with a blank cell must continue at least as long as the blank cells continue # (unless rule #1 is satisfied and a possible projected row header is reached?) if table_has_column_header: # Do not use this rule; while perhaps not ideal, columns can have the same header #print("Flattening header") #num_rows = len(table_dict['rows']) #num_columns = len(table_dict['columns']) #cell_grid = np.zeros((num_rows, num_columns)).astype('str').tolist() #for cell in table_dict['cells']: # for row_num in cell['row_nums']: # for column_num in cell['column_nums']: # cell_grid[row_num][column_num] = cell['text_content'] #flattened_header = ['' for column_num in range(num_columns)] #for row_num in range(num_rows): # unique_headers = True # for column_num in range(num_columns): # flattened_header[column_num] += ' ' + cell_grid[row_num][column_num] # flattened_header[column_num] = flattened_header[column_num].strip() # print(flattened_header) # for column_num1 in range(num_columns-1): # for column_num2 in range(column_num1+1, num_columns): # if flattened_header[column_num1] == flattened_header[column_num2] and len(flattened_header[column_num1]) > 0: # unique_headers = False # if unique_headers: # break #unique_header_row = row_num #print(unique_header_row) first_column_filled_by_row = defaultdict(bool) for cell in table_dict['cells']: if 0 in cell['column_nums']: if len(cell['text_content']) > 0: for row_num in cell['row_nums']: first_column_filled_by_row[row_num] = True first_column_filled_by_row = defaultdict(bool) for cell in table_dict['cells']: if 0 in cell['column_nums']: if len(cell['text_content']) > 0: for row_num in cell['row_nums']: first_column_filled_by_row[row_num] = True first_single_node_row_by_column = defaultdict(lambda: len(table_dict['rows'])-1) for cell in table_dict['cells']: if len(cell['column_nums']) == 1: first_single_node_row_by_column[cell['column_nums'][0]] = min(first_single_node_row_by_column[cell['column_nums'][0]], max(cell['row_nums'])) first_filled_single_node_row_by_column = defaultdict(lambda: len(table_dict['rows'])-1) for cell in table_dict['cells']: if len(cell['column_nums']) == 1 and len(cell['text_content'].strip()) > 0: first_filled_single_node_row_by_column[cell['column_nums'][0]] = min(first_filled_single_node_row_by_column[cell['column_nums'][0]], max(cell['row_nums'])) first_filled_cell_by_column = defaultdict(lambda: len(table_dict['rows'])-1) for cell in table_dict['cells']: if len(cell['text_content']) > 0: min_row_num = min(cell['row_nums']) for column_num in cell['column_nums']: first_filled_cell_by_column[column_num] = min(first_filled_cell_by_column[column_num], min_row_num) projected_row_header_rows = identify_projected_row_headers(table_dict) if 0 in projected_row_header_rows: table_dict['reject'].append("bad projected row header") #raise BadProjectedRowHeaderException('Starting with PRH') #for row_num in range(num_rows): # if row_num in projected_row_header_rows: # projected_row_header_rows.remove(row_num) # else: # break # Header must continue until at least this row minimum_grid_cell_single_node_row = max(first_single_node_row_by_column.values()) # Header can stop prior to the first of these rows that occurs after the above row minimum_first_body_row = min(num_rows-1, max(first_filled_cell_by_column.values())) # Determine the max row for which a column N has been single and filled but column N+1 has not minimum_all_following_filled = -1 for row_num in range(num_rows): for column_num1 in range(num_columns-1): for column_num2 in range(column_num1+1, num_columns): if (first_filled_single_node_row_by_column[column_num2] > row_num and first_filled_single_node_row_by_column[column_num1] < first_filled_single_node_row_by_column[column_num2]): minimum_all_following_filled = row_num + 1 #minimum_projected_row_header_row = min([num_rows-1] + [elem for elem in projected_row_header_rows if elem > minimum_grid_cell_single_node_row]) if len(projected_row_header_rows) > 0: minimum_projected_row_header_row = min(projected_row_header_rows) else: minimum_projected_row_header_row = num_rows #first_possible_last_header_row = min(minimum_first_body_row, minimum_projected_row_header_row) - 1 first_possible_last_header_row = minimum_first_body_row - 1 last_header_row = max(minimum_all_following_filled, minimum_grid_cell_single_node_row, first_possible_last_header_row) x = last_header_row while(last_header_row < num_rows and not first_column_filled_by_row[last_header_row+1]): last_header_row += 1 #incomplete_header = False # temp for debugging if minimum_projected_row_header_row <= last_header_row: last_header_row = minimum_projected_row_header_row - 1 #incomplete_header = True for cell in table_dict['cells']: if max(cell['row_nums']) <= last_header_row: cell['is_column_header'] = True for row_num, row in table_dict['rows'].items(): if row_num <= last_header_row: row['is_column_header'] = True #if not x == last_header_row: # raise DebugException("Header extended") #if minimum_all_following_filled == last_header_row: # raise DebugException #if incomplete_header: # raise IncompleteHeaderException("Set last header row to be just before minimum projected row header row".format(last_header_row, minimum_projected_row_header_row)) if not table_has_column_header and num_columns == 2: table_dict['reject'].append("ambiguous header") #raise AmbiguousHeaderException("Missing header annotation for table with two columns; cannot unambiguously determine header") def canonicalize(table_dict): # Preprocessing step: Split every blank spanning cell in the column header into blank grid cells. cells_to_delete = [] try: for cell in table_dict['cells']: if (cell['is_column_header'] and len(cell['text_content'].strip()) == 0 and (len(cell['column_nums']) > 1 or len(cell['row_nums']) > 1)): cells_to_delete.append(cell) # Split this blank spanning cell into blank grid cells for column_num in cell['column_nums']: for row_num in cell['row_nums']: #row_bbox = table_dict['rows'][row_num]['pdf_row_bbox'] #column_bbox = table_dict['columns'][column_num]['pdf_column_bbox'] #bbox = list(Rect(row_bbox).intersect(list(column_bbox))) new_cell = {'text_content': '', 'column_nums': [column_num], 'row_nums': [row_num], 'is_column_header': cell['is_column_header'], 'pdf_text_tight_bbox': [], 'is_projected_row_header': False} table_dict['cells'].append(new_cell) except: print(traceback.format_exc()) for cell in cells_to_delete: table_dict['cells'].remove(cell) # Index cells by row-column position cell_grid_index = {} for cell in table_dict['cells']: for column_num in cell['column_nums']: for row_num in cell['row_nums']: cell_grid_index[(row_num, column_num)] = cell # Go bottom up, try to extend non-blank cells up to absorb blank cells header_groups = [] for cell in table_dict['cells']: if not cell['is_column_header'] or len(cell['text_content']) == 0: continue header_group = [cell] next_row_num = min(cell['row_nums']) - 1 for row_num in range(next_row_num, -1, -1): all_are_blank = True for column_num in cell['column_nums']: cell2 = cell_grid_index[(row_num, column_num)] all_are_blank = all_are_blank and len(cell2['text_content']) == 0 if all_are_blank: for column_num in cell['column_nums']: header_group.append(cell_grid_index[(row_num, column_num)]) else: break # Stop looking; must be contiguous if len(header_group) > 1: header_groups.append(header_group) for group in header_groups: merge_group(table_dict, group) # Index cells by row-column position cell_grid_index = {} for cell in table_dict['cells']: for column_num in cell['column_nums']: for row_num in cell['row_nums']: cell_grid_index[(row_num, column_num)] = cell num_rows = len(table_dict['rows']) # Go top down, try to extend non-blank cells down to absorb blank cells header_groups = [] for cell in table_dict['cells']: if not cell['is_column_header'] or len(cell['text_content']) == 0: continue header_group = [cell] next_row_num = max(cell['row_nums']) + 1 for row_num in range(next_row_num, num_rows): if not table_dict['rows'][row_num]['is_column_header']: break all_are_blank = True for column_num in cell['column_nums']: cell2 = cell_grid_index[(row_num, column_num)] all_are_blank = all_are_blank and len(cell2['text_content']) == 0 if all_are_blank: for column_num in cell['column_nums']: header_group.append(cell_grid_index[(row_num, column_num)]) else: break # Stop looking; must be contiguous if len(header_group) > 1: header_groups.append(header_group) for group in header_groups: merge_group(table_dict, group) # Index cells by row-column position cell_grid_index = {} for cell in table_dict['cells']: for column_num in cell['column_nums']: for row_num in cell['row_nums']: cell_grid_index[(row_num, column_num)] = cell # Go top down, merge any neighboring cells occupying the same columns, whether they are blank or not header_groups_by_row_column = defaultdict(list) header_groups = [] do_full_break = False for row_num in table_dict['rows']: for column_num in table_dict['columns']: cell = cell_grid_index[(row_num, column_num)] if not cell['is_column_header']: do_full_break = True break if len(header_groups_by_row_column[(row_num, column_num)]) > 0: continue if not row_num == min(cell['row_nums']) and column_num == min(cell['column_nums']): continue # Start new header group header_group = [cell] next_row_num = max(cell['row_nums']) + 1 while next_row_num < num_rows: cell2 = cell_grid_index[(next_row_num, column_num)] if cell2['is_column_header'] and set(cell['column_nums']) == set(cell2['column_nums']): header_group.append(cell2) for row_num2 in cell2['row_nums']: for column_num2 in cell2['column_nums']: header_groups_by_row_column[(row_num2, column_num2)] = header_group else: break next_row_num = max(cell2['row_nums']) + 1 for row_num2 in cell['row_nums']: for column_num2 in cell['column_nums']: header_groups_by_row_column[(row_num2, column_num2)] = header_group if len(header_group) > 1: header_groups.append(header_group) if do_full_break: break for group in header_groups: merge_group(table_dict, group) # Merge spanning cells in the row header merge_spanning_cells_in_first_column(table_dict) def is_all_dots(text): if len(text) > 0 and len(text.replace('.','')) == 0: return True return False def extract_pdf_text(table_dict, page_words, threshold=0.5): adjusted_text_tight_bbox = False for cell in table_dict['cells']: pdf_text_tight_bbox = cell['pdf_text_tight_bbox'] pdf_bbox = cell['pdf_bbox'] cell_page_words = [w for w in page_words if Rect(w[:4]).intersect(list(pdf_bbox)).get_area() / Rect(w[:4]).get_area() > threshold] cell_words = [w[4] for w in cell_page_words] cell_text = ''.join(cell_words) # Remove trailing dots from cell_page_words # Some of the original annotations include dots in the pdf_text_tight_bbox when they shouldn't # This code ensures that those are fixed, plus that dots are not added by extracting text from the # entire grid cell if len(cell_text) > 2 and cell_text[-1] == '.' and cell_text[-2] == '.': for page_word in cell_page_words[::-1]: if is_all_dots(page_word[4]): table_dict['fix'].append('removed dots from text cell') cell_page_words.remove(page_word) else: break cell_words_rect = Rect() for w in cell_page_words: cell_words_rect.include_rect(w[:4]) cell_words = [w[4] for w in cell_page_words] cell_text = ' '.join(cell_words) cell_text = cell_text.replace(' .', '.').replace(' ,', ',') if cell_text.endswith('..'): table_dict['reject'].append("dots retained") #raise DotsRetainedException("Dots retained in text [{}] '{}'".format(cell_words, cell_text)) cell['pdf_text_content'] = cell_text if cell_words_rect.get_area() > 0: new_pdf_text_tight_bbox = list(cell_words_rect) if not pdf_text_tight_bbox == new_pdf_text_tight_bbox: adjusted_text_tight_bbox = True cell['pdf_text_tight_bbox'] = new_pdf_text_tight_bbox return adjusted_text_tight_bbox def overlap(bbox1, bbox2): try: return Rect(bbox1).intersect(list(bbox2)).get_area() / Rect(bbox1).get_area() except: return 1 def table_text_edit_distance(cells): if len(cells) == 0: return 0 D = 0 for cell in cells: # Remove spaces and trailing periods xml_text = ''.join(cell['text_content'].split()).strip('.') pdf_text = ''.join(cell['pdf_text_content'].split()).strip('.') L = max(len(xml_text), len(pdf_text)) if L > 0: D += editdistance.eval(xml_text, pdf_text) / L return D / len(cells) def quality_control(table_dict, page_words): for row_num1, row1 in table_dict['rows'].items(): for row_num2, row2, in table_dict['rows'].items(): if row_num1 == row_num2 - 1: if row1['pdf_row_bbox'][3] > row2['pdf_row_bbox'][1] + 1: table_dict['reject'].append("rows intersect") #raise RowsIntersectException for column_num1, column1 in table_dict['columns'].items(): for column_num2, column2, in table_dict['columns'].items(): if column_num1 == column_num2 - 1: if column1['pdf_column_bbox'][2] > column2['pdf_column_bbox'][0] + 1: table_dict['reject'].append("columns intersect") #raise ColumnsIntersectException D = table_text_edit_distance(table_dict['cells']) if D > 0.05: table_dict['reject'].append("text annotation quality") word_overlaps = [] table_bbox = table_dict['pdf_table_bbox'] for w in page_words: if w[4] == '.': continue if overlap(w[:4], table_bbox) < 0.5: continue word_overlaps.append(max([overlap(w[:4], cell['pdf_bbox']) for cell in table_dict['cells']])) C = sum(word_overlaps) / len(word_overlaps) if C < 0.9: table_dict['reject'].append("poor text cell fit") def remove_html_tags_in_text(table_dict): for cell in table_dict['cells']: cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace("", " ") cell['text_content'] = cell['text_content'].replace(" ", " ") cell['text_content'] = cell['text_content'].strip() def is_good_bbox(bbox, page_bbox): if (not bbox[0] is None and not bbox[1] is None and not bbox[2] is None and not bbox[3] is None and bbox[0] >= 0 and bbox[1] >= 0 and bbox[2] <= page_bbox[2] and bbox[3] <= page_bbox[3] and bbox[0] < bbox[2]-1 and bbox[1] < bbox[3]-1): return True return False def create_document_page_image(doc, page_num, output_image_max_dim=1000): page = doc[page_num] page_width = page.rect[2] page_height = page.rect[3] if page_height > page_width: zoom = output_image_max_dim / page_height output_image_height = output_image_max_dim output_image_width = int(round(output_image_max_dim * page_width / page_height)) else: zoom = output_image_max_dim / page_width output_image_width = output_image_max_dim output_image_height = int(round(output_image_max_dim * page_height / page_width)) mat = fitz.Matrix(zoom, zoom) pix = page.get_pixmap(matrix = mat, alpha = False) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) return img def create_pascal_voc_page_element(image_filename, output_image_width, output_image_height, database): # Create XML of tables on PDF page in PASCAL VOC format annotation = ET.Element("annotation") folder = ET.SubElement(annotation, "folder").text = "" filename = ET.SubElement(annotation, "filename").text = image_filename path = ET.SubElement(annotation, "path").text = image_filename source = ET.SubElement(annotation, "source") database = ET.SubElement(source, "database").text = database size = ET.SubElement(annotation, "size") width = ET.SubElement(size, "width").text = str(output_image_width) height = ET.SubElement(size, "height").text = str(output_image_height) depth = ET.SubElement(size, "depth").text = "3" segmented = ET.SubElement(annotation, "segmented").text = "0" return annotation def create_pascal_voc_object_element(class_name, bbox, page_bbox, output_image_max_dim=1000): bbox_area = fitz.Rect(bbox).get_area() if bbox_area == 0: raise Exception intersect_area = fitz.Rect(page_bbox).intersect(fitz.Rect(bbox)).get_area() if abs(intersect_area - bbox_area) > 0.1: print(bbox) print(bbox_area) print(page_bbox) print(intersect_area) raise Exception object_ = ET.Element("object") name = ET.SubElement(object_, "name").text = class_name pose = ET.SubElement(object_, "pose").text = "Frontal" truncated = ET.SubElement(object_, "truncated").text = "0" difficult = ET.SubElement(object_, "difficult").text = "0" occluded = ET.SubElement(object_, "occluded").text = "0" bndbox = ET.SubElement(object_, "bndbox") page_width = page_bbox[2] - page_bbox[0] page_height = page_bbox[3] - page_bbox[1] if page_width > page_height: output_image_width = output_image_max_dim output_image_height = int(output_image_max_dim * page_height / page_width) else: output_image_height = output_image_max_dim output_image_width = int(output_image_max_dim * page_width / page_height) xmin = (bbox[0] - page_bbox[0]) * output_image_width / page_width ymin = (bbox[1] - page_bbox[1]) * output_image_height / page_height xmax = (bbox[2] - page_bbox[0]) * output_image_width / page_width ymax = (bbox[3] - page_bbox[1]) * output_image_height / page_height ET.SubElement(bndbox, "xmin").text = str(xmin) ET.SubElement(bndbox, "ymin").text = str(ymin) ET.SubElement(bndbox, "xmax").text = str(xmax) ET.SubElement(bndbox, "ymax").text = str(ymax) return object_ def save_xml_pascal_voc(page_annotation, filepath): xmlstr = minidom.parseString(ET.tostring(page_annotation)).toprettyxml(indent=" ") with open(filepath, "w") as f: f.write(xmlstr) def bbox_pdf_to_image(bbox, page_bbox, output_image_max_dim=1000): page_width = page_bbox[2] - page_bbox[0] page_height = page_bbox[3] - page_bbox[1] if page_width > page_height: output_image_width = output_image_max_dim output_image_height = int(output_image_max_dim * page_height / page_width) else: output_image_height = output_image_max_dim output_image_width = int(output_image_max_dim * page_width / page_height) xmin = (bbox[0] - page_bbox[0]) * output_image_width / page_width ymin = (bbox[1] - page_bbox[1]) * output_image_height / page_height xmax = (bbox[2] - page_bbox[0]) * output_image_width / page_width ymax = (bbox[3] - page_bbox[1]) * output_image_height / page_height return [xmin, ymin, xmax, ymax] def get_tokens_in_table_img(page_words, table_img_bbox): tokens = [] for word_num, word in enumerate(page_words): word['flags'] = 0 word['span_num'] = word_num word['line_num'] = 0 word['block_num'] = 0 tokens.append(word) tokens_in_table = [token for token in tokens if utils.iob(token['bbox'], table_img_bbox) >= 0.5] return tokens_in_table def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', help="Root directory for source data to process") parser.add_argument('--output_dir', help="Root directory for output data") parser.add_argument('--train_padding', type=int, default=30, help="The amount of padding to add around a table in the training set when cropping.") parser.add_argument('--test_padding', type=int, default=5, help="The amount of padding to add around a table in the val and test sets when cropping.") parser.add_argument('--skip_large', action='store_true') return parser.parse_args() def main(): args = get_args() data_directory = args.data_dir output_json_directory = os.path.join(args.output_dir, "SciTSR.c-PDF_Annotations_JSON") if not os.path.exists(output_json_directory): os.makedirs(output_json_directory) output_subdirs = ['images', 'train', 'test', 'val'] output_detection_directory = os.path.join(args.output_dir, "SciTSR.c-Image_Detection_PASCAL_VOC") if not os.path.exists(output_detection_directory): os.makedirs(output_detection_directory) for subdir in output_subdirs: subdirectory = os.path.join(output_detection_directory, subdir) if not os.path.exists(subdirectory): os.makedirs(subdirectory) output_page_words_directory = os.path.join(args.output_dir, "SciTSR.c-Image_Page_Words_JSON") if not os.path.exists(output_page_words_directory): os.makedirs(output_page_words_directory) output_structure_directory = os.path.join(args.output_dir, "SciTSR.c-Image_Structure_PASCAL_VOC") if not os.path.exists(output_structure_directory): os.makedirs(output_structure_directory) for subdir in output_subdirs: subdirectory = os.path.join(output_structure_directory, subdir) if not os.path.exists(subdirectory): os.makedirs(subdirectory) output_table_words_directory = os.path.join(args.output_dir, "SciTSR.c-Image_Table_Words_JSON") if not os.path.exists(output_table_words_directory): os.makedirs(output_table_words_directory) train_structure_files = os.listdir(os.path.join(data_directory, "train", "structure")) test_structure_files = os.listdir(os.path.join(data_directory, "test", "structure")) structure_filepaths = [os.path.join(data_directory, "train", "structure", elem) for elem in train_structure_files] structure_filepaths += [os.path.join(data_directory, "test", "structure", elem) for elem in test_structure_files] with open(os.path.join(data_directory, "train", "structure", structure_filepaths[1]), 'r') as infile: data = json.load(infile) splits_by_filepath = dict() test_filepaths = [os.path.join(data_directory, "test", "structure", elem) for elem in test_structure_files] train_filepaths = [os.path.join(data_directory, "train", "structure", elem) for elem in train_structure_files] for filepath in test_filepaths: splits_by_filepath[filepath] = 'test' n = len(train_filepaths) print(n) order = np.random.permutation(n) split_point = int(n * 0.875) for idx in order[:split_point]: splits_by_filepath[train_filepaths[idx]] = 'train' for idx in order[split_point:]: splits_by_filepath[train_filepaths[idx]] = 'val' processed_count = 0 good_count = 0 reject_count = 0 reject_reasons = defaultdict(list) fixes = defaultdict(list) kept_as_is_count = 0 output_image_max_dim = 1000 do_break = False for idx, structure_filepath in tqdm(enumerate(structure_filepaths)): split = splits_by_filepath[structure_filepath] try: with open(os.path.join(data_directory, "train", "structure", structure_filepath), 'r') as infile: data = json.load(infile) data['cells'] = sorted(sorted(data['cells'], key=lambda x: x['start_col']), key=lambda x: x['start_row']) table_dict = create_table_dict(data) img_filepath = structure_filepath.replace("structure", "img").replace(".json", ".png") pdf_filepath = structure_filepath.replace("structure", "pdf").replace(".json", ".pdf") doc = fitz.open(pdf_filepath) page = doc[0] page_words = doc[0].get_text_words() if split == 'val' or split == 'test': padding = args.test_padding else: padding = args.train_padding # For SciTSR, the table isn't always completely inside the PDF page page_rect = page.mediabox for word in page_words: bbox = word[:4] bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[2]+padding] page_rect.include_rect(bbox) page.set_mediabox(page_rect) img = create_document_page_image(doc, 0, output_image_max_dim=1000) cell_bboxes, inliers = locate_table(page_words, table_dict) for cell, cell_bbox in zip(table_dict['cells'], cell_bboxes.values()): if cell_bbox is None: cell_bbox = [] cell['pdf_text_tight_bbox'] = cell_bbox #adjust_bbox_coordinates(table_dict, doc) tables = [table_dict] except: traceback.print_exc() continue document_tables = [] for table_index, table_dict in enumerate(tables): try: page_num = 0 page = doc[page_num] exclude_for_structure = False exclude_for_detection = False table_dict['exclude_for_structure'] = exclude_for_structure table_dict['exclude_for_detection'] = exclude_for_detection table_dict['split'] = split table_dict['pdf_file_name'] = pdf_filepath.split("/")[-1] table_dict['pdf_page_index'] = page_num table_dict['document_id'] = table_dict['pdf_file_name'].replace(".pdf", "") table_dict['source_file_name'] = structure_filepath.split("/")[-1] table_dict['pdf_full_page_bbox'] = list(page.rect) table_dict['document_table_index'] = table_index table_dict['structure_id'] = "{}_{}".format(table_dict['document_id'], table_dict['document_table_index']) merged = False debug = False remove_empty_columns(table_dict) merge_columns(table_dict) remove_empty_rows(table_dict) merge_rows(table_dict) include = [] exclude = [] annotate_projected_row_headers(table_dict) correct_header(table_dict, assume_header_if_more_than_two_columns=True) annotate_projected_row_headers(table_dict) # Look for tables with multiple headers num_rows = len(table_dict['rows']) num_columns = len(table_dict['columns']) cell_grid = np.zeros((num_rows, num_columns)).astype('str').tolist() for cell in table_dict['cells']: for row_num in cell['row_nums']: for column_num in cell['column_nums']: cell_grid[row_num][column_num] = cell['text_content'] for row_num1 in range(num_rows-1): row1 = table_dict['rows'][row_num1] if not row1['is_column_header']: continue for row_num2 in range(row_num1+1, num_rows): row2 = table_dict['rows'][row_num2] if row2['is_column_header']: continue if cell_grid[row_num1] == cell_grid[row_num2]: print('multiple column headers') #table_dict['reject'].append("multiple column headers") for cell1 in table_dict['cells']: for cell2 in table_dict['cells']: if cell1['is_column_header'] and not cell2['is_column_header']: if cell1['text_content'] == cell2['text_content'] and len(cell1['text_content'].strip()) > 0: #table_dict['reject'].append("multiple column headers") print('multiple column headers') first_column_merge_exclude = [] canonicalize(table_dict) remove_empty_columns(table_dict) merge_columns(table_dict) remove_empty_rows(table_dict) merge_rows(table_dict) for row_num, row in table_dict['rows'].items(): if row['is_column_header'] and row_num > 4: table_dict['reject'].append("oversized header") # Iterative process because a grid cell bounding box depends on surrounding text, which can # change the bounding box for the cell, which can change the text that falls in the bounding box, # which can change the bounding boxes for other cells, and so on... adjust_text = True iterations = 0 while(adjust_text and iterations < 3): #look_for_dots_in_text_tight_bbox(table_dict, page_words, threshold=0.5) complete_table_grid(table_dict) adjust_text = extract_pdf_text(table_dict, page_words) iterations += 1 if adjust_text: table_dict['reject'].append("runaway text adjustment") num_rows = len(table_dict['rows']) num_cells_in_last_row = 0 for cell in table_dict['cells']: if num_rows-1 in cell['row_nums']: num_cells_in_last_row += 1 # Do manual visual inspection for box-text fit quality_control(table_dict, page_words) has_body = False for row_num, row in table_dict['rows'].items(): if not row['is_column_header']: has_body = True break if not has_body: table_dict['reject'].append("no table body") #if table_dict['rows'][0]['is_projected_row_header']: # table_dict['reject'].append("bad projected row header") num_rows = len(table_dict['rows']) if table_dict['rows'][num_rows-1]['is_projected_row_header']: table_dict['reject'].append("bad projected row header") except KeyboardInterrupt: do_break = True break except: print(traceback.format_exc()) table_dict['reject'].append('unknown exception') print('not ok') #continue processed_count += 1 if len(table_dict['reject']) > 0: reject_count += 1 for reject_reason in set(table_dict['reject']): reject_reasons[reject_reason].append(table_dict['structure_id']) table_dict['exclude_for_detection'] = True table_dict['exclude_for_structure'] = True else: good_count += 1 if len(table_dict['fix']) > 0: for fix in set(table_dict['fix']): fixes[fix].append(table_dict['structure_id']) else: kept_as_is_count += 1 document_tables.append(table_dict) del table_dict['reject'] del table_dict['fix'] if do_break: break if do_break: break # If not all tables present and included for detection, then exclude all for detection if not sum([1 for elem in document_tables if not elem['exclude_for_detection']]) == len(tables): for table_dict in document_tables: table_dict['exclude_for_detection'] = True if len(document_tables) == 0: continue save_filename = pdf_filepath.split("/")[-1].replace(".pdf", "") + "_tables.json" save_filepath = os.path.join(output_json_directory, save_filename) with open(save_filepath, 'w') as out_file: json.dump(document_tables, out_file, ensure_ascii=False, indent=4) # Create detection PASCAL VOC data if not document_tables[0]['exclude_for_detection']: detection_boxes_by_page = defaultdict(list) # Each table has associated bounding boxes for table_dict in document_tables: try: table_boxes = [] page_num = table_dict['pdf_page_index'] # Create detection data class_label = 'table' dict_entry = {'class_label': class_label, 'bbox': table_dict['pdf_table_bbox']} detection_boxes_by_page[page_num].append(dict_entry) except Exception as err: print(traceback.format_exc()) # Create detection PASCAL VOC XML file and page image for page_num, boxes in detection_boxes_by_page.items(): try: page_bbox = table_dict['pdf_full_page_bbox'] if not all([is_good_bbox(entry['bbox'], page_bbox) for entry in boxes]): raise Exception("At least one bounding box has non-positive area or is outside of image") # Create page image document_id = table_dict['document_id'] image_filename = document_id + "_" + str(page_num) + ".jpg" image_filepath = os.path.join(output_detection_directory, "images", image_filename) page_img = create_document_page_image(doc, page_num, output_image_max_dim=output_image_max_dim) # Initialize PASCAL VOC XML page_annotation = create_pascal_voc_page_element(image_filename, page_img.width, page_img.height, database="SciTSR.c-Detection") for entry in boxes: bbox = entry['bbox'] # Add to PASCAl VOC element = create_pascal_voc_object_element(entry['class_label'], entry['bbox'], page_bbox, output_image_max_dim=output_image_max_dim) page_annotation.append(element) xml_filename = document_id + "_" + str(page_num) + ".xml" xml_filepath = os.path.join(output_detection_directory, split, xml_filename) # Page words # output_page_words_directory page_rect = list(doc[page_num].rect) scale = output_image_max_dim / max(page_rect) tokens = [] for word_num, word in enumerate(doc[page_num].get_text_words()): token = {} token['flags'] = 0 token['span_num'] = word_num token['line_num'] = 0 token['block_num'] = 0 bbox = [round(scale * v, 5) for v in word[:4]] if Rect(bbox).get_area() > 0 and overlap(bbox, page_rect) > 0.75: bbox = [max(0, bbox[0]), max(0, bbox[1]), min(page_rect[2], bbox[2]), min(page_rect[3], bbox[3])] if Rect(bbox).get_area() > 0: token['bbox'] = bbox token['text'] = word[4] tokens.append(token) words_save_filepath = os.path.join(output_page_words_directory, document_id + "_" + str(page_num) + "_words.json") # Save page_img.save(image_filepath) save_xml_pascal_voc(page_annotation, xml_filepath) with open(words_save_filepath, 'w', encoding='utf8') as f: json.dump(tokens, f) except: print("Exception; skipping page") pass # Create structure PASCAL VOC data # output_structure_directory for table_dict in document_tables: if table_dict['exclude_for_structure']: continue page_num = table_dict['pdf_page_index'] page_rect = list(doc[page_num].rect) scale = output_image_max_dim / max(page_rect) page_img = create_document_page_image(doc, page_num, output_image_max_dim=output_image_max_dim) table_num = table_dict['document_table_index'] table_boxes = [] # Create structure recognition data class_label = 'table' dict_entry = {'class_label': class_label, 'bbox': table_dict['pdf_table_bbox']} table_boxes.append(dict_entry) rows = table_dict['rows'].values() rows = sorted(rows, key=lambda k: k['pdf_row_bbox'][1]) if len(rows) > 1: for row1, row2 in zip(rows[:-1], rows[1:]): mid_point = (row1['pdf_row_bbox'][3] + row2['pdf_row_bbox'][1]) / 2 row1['pdf_row_bbox'][3] = mid_point row2['pdf_row_bbox'][1] = mid_point columns = table_dict['columns'].values() columns = sorted(columns, key=lambda k: k['pdf_column_bbox'][0]) for col1, col2 in zip(columns[:-1], columns[1:]): mid_point = (col1['pdf_column_bbox'][2] + col2['pdf_column_bbox'][0]) / 2 col1['pdf_column_bbox'][2] = mid_point col2['pdf_column_bbox'][0] = mid_point for cell in table_dict['cells']: column_nums = cell['column_nums'] row_nums = cell['row_nums'] column_rect = Rect() row_rect = Rect() for column_num in column_nums: column_rect.include_rect(columns[column_num]['pdf_column_bbox']) for row_num in row_nums: row_rect.include_rect(rows[row_num]['pdf_row_bbox']) cell_rect = column_rect.intersect(row_rect) cell['pdf_bbox'] = list(cell_rect) header_rect = Rect() for cell in table_dict['cells']: cell_bbox = cell['pdf_bbox'] is_blank = len(cell['text_content'].strip()) == 0 is_spanning_cell = len(cell['row_nums']) > 1 or len(cell['column_nums']) > 1 is_column_header = cell['is_column_header'] is_projected_row_header = cell['is_projected_row_header'] if is_projected_row_header: dict_entry = {'class_label': 'table projected row header', 'bbox': cell['pdf_bbox']} table_boxes.append(dict_entry) elif is_spanning_cell and not is_blank: dict_entry = {'class_label': 'table spanning cell', 'bbox': cell['pdf_bbox']} table_boxes.append(dict_entry) if is_column_header: header_rect.include_rect(cell_bbox) if header_rect.get_area() > 0: dict_entry = {'class_label': 'table column header', 'bbox': list(header_rect)} table_boxes.append(dict_entry) for row in rows: row_bbox = row['pdf_row_bbox'] dict_entry = {'class_label': 'table row', 'bbox': row_bbox} table_boxes.append(dict_entry) # table_entry['columns'] for column in columns: dict_entry = {'class_label': 'table column', 'bbox': column['pdf_column_bbox']} table_boxes.append(dict_entry) # Crop table_bbox = table_dict['pdf_table_bbox'] # Convert to image coordinates crop_bbox = [int(round(scale * elem)) for elem in table_bbox] split = table_dict['split'] if split == 'val' or split == 'test': padding = args.test_padding else: padding = args.train_padding # Pad crop_bbox = [crop_bbox[0]-padding, crop_bbox[1]-padding, crop_bbox[2]+padding, crop_bbox[3]+padding] # Keep within image crop_bbox = [max(0, crop_bbox[0]), max(0, crop_bbox[1]), min(page_img.size[0], crop_bbox[2]), min(page_img.size[1], crop_bbox[3])] table_img = page_img.crop(crop_bbox) for entry in table_boxes: bbox = entry['bbox'] bbox = [scale*elem for elem in bbox] bbox = [max(0, bbox[0]-crop_bbox[0]-1), max(0, bbox[1]-crop_bbox[1]-1), min(table_img.size[0], bbox[2]-crop_bbox[0]-1), min(table_img.size[1], bbox[3]-crop_bbox[1]-1)] entry['bbox'] = bbox # Initialize PASCAL VOC XML table_image_filename = document_id + "_table_" + str(table_num) + ".jpg" table_image_filepath = os.path.join(output_structure_directory, "images", table_image_filename) table_annotation = create_pascal_voc_page_element(table_image_filename, table_img.width, table_img.height, database="SciTSR.c-Structure") table_img_bbox = [0, 0, table_img.width, table_img.height] try: if not all([is_good_bbox(entry['bbox'], table_img_bbox) for entry in table_boxes]): raise Exception("At least one bounding box has non-positive area or is outside of image") for entry in table_boxes: bbox = entry['bbox'] # Add to PASCAl VOC element = create_pascal_voc_object_element(entry['class_label'], entry['bbox'], [0, 0, table_img.size[0], table_img.size[1]], output_image_max_dim=max(table_img.size)) table_annotation.append(element) xml_filename = table_image_filename.replace(".jpg", ".xml") xml_filepath = os.path.join(output_structure_directory, split, xml_filename) # Table words # output_table_words_directory tokens = [] for word_num, word in enumerate(doc[page_num].get_text_words()): token = {} token['flags'] = 0 token['span_num'] = word_num token['line_num'] = 0 token['block_num'] = 0 bbox = [round(scale * v, 5) for v in word[:4]] if overlap(bbox, crop_bbox) > 0.75: bbox = [max(0, bbox[0]-crop_bbox[0]-1), max(0, bbox[1]-crop_bbox[1]-1), min(table_img.size[0], bbox[2]-crop_bbox[0]-1), min(table_img.size[1], bbox[3]-crop_bbox[1]-1)] if Rect(bbox).get_area() > 0: token['bbox'] = bbox token['text'] = word[4] tokens.append(token) else: print("REMOVED BAD TABLE WORD") words_save_filepath = os.path.join(output_table_words_directory, table_image_filename.replace(".jpg", "_words.json")) # Save everything table_img.save(table_image_filepath) save_xml_pascal_voc(table_annotation, xml_filepath) print(xml_filepath) with open(words_save_filepath, 'w', encoding='utf8') as f: json.dump(tokens, f) except: print("Exception; skipping table") pass del doc # Just removes from memory, not from disk if __name__ == "__main__": main()