MCQ-Rake / summarizer /summarizer.py
mikymatt's picture
fix
aa19b06
import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
import random
import numpy as np
import nltk
nltk.download('punkt')
nltk.download('brown')
nltk.download('wordnet')
from nltk.corpus import wordnet as wn
from nltk.tokenize import sent_tokenize
import locale
locale.getpreferredencoding = lambda: "UTF-8"
class Summarizer:
def __init__(self):
self.model = T5ForConditionalGeneration.from_pretrained('t5-base')
self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.set_seed(42)
def set_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def postprocesstext(self, content):
final=""
for sent in sent_tokenize(content):
sent = sent.capitalize()
final = final +" "+sent
return final
def summarizer(self, text, model = None, tokenizer = None):
if(model == None):
model = self.model
if(tokenizer == None):
tokenizer = self.tokenizer
text = text.strip().replace("\n"," ")
text = "summarize: "+text
max_len = 512
encoding = tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt").to(self.device)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=3,
num_return_sequences=1,
no_repeat_ngram_size=2,
min_length = 75,
max_length=300
)
dec = [tokenizer.decode(ids,skip_special_tokens=True) for ids in outs]
summary = dec[0]
summary = self.postprocesstext(summary)
summary= summary.strip()
return summary