PPTGenerator / src /summarizer.py
Davidsamuel101's picture
Tidy Up Code
9f2dd14
raw
history blame
2.86 kB
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from src.text_extractor import TextExtractor
from mdutils.mdutils import MdUtils
import torch
import fitz
import copy
class Summarizer():
def __init__(self, model_name: str):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(self.device)
self.preprocess = TextExtractor()
def extract_text(self, document: object) -> Dict[str, List[Tuple[str, str]]]:
doc = fitz.open(document)
self.filename = doc.name.split('/')[-1].split('.')[0]
font_counts, styles = self.preprocess.get_font_info(doc, granularity=False)
size_tag = self.preprocess.get_font_tags(font_counts, styles)
texts = self.preprocess.assign_tags(doc, size_tag)
slide_content = self.preprocess.get_slides(texts)
return slide_content
def __call__(self, slides: Dict[str, List[Tuple[str, str]]]) -> Dict[str, List[Tuple[str, str]]]:
summarized_slides = copy.deepcopy(slides)
for page, contents in tqdm(summarized_slides.items()):
for idx, (tag, content) in enumerate(contents):
if tag.startswith('p'):
try:
input = self.tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(self.device)
tensor = self.model.generate(**input)
summary = self.tokenizer.batch_decode(tensor, skip_special_tokens=True)[0]
contents[idx] = (tag, summary)
except Exception as e:
print(f"Summarization Fails, Error: {e}")
return summarized_slides
def convert2markdown(self, summarized_slides: Dict[str, List[Tuple[str, str]]], target_path: Optional[str]=None) -> str:
filename = self.filename
if target_path:
filename = target_path
mdFile = MdUtils(file_name=filename, title=f'{self.filename} Presentation')
for k, v in summarized_slides.items():
mdFile.new_line('---\n')
for section in v:
tag = section[0]
content = section[1]
if tag.startswith('h'):
mdFile.new_header(level=int(tag[1]), title=content)
if tag == 'p':
contents = content.split('<n>')
for content in contents:
mdFile.new_line(f"{content}\n")
markdown = mdFile.create_md_file()
return markdown