# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import re from collections import Counter import numpy as np from huggingface_hub import snapshot_download from api.utils.file_utils import get_project_base_directory from rag.nlp import rag_tokenizer from .recognizer import Recognizer class TableStructureRecognizer(Recognizer): labels = [ "table", "table column", "table row", "table column header", "table projected row header", "table spanning cell", ] def __init__(self): try: super().__init__(self.labels, "tsr", os.path.join( get_project_base_directory(), "rag/res/deepdoc")) except Exception as e: super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False)) def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) res = [] # align left&right for rows, align top&bottom for columns for tbl in tbls: lts = [{"label": b["type"], "score": b["score"], "x0": b["bbox"][0], "x1": b["bbox"][2], "top": b["bbox"][1], "bottom": b["bbox"][-1] } for b in tbl] if not lts: continue left = [b["x0"] for b in lts if b["label"].find( "row") > 0 or b["label"].find("header") > 0] right = [b["x1"] for b in lts if b["label"].find( "row") > 0 or b["label"].find("header") > 0] if not left: continue left = np.mean(left) if len(left) > 4 else np.min(left) right = np.mean(right) if len(right) > 4 else np.max(right) for b in lts: if b["label"].find("row") > 0 or b["label"].find("header") > 0: if b["x0"] > left: b["x0"] = left if b["x1"] < right: b["x1"] = right top = [b["top"] for b in lts if b["label"] == "table column"] bottom = [b["bottom"] for b in lts if b["label"] == "table column"] if not top: res.append(lts) continue top = np.median(top) if len(top) > 4 else np.min(top) bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) for b in lts: if b["label"] == "table column": if b["top"] > top: b["top"] = top if b["bottom"] < bottom: b["bottom"] = bottom res.append(lts) return res @staticmethod def is_caption(bx): patt = [ r"[图表]+[ 0-9::]{2,}" ] if any([re.match(p, bx["text"].strip()) for p in patt]) \ or bx["layout_type"].find("caption") >= 0: return True return False @staticmethod def blockType(b): patt = [ ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), (r"^(20|19)[0-9]{2}年$", "Dt"), (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), (r"^第*[一二三四1-4]季度$", "Dt"), (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), ("^[0-9.,+%/ -]+$", "Nu"), (r"^[0-9A-Z/\._~-]+$", "Ca"), (r"^[A-Z]*[a-z' -]+$", "En"), (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), (r"^.{1}$", "Sg") ] for p, n in patt: if re.search(p, b["text"].strip()): return n tks = [t for t in rag_tokenizer.tokenize(b["text"]).split(" ") if len(t) > 1] if len(tks) > 3: if len(tks) < 12: return "Tx" else: return "Lx" if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr": return "Nr" return "Ot" @staticmethod def construct_table(boxes, is_english=False, html=False): cap = "" i = 0 while i < len(boxes): if TableStructureRecognizer.is_caption(boxes[i]): if is_english: cap + " " cap += boxes[i]["text"] boxes.pop(i) i -= 1 i += 1 if not boxes: return [] for b in boxes: b["btype"] = TableStructureRecognizer.blockType(b) max_type = Counter([b["btype"] for b in boxes]).items() max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" logging.debug("MAXTYPE: " + max_type) rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] rowh = np.min(rowh) if rowh else 0 boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) #for b in boxes:print(b) boxes[0]["rn"] = 0 rows = [[boxes[0]]] btm = boxes[0]["bottom"] for b in boxes[1:]: b["rn"] = len(rows) - 1 lst_r = rows[-1] if lst_r[-1].get("R", "") != b.get("R", "") \ or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") ): # new row btm = b["bottom"] b["rn"] += 1 rows.append([b]) continue btm = (btm + b["bottom"]) / 2. rows[-1].append(b) colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] colwm = np.min(colwm) if colwm else 0 crosspage = len(set([b["page_number"] for b in boxes])) > 1 if crosspage: boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False) else: boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) boxes[0]["cn"] = 0 cols = [[boxes[0]]] right = boxes[0]["x1"] for b in boxes[1:]: b["cn"] = len(cols) - 1 lst_c = cols[-1] if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ "page_number"]) \ or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col right = b["x1"] b["cn"] += 1 cols.append([b]) continue right = (right + b["x1"]) / 2. cols[-1].append(b) tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] for b in boxes: tbl[b["rn"]][b["cn"]].append(b) if len(rows) >= 4: # remove single in column j = 0 while j < len(tbl[0]): e, ii = 0, 0 for i in range(len(tbl)): if tbl[i][j]: e += 1 ii = i if e > 1: break if e > 1: j += 1 continue f = (j > 0 and tbl[ii][j - 1] and tbl[ii] [j - 1][0].get("text")) or j == 0 ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) if f and ff: j += 1 continue bx = tbl[ii][j][0] logging.debug("Relocate column single: " + bx["text"]) # j column only has one value left, right = 100000, 100000 if j > 0 and not f: for i in range(len(tbl)): if tbl[i][j - 1]: left = min(left, np.min( [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) if j + 1 < len(tbl[0]) and not ff: for i in range(len(tbl)): if tbl[i][j + 1]: right = min(right, np.min( [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) assert left < 100000 or right < 100000 if left < right: for jj in range(j, len(tbl[0])): for i in range(len(tbl)): for a in tbl[i][jj]: a["cn"] -= 1 if tbl[ii][j - 1]: tbl[ii][j - 1].extend(tbl[ii][j]) else: tbl[ii][j - 1] = tbl[ii][j] for i in range(len(tbl)): tbl[i].pop(j) else: for jj in range(j + 1, len(tbl[0])): for i in range(len(tbl)): for a in tbl[i][jj]: a["cn"] -= 1 if tbl[ii][j + 1]: tbl[ii][j + 1].extend(tbl[ii][j]) else: tbl[ii][j + 1] = tbl[ii][j] for i in range(len(tbl)): tbl[i].pop(j) cols.pop(j) assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( len(cols), len(tbl[0])) if len(cols) >= 4: # remove single in row i = 0 while i < len(tbl): e, jj = 0, 0 for j in range(len(tbl[i])): if tbl[i][j]: e += 1 jj = j if e > 1: break if e > 1: i += 1 continue f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] [jj][0].get("text")) or i == 0 ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] [jj][0].get("text")) or i + 1 >= len(tbl) if f and ff: i += 1 continue bx = tbl[i][jj][0] logging.debug("Relocate row single: " + bx["text"]) # i row only has one value up, down = 100000, 100000 if i > 0 and not f: for j in range(len(tbl[i - 1])): if tbl[i - 1][j]: up = min(up, np.min( [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) if i + 1 < len(tbl) and not ff: for j in range(len(tbl[i + 1])): if tbl[i + 1][j]: down = min(down, np.min( [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) assert up < 100000 or down < 100000 if up < down: for ii in range(i, len(tbl)): for j in range(len(tbl[ii])): for a in tbl[ii][j]: a["rn"] -= 1 if tbl[i - 1][jj]: tbl[i - 1][jj].extend(tbl[i][jj]) else: tbl[i - 1][jj] = tbl[i][jj] tbl.pop(i) else: for ii in range(i + 1, len(tbl)): for j in range(len(tbl[ii])): for a in tbl[ii][j]: a["rn"] -= 1 if tbl[i + 1][jj]: tbl[i + 1][jj].extend(tbl[i][jj]) else: tbl[i + 1][jj] = tbl[i][jj] tbl.pop(i) rows.pop(i) # which rows are headers hdset = set([]) for i in range(len(tbl)): cnt, h = 0, 0 for j, arr in enumerate(tbl[i]): if not arr: continue cnt += 1 if max_type == "Nu" and arr[0]["btype"] == "Nu": continue if any([a.get("H") for a in arr]) \ or (max_type == "Nu" and arr[0]["btype"] != "Nu"): h += 1 if h / cnt > 0.5: hdset.add(i) if html: return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True) ) return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english) @staticmethod def __html_table(cap, hdset, tbl): # constrcut HTML html = "" if cap: html += f"" for i in range(len(tbl)): row = "" txts = [] for j, arr in enumerate(tbl[i]): if arr is None: continue if not arr: row += "" if i not in hdset else "" continue txt = "" if arr: h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)]) txts.append(txt) sp = "" if arr[0].get("colspan"): sp = "colspan={}".format(arr[0]["colspan"]) if arr[0].get("rowspan"): sp += " rowspan={}".format(arr[0]["rowspan"]) if i in hdset: row += f"" else: row += f"" if i in hdset: if all([t in hdset for t in txts]): continue for t in txts: hdset.add(t) if row != "": row += "" else: row = "" html += "\n" + row html += "\n
{cap}
" + txt + "" + txt + "
" return html @staticmethod def __desc_table(cap, hdr_rowno, tbl, is_english): # get text of every colomn in header row to become header text clmno = len(tbl[0]) rowno = len(tbl) headers = {} hdrset = set() lst_hdr = [] de = "的" if not is_english else " for " for r in sorted(list(hdr_rowno)): headers[r] = ["" for _ in range(clmno)] for i in range(clmno): if not tbl[r][i]: continue txt = " ".join([a["text"].strip() for a in tbl[r][i]]) headers[r][i] = txt hdrset.add(txt) if all([not t for t in headers[r]]): del headers[r] hdr_rowno.remove(r) continue for j in range(clmno): if headers[r][j]: continue if j >= len(lst_hdr): break headers[r][j] = lst_hdr[j] lst_hdr = headers[r] for i in range(rowno): if i not in hdr_rowno: continue for j in range(i + 1, rowno): if j not in hdr_rowno: break for k in range(clmno): if not headers[j - 1][k]: continue if headers[j][k].find(headers[j - 1][k]) >= 0: continue if len(headers[j][k]) > len(headers[j - 1][k]): headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k] else: headers[j][k] = headers[j - 1][k] \ + (de if headers[j - 1][k] else "") \ + headers[j][k] logging.debug( f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") row_txt = [] for i in range(rowno): if i in hdr_rowno: continue rtxt = [] def append(delimer): nonlocal rtxt, row_txt rtxt = delimer.join(rtxt) if row_txt and len(row_txt[-1]) + len(rtxt) < 64: row_txt[-1] += "\n" + rtxt else: row_txt.append(rtxt) r = 0 if len(headers.items()): _arr = [(i - r, r) for r, _ in headers.items() if r < i] if _arr: _, r = min(_arr, key=lambda x: x[0]) if r not in headers and clmno <= 2: for j in range(clmno): if not tbl[i][j]: continue txt = "".join([a["text"].strip() for a in tbl[i][j]]) if txt: rtxt.append(txt) if rtxt: append(":") continue for j in range(clmno): if not tbl[i][j]: continue txt = "".join([a["text"].strip() for a in tbl[i][j]]) if not txt: continue ctt = headers[r][j] if r in headers else "" if ctt: ctt += ":" ctt += txt if ctt: rtxt.append(ctt) if rtxt: row_txt.append("; ".join(rtxt)) if cap: if is_english: from_ = " in " else: from_ = "来自" row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] return row_txt @staticmethod def __cal_spans(boxes, rows, cols, tbl, html=True): # caculate span clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols] crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols] rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows] rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows] for b in boxes: if "SP" not in b: continue b["colspan"] = [b["cn"]] b["rowspan"] = [b["rn"]] # col span for j in range(0, len(clft)): if j == b["cn"]: continue if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: continue if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: continue b["colspan"].append(j) # row span for j in range(0, len(rtop)): if j == b["rn"]: continue if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: continue if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: continue b["rowspan"].append(j) def join(arr): if not arr: return "" return "".join([t["text"] for t in arr]) # rm the spaning cells for i in range(len(tbl)): for j, arr in enumerate(tbl[i]): if not arr: continue if all(["rowspan" not in a and "colspan" not in a for a in arr]): continue rowspan, colspan = [], [] for a in arr: if isinstance(a.get("rowspan", 0), list): rowspan.extend(a["rowspan"]) if isinstance(a.get("colspan", 0), list): colspan.extend(a["colspan"]) rowspan, colspan = set(rowspan), set(colspan) if len(rowspan) < 2 and len(colspan) < 2: for a in arr: if "rowspan" in a: del a["rowspan"] if "colspan" in a: del a["colspan"] continue rowspan, colspan = sorted(rowspan), sorted(colspan) rowspan = list(range(rowspan[0], rowspan[-1] + 1)) colspan = list(range(colspan[0], colspan[-1] + 1)) assert i in rowspan, rowspan assert j in colspan, colspan arr = [] for r in rowspan: for c in colspan: arr_txt = join(arr) if tbl[r][c] and join(tbl[r][c]) != arr_txt: arr.extend(tbl[r][c]) tbl[r][c] = None if html else arr for a in arr: if len(rowspan) > 1: a["rowspan"] = len(rowspan) elif "rowspan" in a: del a["rowspan"] if len(colspan) > 1: a["colspan"] = len(colspan) elif "colspan" in a: del a["colspan"] tbl[rowspan[0]][colspan[0]] = arr return tbl