Paraphrase / app.py
jaimin's picture
Update app.py
a671dff
import gradio as gr
from gradio.mix import Parallel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
from transformers import T5TokenizerFast, T5ForConditionalGeneration
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import pytorch_lightning as pl
import torch
import itertools
import random
import nltk
from nltk.tokenize import sent_tokenize
import requests
import json
nltk.download('punkt')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer1 = PegasusTokenizer.from_pretrained('jaimin/pegasus')
model1 = PegasusForConditionalGeneration.from_pretrained('jaimin/pegasus').to(device)
def listToDict(lst):
op = { i : lst[i] for i in range(0, len(lst) ) }
return op
def get_paraphrases_pytorchlight(text, n_predictions=3, top_k=50, max_length=256, device="cpu"):
para = []
sentence = text
for sent in sent_tokenize(sentence):
text = "paraphrase: "+sent + " </s>"
encoding = tokenizer1.encode_plus(text, padding=True, return_tensors="pt", truncation=True)
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
model_output = model1.generate(
input_ids=input_ids,attention_mask=attention_masks,
max_length = 512,
early_stopping=True,
num_beams=15,
num_beam_groups = 3,
num_return_sequences=n_predictions,
diversity_penalty = 0.70,
temperature=0.7,
no_repeat_ngram_size=2 )
outputs = []
for output in model_output:
generated_sent = tokenizer1.decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
if (
generated_sent.lower() != sentence.lower()
and generated_sent not in outputs
):
outputs.append(generated_sent)
para.append(outputs)
a = list(itertools.product(*para))
random.shuffle(a)
l=[]
for i in range(len(a)):
l.append(" ".join(a[i]))
final_output=[]
for i in range(len(l)):
final_output.append("* " + l[i] + ".")
paraphrase = "\n".join(final_output)
return paraphrase
iface = gr.Interface(fn=get_paraphrases_pytorchlight, inputs=[gr.inputs.Textbox(lines=5)],outputs="text")
#iface1 = gr.Interface(fn=get_paraphrases_pytorchlight, inputs=[gr.inputs.Textbox(lines=5)],outputs="text")
iface.launch(enable_queue = True)