| import cv2 |
| import copy |
| import Polygon |
| import numpy as np |
|
|
|
|
| def cal_mean_lr(optimizer): |
| lrs = [group['lr'] for group in optimizer.param_groups] |
| return sum(lrs)/len(lrs) |
|
|
|
|
| def cal_pr_f1(pr_info): |
| precision = pr_info[0] / pr_info[1] |
| recall = pr_info[0] / pr_info[2] |
| f1 = 2*precision*recall/(precision+recall) |
| return precision, recall, f1 |
|
|
|
|
| def match_segment_spans(segments, spans): |
| matched_segments = list() |
| matched_spans = list() |
|
|
| for segment_idx, segment in enumerate(segments): |
| for span_idx, span in enumerate(spans): |
| if span_idx not in matched_spans: |
| if (segment >= span[0]) and (segment < span[1]): |
| matched_segments.append(segment_idx) |
| matched_spans.append(span_idx) |
| |
| return matched_segments, matched_spans |
|
|
|
|
| def find_unmatch_segment_spans(segments, spans): |
| unmatched_segments = list() |
| for segment_idx, segment in enumerate(segments): |
| matched = False |
| for span in spans: |
| if (segment >= span[0]) and (segment < span[1]): |
| matched = True |
| break |
| if not matched: |
| unmatched_segments.append(segment_idx) |
| |
| return unmatched_segments |
|
|
|
|
| def parse_layout(spans, num_rows, num_cols): |
| layout = np.full([num_rows, num_cols], -1, dtype=np.int) |
| cell_count = 0 |
| for x1, y1, x2, y2 in spans: |
| layout[y1:y2+1, x1:x2+1] = cell_count |
| cell_count += 1 |
|
|
| cells_id = list() |
| for row_idx in range(num_rows): |
| for col_idx in range(num_cols): |
| cell_id = layout[row_idx, col_idx] |
| if cell_id in cells_id: |
| layout[row_idx, col_idx] = cells_id.index(cell_id) |
| else: |
| layout[row_idx, col_idx] = len(cells_id) |
| cells_id.append(cell_id) |
| return layout |
|
|
|
|
| def parse_cells(layout, spans, row_segments, col_segments): |
| cells = list() |
| num_cells = np.max(layout) + 1 |
| for cell_id in range(num_cells): |
| cell_positions = np.argwhere(layout == cell_id) |
| y1 = np.min(cell_positions[:, 0]) |
| y2 = np.max(cell_positions[:, 0]) |
| x1 = np.min(cell_positions[:, 1]) |
| x2 = np.max(cell_positions[:, 1]) |
| assert np.all(layout[y1:y2, x1:x2] == cell_id) |
| x1 = col_segments[x1] |
| x2 = col_segments[x2+1] |
| y1 = row_segments[y1] |
| y2 = row_segments[y2+1] |
| cell = dict( |
| segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]] |
| ) |
| cells.append(cell) |
| for span in spans: |
| cell_id = layout[span[1], span[0]] |
| cells[cell_id]['transcript'] = 'None' |
| return cells |
|
|
|
|
| def segmentation_to_bbox(segmentation): |
| x1 = min([min([pt[0] for pt in contour]) for contour in segmentation]) |
| y1 = min([min([pt[1] for pt in contour]) for contour in segmentation]) |
| x2 = max([max([pt[0] for pt in contour]) for contour in segmentation]) |
| y2 = max([max([pt[1] for pt in contour]) for contour in segmentation]) |
| return [x1, y1, x2, y2] |
|
|
|
|
| def extend_cell_lines(cells, lines): |
| def segmentation_to_polygon(segmentation): |
| polygon = Polygon.Polygon() |
| for contour in segmentation: |
| polygon = polygon + Polygon.Polygon(contour) |
| return polygon |
|
|
| lines = copy.deepcopy(lines) |
|
|
| cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells] |
| lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines] |
|
|
| cells_lines = [[] for _ in range(len(cells))] |
|
|
| for line_idx, line_poly in enumerate(lines_poly): |
| if line_poly.area() == 0: |
| continue |
| line_area = line_poly.area() |
| max_overlap = 0 |
| max_overlap_idx = None |
| for cell_idx, cell_poly in enumerate(cells_poly): |
| overlap = (cell_poly & line_poly).area()/line_area |
| if overlap > max_overlap: |
| max_overlap_idx = cell_idx |
| max_overlap = overlap |
| if max_overlap > 0: |
| cells_lines[max_overlap_idx].append(line_idx) |
| lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines] |
| cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines] |
|
|
| for cell, cell_lines in zip(cells, cells_lines): |
| cell['lines_idx'] = cell_lines |
|
|
|
|
| def rerange_layout(table): |
| layout = table['layout'] |
| cells = table['cells'] |
| valid_cells_id = list() |
| for row_idx in range(layout.shape[0]): |
| for col_idx in range(layout.shape[1]): |
| cell_id = layout[row_idx, col_idx] |
| if cell_id not in valid_cells_id: |
| valid_cells_id.append(cell_id) |
| layout[row_idx, col_idx] = valid_cells_id.index(cell_id) |
| cells = [cells[cell_id] for cell_id in valid_cells_id] |
| table['layout'] = layout |
| table['cells'] = cells |
|
|
| def cal_cell_spans(table): |
| layout = table['layout'] |
| num_cells = len(table['cells']) |
| cells_span = list() |
| for cell_id in range(num_cells): |
| cell_positions = np.argwhere(layout == cell_id) |
| y1 = np.min(cell_positions[:, 0]) |
| y2 = np.max(cell_positions[:, 0]) |
| x1 = np.min(cell_positions[:, 1]) |
| x2 = np.max(cell_positions[:, 1]) |
| assert np.all(layout[y1:y2, x1:x2] == cell_id) |
| cells_span.append([x1, y1, x2, y2]) |
| return cells_span |
|
|
|
|
| def remove_repeat_rcs(table): |
| layout = table['layout'] |
| head_rows = table['head_rows'] |
| body_rows = table['body_rows'] |
| while True: |
| num_rows = layout.shape[0] |
| num_cols = layout.shape[1] |
| valid_rows_idx = list() |
| valid_rows_key = list() |
|
|
| for row_idx in range(num_rows): |
| row = layout[row_idx, :] |
| if len(np.unique(row)) == 1 and row_idx in body_rows: |
| continue |
| row_key = ','.join([str(item) for item in row]) |
| if row_key not in valid_rows_key: |
| valid_rows_idx.append(row_idx) |
| valid_rows_key.append(row_key) |
|
|
| valid_cols_idx = list() |
| valid_cols_key = list() |
| for col_idx in range(num_cols): |
| col = layout[:, col_idx] |
| if len(np.unique(col)) == 1: |
| continue |
| col_key = ','.join([str(item) for item in col]) |
| if col_key not in valid_cols_key: |
| valid_cols_idx.append(col_idx) |
| valid_cols_key.append(col_key) |
| if (len(valid_rows_idx) == num_rows) and (len(valid_cols_idx) == num_cols): |
| break |
| layout = layout[valid_rows_idx][:, valid_cols_idx] |
| head_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in head_rows] |
| body_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in body_rows] |
|
|
| table['layout'] = layout |
| table['head_rows'] = head_rows |
| table['body_rows'] = body_rows |
| rerange_layout(table) |
|
|
|
|
| def pred_result_to_table(pred_result): |
| row_segments, col_segments, divide, spans = pred_result |
| num_rows = len(row_segments) - 1 |
| num_cols = len(col_segments) - 1 |
|
|
| layout = parse_layout(spans, num_rows, num_cols) |
| cells = parse_cells(layout, spans, row_segments, col_segments) |
| head_rows = list(range(0, divide)) |
| body_rows = list(range(divide, num_rows)) |
| |
| table = dict( |
| layout=layout, |
| head_rows=head_rows, |
| body_rows=body_rows, |
| cells=cells |
| ) |
|
|
| |
| |
| return table |
|
|
|
|
| def is_simple_table(table): |
| layout = table['layout'] |
| num_rows, num_cols = layout.shape |
| if num_rows * num_cols == len(table['cells']): |
| return True |
| else: |
| return False |
|
|
|
|
| def tensor_to_image(tensor): |
| image = tensor.detach().cpu().numpy() |
| if (len(image.shape) == 3) and (image.shape[0] != 3) and (image.shape[0] != 1): |
| image = np.sqrt(np.sum(np.power(image, 2), axis=0, keepdims=True)) |
| image = 255 * (image-np.min(image))/(np.max(image) - np.min(image)) |
| image = image.astype(np.uint8) |
| if len(image.shape) == 3: |
| image = np.transpose(image, (1, 2, 0)).copy() |
| if image.shape[2] == 1: |
| image = image[:, :, 0] |
| return image |
|
|
|
|
| def visualize_layout(image, table): |
| def draw_segmentation(image, segmentation, color): |
| for contour in segmentation: |
| contour = np.array(contour, dtype=np.int32) |
| image = cv2.polylines(image, [contour], True, color) |
| return image |
| for cell in table['cells']: |
| if 'segmentation' in cell: |
| image = draw_segmentation(image, cell['segmentation'], (255, 0, 0)) |
| return image |
|
|
| virtual_chars = ["<b>", "</b>", "<i>", "</i>", "<sup>", "</sup>", "<sub>", "</sub>", "<overline>", "</overline>", "<underline>", "</underline>", "<strike>", "</strike>"] |
|
|
|
|
| def is_blank(content): |
| global virtual_chars |
| |
| new_content = content |
| for item in virtual_chars: |
| new_content = new_content.replace(item, '') |
| return new_content.strip() == '' |
|
|
|
|
| def filt_content(content, filt_blank=False, filt_virtual=False, filt_pad=False): |
| global virtual_chars |
| if filt_blank: |
| if is_blank(content): |
| content = '' |
|
|
| if filt_virtual: |
| for item in content: |
| content = content.replace(item, '') |
|
|
| if filt_pad: |
| content = content.strip() |
|
|
| return content |
|
|
|
|
| def filt_transcript(html, filt_blank=False, filt_virtual=False, filt_pad=False): |
| start_idx = 0 |
| while '<td' in html[start_idx:]: |
| start_idx = html[start_idx:].index('<td') + start_idx |
| content_start_idx = html[start_idx:].index('>') + 1 + start_idx |
| content_end_idx = html[content_start_idx:].index('</td>') + content_start_idx |
| end_idx = content_end_idx + len('</td>') |
|
|
| content = html[content_start_idx:content_end_idx] |
| content = filt_content(content, filt_blank, filt_virtual, filt_pad) |
| html = html[:content_start_idx] + content + html[content_end_idx:] |
| start_idx = end_idx - (content_end_idx-content_start_idx - len(content)) |
| return html |
|
|