from text_extractor import TextExtractor from tqdm import tqdm from transformers import PegasusForConditionalGeneration, PegasusTokenizer from transformers import pipeline from mdutils.mdutils import MdUtils from pathlib import Path import gradio as gr import fitz import torch import copy import os FILENAME = "" preprocess = TextExtractor() model_name = "sshleifer/distill-pegasus-cnn-16-4" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = PegasusTokenizer.from_pretrained(model_name, max_length=500) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) def summarize(slides): generated_slides = copy.deepcopy(slides) for page, contents in tqdm(generated_slides.items()): for idx, (tag, content) in enumerate(contents): if tag.startswith('p'): try: input = tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(device) tensor = model.generate(**input) summary = tokenizer.batch_decode(tensor, skip_special_tokens=True)[0] contents[idx] = (tag, summary) except Exception as e: print(e) print("Summarization Fails") return generated_slides def convert2markdown(generate_slides): # save_path = f"tmp/{FILENAME}" mdFile = MdUtils(file_name=FILENAME, title=f'{FILENAME} Presentation') for k, v in generate_slides.items(): mdFile.new_paragraph('---') for section in v: tag = section[0] content = section[1] if tag.startswith('h'): mdFile.new_header(level=int(tag[1]), title=content) if tag == 'p': contents = content.split('') for content in contents: mdFile.new_paragraph(content) mdFile.create_md_file() return f"{FILENAME}.md" def inference(document): global FILENAME doc = fitz.open(document) FILENAME = Path(doc.name).stem font_counts, styles = preprocess.get_font_info(doc, granularity=False) size_tag = preprocess.get_font_tags(font_counts, styles) texts = preprocess.assign_tags(doc, size_tag) slides = preprocess.get_slides(texts) generated_slides = summarize(slides) markdown_path = convert2markdown(generated_slides) with open(markdown_path, 'rt') as f: markdown_str = f.read() return markdown_str with gr.Blocks() as demo: inp = gr.File( file_types=['pdf']) out = gr.Textbox(label="Markdown Content") inference_btn = gr.Button("Summarized PDF") inference_btn.click(fn=inference, inputs=inp, outputs=out, show_progress=True, api_name="summarize") demo.launch()