G.Hemanth Sai
qgen
32d9382
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import random
import numpy as np
from nltk.tokenize import sent_tokenize
class T5_Base:
def __init__(self,path,device,model_max_length):
self.model=T5ForConditionalGeneration.from_pretrained(path)
self.tokenizer=T5Tokenizer.from_pretrained(path,model_max_length=model_max_length)
self.device=torch.device(device)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def preprocess(self,data):
preprocess_text=data.strip().replace('\n',' ')
return preprocess_text
def post_process(self,data):
final=""
for sent in sent_tokenize(data):
sent=sent.capitalize()
final+=sent+" "+sent
return final
def getSummary(self,data):
data=self.preprocess(data)
t5_prepared_Data="summarize: "+data
tokenized_text=self.tokenizer.encode_plus(t5_prepared_Data,max_length=512,pad_to_max_length=False,truncation=True,return_tensors='pt').to(self.device)
input_ids,attention_mask=tokenized_text['input_ids'],tokenized_text['attention_mask']
summary_ids=self.model.generate(input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=3,
num_return_sequences=1,
no_repeat_ngram_size=2,
min_length = 75,
max_length=300)
output=[self.tokenizer.decode(ids,skip_special_tokens=True) for ids in summary_ids]
summary=output[0]
summary=self.post_process(summary)
summary=summary.strip()
return summary