MediaRephraser / app.py
Unityraptor's picture
Update app.py
a7da163 verified
raw
history blame contribute delete
No virus
1.8 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base").to(device)
#translator = pipeline("translation", model="facebook/nllb-200-distilled-600M")
translator = pipeline("translation_en_to_fr", model="google-t5/t5-large")
def paraphrase(
question,
num_beams=5,
num_beam_groups=5,
num_return_sequences=1,
repetition_penalty=10.0,
diversity_penalty=8.0,
no_repeat_ngram_size=2,
temperature=0.7,
max_length=1024
):
input_ids = tokenizer(
f'paraphrase: {question}',
return_tensors="pt", padding="longest",
max_length=max_length,
truncation=True,
).input_ids.to(device)
outputs = model.generate(
input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beams, num_beam_groups=num_beam_groups,
max_length=max_length, diversity_penalty=diversity_penalty
)
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return res
def translate(myinput):
#myout = translator(myinput,src_lang="eng_Latn",tgt_lang="fra_Latn")
myout = translator(myinput,src_lang="en",tgt_lang="fr")
return myout
def predict(mytextInput):
out = translate(paraphrase(mytextInput))
#out = paraphrase(mytextInput)
return out
def greet(name):
return "Hello "+name
iface = gr.Interface(fn=predict,
inputs="textbox",
outputs="text",
)
iface.launch(share=True)