|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import streamlit as st |
|
st.title("Paraphrase") |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") |
|
|
|
return model, tokenizer |
|
|
|
model, tokenizer = get_model() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
temp = st.sidebar.slider("Temperature", 0.7, 1.5) |
|
number_of_outputs = st.sidebar.slider("Number of Outputs", 1, 10) |
|
|
|
def translate_to_english(model, tokenizer, text): |
|
translated_text = [] |
|
text = text + " </s>" |
|
encoding = tokenizer.encode_plus(text,pad_to_max_length=True, return_tensors="pt") |
|
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device) |
|
beam_outputs = model.generate( |
|
input_ids=input_ids, attention_mask=attention_masks, |
|
do_sample=True, |
|
max_length=256, |
|
temperature = temp, |
|
top_k=120, |
|
top_p=0.98, |
|
early_stopping=True, |
|
num_return_sequences=number_of_outputs, |
|
) |
|
for beam_output in beam_outputs: |
|
sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True) |
|
print(sent) |
|
translated_text.append(sent) |
|
return translated_text |
|
|
|
text = st.text_input("Okay") |
|
st.text("What you wrote: ") |
|
st.write(text) |
|
st.text("Output: ") |
|
if text: |
|
translated_text = translate_to_english(model, tokenizer, text) |
|
st.write(translated_text if translated_text else "No translation found") |
|
|