Spaces:
Runtime error
Runtime error
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
class PegasusParaphraser: | |
""" Pegasus Model for Paraphrase""" | |
def __init__(self, num_return_sequences=3, num_beams=10, max_length=60,temperature=1.5, device="cpu"): | |
self.model_name = "tuner007/pegasus_paraphrase" | |
self.device = device | |
self.model = self.load_model() | |
self.tokenizer = PegasusTokenizer.from_pretrained(self.model_name) | |
self.num_return_sequences = num_return_sequences | |
self.num_beams = num_beams | |
self.max_length=max_length | |
self.temperature=temperature | |
def load_model(self): | |
model = PegasusForConditionalGeneration.from_pretrained(self.model_name).to(self.device) | |
return model | |
def paraphrase(self,input_text ): | |
batch = self.tokenizer( | |
[input_text], | |
truncation=True, | |
padding="longest", | |
max_length=self.max_length, | |
return_tensors="pt", | |
).to(self.device) | |
translated = self.model.generate( | |
**batch, | |
max_length=self.max_length, | |
num_beams=self.num_beams, | |
num_return_sequences=self.num_return_sequences, | |
temperature=self.temperature | |
) | |
tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) | |
return tgt_text | |