IndoPara-Gen / app /app.py
Wikidepia's picture
Initial commit
654d2fa
raw
history blame
No virus
1.86 kB
import os
from typing import List
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@st.cache(allow_output_mutation=True)
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model, tokenizer
def paraphrase(model, encoding, top_k=120, top_p=0.95, max_len=120) -> List[str]:
outputs = model.generate(
input_ids=encoding["input_ids"],
attention_mask=encoding["attention_mask"],
do_sample=True,
top_k=top_k,
top_p=top_p,
max_length=max_len,
early_stopping=True,
num_return_sequences=5,
)
return [
tokenizer.decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for output in outputs
]
if __name__ == "__main__":
st.header("Indonesian Paraphrase Generation")
user_input = st.text_area("Original Sentence", "", height=30)
# Slider for max_len
st.sidebar.header("Decoding Settings")
max_len = st.sidebar.slider("Max-Length", 0, 512, 256)
top_k = st.sidebar.slider("Top-K", 0, 512, 200)
top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.95)
if st.button("Paraphrase") or user_input:
with st.spinner("T5 is processing your text..."):
model, tokenizer = load_model("Wikidepia/IndoT5-base-paraphrase")
text = "paraphrase: " + user_input + " </s>"
encode_id = tokenizer(text, return_tensors="pt")
outputs = paraphrase(
model, encode_id, top_k=top_k, top_p=top_p, max_len=max_len
)
st.markdown("### Hasil Parafrase")
for i, output in enumerate(outputs):
st.markdown(f"- {output}")