Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
tokenizer = AutoTokenizer.from_pretrained("milyiyo/paraphraser-german-mt5-small") | |
model = AutoModelForSeq2SeqLM.from_pretrained("milyiyo/paraphraser-german-mt5-small") | |
def generate_v1(inputs, count): | |
"""Generate text using a Beam Search strategy with repetition penalty.""" | |
model_outputs = model.generate(inputs["input_ids"], | |
early_stopping=True, | |
length_penalty=1.0, | |
max_length=1024, | |
no_repeat_ngram_size=2, | |
num_beams=10, | |
repetition_penalty=3.5, | |
num_return_sequences=count | |
) | |
sentences = [] | |
for output in model_outputs: | |
sentences.append(tokenizer.decode(output, skip_special_tokens=True)) | |
return sentences | |
def generate_v2(inputs, count): | |
"""Generate text using a Beam Search strategy.""" | |
model_outputs = model.generate(inputs["input_ids"], | |
early_stopping=True, | |
length_penalty=2.0, | |
max_length=1024, | |
no_repeat_ngram_size=2, | |
num_beams=5, | |
temperature=1.5, | |
num_return_sequences=count | |
) | |
sentences = [] | |
for output in model_outputs: | |
sentences.append(tokenizer.decode(output, skip_special_tokens=True)) | |
return sentences | |
def generate_v3(inputs, count): | |
"""Generate text using a Diverse Beam Search strategy.""" | |
model_outputs = model.generate(inputs["input_ids"], | |
num_beams=5, | |
max_length=1024, | |
temperature=1.5, | |
num_beam_groups=5, | |
diversity_penalty=2.0, | |
no_repeat_ngram_size=2, | |
early_stopping=True, | |
length_penalty=2.0, | |
num_return_sequences=count) | |
sentences = [] | |
for output in model_outputs: | |
sentences.append(tokenizer.decode(output, skip_special_tokens=True)) | |
return sentences | |
def generate_v4(encoding, count): | |
"""Generate text using a Diverse Beam Search strategy.""" | |
print(encoding) | |
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
print(input_ids) | |
print(attention_masks) | |
outputs = model.generate(input_ids=input_ids, | |
attention_mask=attention_masks, | |
max_length=512, | |
do_sample=True, | |
top_k=120, | |
top_p=0.95, | |
early_stopping=True, | |
num_return_sequences=count) | |
res = [] | |
for output in outputs: | |
line = tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
res.append(line) | |
return res | |
def paraphrase(sentence: str, count: str): | |
p_count = int(count) | |
if p_count <= 0 or len(sentence.strip()) == 0: | |
return {'result': []} | |
sentence_input = sentence | |
text = f"paraphrase: {sentence_input} </s>" | |
# encoding = tokenizer.encode_plus(text, padding=True, return_tensors="pt") | |
encoding = tokenizer(text, return_tensors="pt") | |
# input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
# outputs = model.generate( | |
# input_ids=input_ids, attention_mask=attention_masks, | |
# max_length=512, # 256, | |
# do_sample=True, | |
# top_k=120, | |
# top_p=0.95, | |
# early_stopping=True, | |
# num_return_sequences=p_count | |
# ) | |
# res = [] | |
# for output in outputs: | |
# line = tokenizer.decode( | |
# output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
# res.append(line) | |
# print(res) | |
# | |
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
outputs = model.generate(input_ids=input_ids, | |
attention_mask=attention_masks, | |
max_length=512, | |
do_sample=True, | |
top_k=120, | |
top_p=0.95, | |
early_stopping=True, | |
num_return_sequences=count) | |
result_v4 = [] | |
for output in outputs: | |
line = tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
result_v4.append(line) | |
# | |
return { | |
'result': { | |
# 'generate_v1':generate_v1(encoding, count), | |
# 'generate_v2':generate_v2(encoding, count), | |
# 'generate_v3':generate_v3(encoding, count), | |
'generate_v4':result_v4 | |
} | |
} | |
def paraphrase_dummy(sentence: str, count: str): | |
return {'result': []} | |
iface = gr.Interface(fn=paraphrase, | |
inputs=[ | |
gr.inputs.Textbox(lines=2, placeholder=None, label='Sentence'), | |
gr.inputs.Number(default=3, label='Paraphrases count'), | |
], | |
outputs=[gr.outputs.JSON(label=None)]) | |
iface.launch() |