Spaces:
Build error
Build error
import os | |
import re | |
import gc | |
import sys | |
import time | |
import torch | |
from PIL import Image, ImageDraw | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..')) | |
from pdf_extract_kit.utils.data_preprocess import load_pdf | |
from pdf_extract_kit.tasks.ocr.task import OCRTask | |
from pdf_extract_kit.dataset.dataset import MathDataset | |
from pdf_extract_kit.registry.registry import TASK_REGISTRY | |
from pdf_extract_kit.utils.merge_blocks_and_spans import ( | |
fill_spans_in_blocks, | |
fix_block_spans, | |
merge_para_with_text | |
) | |
def latex_rm_whitespace(s: str): | |
"""Remove unnecessary whitespace from LaTeX code. | |
""" | |
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' | |
letter = '[a-zA-Z]' | |
noletter = '[\W_^\d]' | |
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] | |
s = re.sub(text_reg, lambda match: str(names.pop(0)), s) | |
news = s | |
while True: | |
s = news | |
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) | |
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) | |
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) | |
if news == s: | |
break | |
return s | |
def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): | |
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) | |
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) | |
# Create a white background with an additional width and height of 50 | |
crop_new_width = crop_xmax - crop_xmin + padding_x * 2 | |
crop_new_height = crop_ymax - crop_ymin + padding_y * 2 | |
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') | |
# Crop image | |
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) | |
cropped_img = input_pil_img.crop(crop_box) | |
return_image.paste(cropped_img, (padding_x, padding_y)) | |
return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] | |
return return_image, return_list | |
class PDF2MARKDOWN(OCRTask): | |
def __init__(self, layout_model, mfd_model, mfr_model, ocr_model): | |
self.layout_model = layout_model | |
self.mfd_model = mfd_model | |
self.mfr_model = mfr_model | |
self.ocr_model = ocr_model | |
if self.mfr_model is not None: | |
assert self.mfd_model is not None, "formula recognition based on formula detection, mfd_model can not be None." | |
self.mfr_transform = transforms.Compose([self.mfr_model.vis_processor, ]) | |
self.color_palette = { | |
'title': (255, 64, 255), | |
'plain text': (255, 255, 0), | |
'abandon': (0, 255, 255), | |
'figure': (255, 215, 135), | |
'figure_caption': (215, 0, 95), | |
'table': (100, 0, 48), | |
'table_caption': (0, 175, 0), | |
'table_footnote': (95, 0, 95), | |
'isolate_formula': (175, 95, 0), | |
'formula_caption': (95, 95, 0), | |
'inline': (0, 0, 255), | |
'isolated': (0, 255, 0), | |
'text': (255, 0, 0) | |
} | |
def convert_format(self, yolo_res, id_to_names, ): | |
""" | |
convert yolo format to pdf-extract format. | |
""" | |
res_list = [] | |
for xyxy, conf, cla in zip(yolo_res.boxes.xyxy.cpu(), yolo_res.boxes.conf.cpu(), yolo_res.boxes.cls.cpu()): | |
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] | |
new_item = { | |
'category_type': id_to_names[int(cla.item())], | |
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], | |
'score': round(float(conf.item()), 2), | |
} | |
res_list.append(new_item) | |
return res_list | |
def process_single_pdf(self, image_list): | |
"""predict on one image, reture text detection and recognition results. | |
Args: | |
image_list: List[PIL.Image.Image] | |
Returns: | |
List[dict]: list of PDF extract results | |
Return example: | |
[ | |
{ | |
"layout_dets": [ | |
{ | |
"category_type": "text", | |
"poly": [ | |
380.6792698635707, | |
159.85058512958923, | |
765.1419999999998, | |
159.85058512958923, | |
765.1419999999998, | |
192.51073013642917, | |
380.6792698635707, | |
192.51073013642917 | |
], | |
"text": "this is an example text", | |
"score": 0.97 | |
}, | |
... | |
], | |
"page_info": { | |
"page_no": 0, | |
"height": 2339, | |
"width": 1654, | |
} | |
}, | |
... | |
] | |
""" | |
pdf_extract_res = [] | |
mf_image_list = [] | |
latex_filling_list = [] | |
for idx, image in enumerate(image_list): | |
img_W, img_H = image.size | |
if self.layout_model is not None: | |
ori_layout_res = self.layout_model.predict([image], "")[0] | |
layout_res = self.convert_format(ori_layout_res, self.layout_model.id_to_names) | |
else: | |
layout_res = [] | |
single_page_res = {'layout_dets': layout_res} | |
single_page_res['page_info'] = dict( | |
page_no = idx, | |
height = img_H, | |
width = img_W | |
) | |
if self.mfd_model is not None: | |
mfd_res = self.mfd_model.predict([image], "")[0] | |
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): | |
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] | |
new_item = { | |
'category_type': self.mfd_model.id_to_names[int(cla.item())], | |
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], | |
'score': round(float(conf.item()), 2), | |
'latex': '', | |
} | |
single_page_res['layout_dets'].append(new_item) | |
if self.mfr_model is not None: | |
latex_filling_list.append(new_item) | |
bbox_img = image.crop((xmin, ymin, xmax, ymax)) | |
mf_image_list.append(bbox_img) | |
pdf_extract_res.append(single_page_res) | |
del mfd_res | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Formula recognition, collect all formula images in whole pdf file, then batch infer them. | |
if self.mfr_model is not None: | |
a = time.time() | |
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) | |
dataloader = DataLoader(dataset, batch_size=self.mfr_model.batch_size, num_workers=0) | |
mfr_res = [] | |
for imgs in dataloader: | |
imgs = imgs.to(self.mfr_model.device) | |
output = self.mfr_model.model.generate({'image': imgs}) | |
mfr_res.extend(output['pred_str']) | |
for res, latex in zip(latex_filling_list, mfr_res): | |
res['latex'] = latex_rm_whitespace(latex) | |
b = time.time() | |
print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) | |
# ocr_res = self.ocr_model.predict(image) | |
# ocr and table recognition | |
for idx, image in enumerate(image_list): | |
layout_res = pdf_extract_res[idx]['layout_dets'] | |
pil_img = image.copy() | |
ocr_res_list = [] | |
table_res_list = [] | |
single_page_mfdetrec_res = [] | |
for res in layout_res: | |
if res['category_type'] in self.mfd_model.id_to_names.values(): | |
single_page_mfdetrec_res.append({ | |
"bbox": [int(res['poly'][0]), int(res['poly'][1]), | |
int(res['poly'][4]), int(res['poly'][5])], | |
}) | |
elif res['category_type'] in [self.layout_model.id_to_names[cid] for cid in [0, 1, 2, 4, 6, 7]]: | |
ocr_res_list.append(res) | |
elif res['category_type'] in [self.layout_model.id_to_names[5]]: | |
table_res_list.append(res) | |
ocr_start = time.time() | |
# Process each area that requires OCR processing | |
for res in ocr_res_list: | |
new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25) | |
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list | |
# Adjust the coordinates of the formula area | |
adjusted_mfdetrec_res = [] | |
for mf_res in single_page_mfdetrec_res: | |
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] | |
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area | |
x0 = mf_xmin - xmin + paste_x | |
y0 = mf_ymin - ymin + paste_y | |
x1 = mf_xmax - xmin + paste_x | |
y1 = mf_ymax - ymin + paste_y | |
# Filter formula blocks outside the graph | |
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): | |
continue | |
else: | |
adjusted_mfdetrec_res.append({ | |
"bbox": [x0, y0, x1, y1], | |
}) | |
# OCR recognition | |
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] | |
# Integration results | |
if ocr_res: | |
for box_ocr_res in ocr_res: | |
p1, p2, p3, p4 = box_ocr_res[0] | |
text, score = box_ocr_res[1] | |
# Convert the coordinates back to the original coordinate system | |
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] | |
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] | |
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] | |
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] | |
layout_res.append({ | |
'category_type': 'text', | |
'poly': p1 + p2 + p3 + p4, | |
'score': round(score, 2), | |
'text': text, | |
}) | |
ocr_cost = round(time.time() - ocr_start, 2) | |
print(f"ocr cost: {ocr_cost}") | |
return pdf_extract_res | |
def order_blocks(self, blocks): | |
def calculate_oder(poly): | |
xmin, ymin, _, _, xmax, ymax, _, _ = poly | |
return ymin*3000 + xmin | |
return sorted(blocks, key=lambda item: calculate_oder(item['poly'])) | |
def convert2md(self, extract_res): | |
blocks = [] | |
spans = [] | |
for item in extract_res['layout_dets']: | |
if item['category_type'] in ['inline', 'text', 'isolated']: | |
text_key = 'text' if item['category_type'] == 'text' else 'latex' | |
xmin, ymin, _, _, xmax, ymax, _, _ = item['poly'] | |
spans.append( | |
{ | |
"type": item['category_type'], | |
"bbox": [xmin, ymin, xmax, ymax], | |
"content": item[text_key] | |
} | |
) | |
if item['category_type'] == "isolated": | |
item['category_type'] = "isolate_formula" | |
blocks.append(item) | |
else: | |
blocks.append(item) | |
blocks_types = ["title", "plain text", "figure_caption", "table_caption", "table_footnote", "isolate_formula", "formula_caption"] | |
need_fix_bbox = [] | |
final_block = [] | |
for block in blocks: | |
block_type = block["category_type"] | |
if block_type in blocks_types: | |
need_fix_bbox.append(block) | |
else: | |
final_block.append(block) | |
block_with_spans, spans = fill_spans_in_blocks(need_fix_bbox, spans, 0.6) | |
fix_blocks = fix_block_spans(block_with_spans) | |
for para_block in fix_blocks: | |
result = merge_para_with_text(para_block) | |
if para_block['type'] == "isolate_formula": | |
para_block['saved_info']['latex'] = result | |
else: | |
para_block['saved_info']['text'] = result | |
final_block.append(para_block['saved_info']) | |
final_block = self.order_blocks(final_block) | |
md_text = "" | |
for block in final_block: | |
if block['category_type'] == "title": | |
md_text += "\n# "+block['text'] +"\n" | |
elif block['category_type'] in ["isolate_formula"]: | |
md_text += "\n"+block['latex']+"\n" | |
elif block['category_type'] in ["plain text", "figure_caption", "table_caption"]: | |
md_text += " "+block['text']+" " | |
elif block['category_type'] in ["figure", "table"]: | |
continue | |
else: | |
continue | |
return md_text | |
def process(self, input_path, save_dir=None, visualize=False, merge2markdown=False): | |
file_list = self.prepare_input_files(input_path) | |
res_list = [] | |
for fpath in file_list: | |
basename = os.path.basename(fpath)[:-4] | |
if fpath.endswith(".pdf") or fpath.endswith(".PDF"): | |
images = load_pdf(fpath) | |
else: | |
images = [Image.open(fpath)] | |
pdf_extract_res = self.process_single_pdf(images) | |
res_list.append(pdf_extract_res) | |
if save_dir: | |
os.makedirs(save_dir, exist_ok=True) | |
self.save_json_result(pdf_extract_res, os.path.join(save_dir, f"{basename}.json")) | |
if merge2markdown: | |
md_content = [] | |
for extract_res in pdf_extract_res: | |
md_text = self.convert2md(extract_res) | |
md_content.append(md_text) | |
with open(os.path.join(save_dir, f"{basename}.md"), "w") as f: | |
f.write("\n\n".join(md_content)) | |
if visualize: | |
for image, page_res in zip(images, pdf_extract_res): | |
self.visualize_image(image, page_res['layout_dets'], cate2color=self.color_palette) | |
if fpath.endswith(".pdf") or fpath.endswith(".PDF"): | |
first_page = images.pop(0) | |
first_page.save(os.path.join(save_dir, f'{basename}.pdf'), 'PDF', resolution=100, save_all=True, append_images=images) | |
else: | |
images[0].save(os.path.join(save_dir, f"{basename}.png")) | |
return res_list | |