Spaces:
Running
Running
import nltk | |
from nltk.tokenize import sent_tokenize | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
import torch | |
import src.exception.Exception as ExceptionCustom | |
METHOD = "PARAPHRASE" | |
def paraphraseParaphraseMethod(requestValue : str, model: str = 'tuner007/pegasus_paraphrase'): | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
exception = "" | |
result_value = "" | |
tokenizer = PegasusTokenizer.from_pretrained(model) | |
model = PegasusForConditionalGeneration.from_pretrained(model) | |
exception = ExceptionCustom.checkForException(requestValue, METHOD) | |
if exception != "": | |
return "", exception | |
tokenized_sent_list = sent_tokenize(requestValue) | |
for SENTENCE in tokenized_sent_list: | |
text = "paraphrase: " + SENTENCE | |
encoding = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
beam_outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_masks, | |
max_length=512, | |
num_beams=5, | |
length_penalty=0.8, | |
early_stopping=True | |
) | |
for beam_output in beam_outputs: | |
text_para = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
if SENTENCE.lower().strip() != text_para.lower().strip(): | |
result_value += text_para + " " | |
break | |
return result_value, "" | |