PPTGenerator / src /summarizer.py
Davidsamuel101's picture
Remove error
28ad2d2
raw
history blame contribute delete
No virus
3.49 kB
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from src.text_extractor import TextExtractor
from mdutils.mdutils import MdUtils
import torch
import fitz
import copy
class Summarizer():
def __init__(self, model_name: str):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(self.device)
self.preprocess = TextExtractor()
def extract_text(self, document: object) -> Dict[str, List[Tuple[str, str]]]:
doc = fitz.open(document)
self.filename = doc.name.split('/')[-1].split('.')[0]
font_counts, styles = self.preprocess.get_font_info(doc, granularity=False)
size_tag = self.preprocess.get_font_tags(font_counts, styles)
texts = self.preprocess.assign_tags(doc, size_tag)
slide_content = self.preprocess.get_slides(texts)
return slide_content
def __call__(self, slides: Dict[str, List[Tuple[str, str]]]) -> Dict[str, List[Tuple[str, str]]]:
summarized_slides = copy.deepcopy(slides)
for page, contents in tqdm(summarized_slides.items()):
for idx, (tag, content) in enumerate(contents):
if tag.startswith('p'):
try:
input = self.tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(self.device)
tensor = self.model.generate(**input)
summary = self.tokenizer.batch_decode(tensor, skip_special_tokens=True)[0]
contents[idx] = (tag, summary)
except Exception as e:
print(f"Summarization Fails, Error: {e}")
return summarized_slides
def convert2markdown(self, summarized_slides: Dict[str, List[Tuple[str, str]]], target_path: Optional[str]=None) -> str:
filename = self.filename
if target_path:
filename = target_path
mdFile = MdUtils(file_name=filename)
for k, v in summarized_slides.items():
mdFile.new_line('---\n')
for section in v:
tag = section[0]
content = section[1]
if tag.startswith('h'):
try:
mdFile.new_header(level=int(tag[1]), title=content)
except:
continue
if tag == 'p':
contents = content.split('<n>')
for content in contents:
mdFile.new_line(f"{content}\n")
markdown = mdFile.create_md_file()
return markdown
def remove_leading_empty_lines(self, file_path) -> None:
with open(file_path, 'r') as file:
lines = file.readlines()
non_empty_lines = []
found_first_word = False
for line in lines:
stripped_line = line.strip()
if stripped_line and not found_first_word:
found_first_word = True
if found_first_word or stripped_line:
non_empty_lines.append(line)
with open(file_path, 'w') as file:
file.writelines(non_empty_lines)
return