text_generation / app.py
azizmma's picture
Update app.py
405417e
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_names = {
"gpt2-medium":"gpt2-medium",
"eluether1.3b":"EleutherAI/gpt-neo-1.3B",
}
def generate_texts(pipeline, input_text, **generator_args):
output_sequences = pipeline(
input_text, **generator_args
)
return output_sequences
@st.cache(allow_output_mutation=True)
def load_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
@st.cache(allow_output_mutation=True)
def load_model(model_name,eos_token_id):
model = AutoModelForCausalLM.from_pretrained(model_name,pad_token_id=eos_token_id)
return model
tokenizers = {model_name:load_tokenizer(model_name) for model_name in model_names.values()}
pipelines = [pipeline('text-generation', model=load_model(model_name,tokenizers[model_name].eos_token_id), tokenizer=tokenizers[model_name]) for model_name in model_names.values()]
print("loaded the pipelines")
default_value = "But not just any roof cleaning will do."
#prompts
st.title("Text Extension or Generation")
st.write("Command + Enter for generation...")
sent = st.text_area("Text", default_value, height = 150)
#generate_button = st.button("Generate")
model_index = st.sidebar.selectbox("Select Model", range(len(model_names)), format_func=lambda x: list(model_names.keys())[x])
max_length = st.sidebar.slider("Max Length", value = 100, min_value = 30, max_value=256)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
num_return_sequences = st.sidebar.slider("Num Return Sequences", min_value = 1, max_value=4, value = 1)
num_beams = st.sidebar.slider("Num Beams", min_value = 2, max_value=6, value = 4)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 90)
top_p = st.sidebar.slider("Top-p", min_value = 0.4, max_value=1.0, step = 0.05, value = 0.9)
repetition_penalty = st.sidebar.slider("Repetition-Penalty", min_value = 0.45, max_value=2.0, step = 0.1, value = 1.2)
if len(sent)<10:
st.write("Input prompt is too small to generate")
else:
print(sent)
st.write(f"Generating for prompt {sent}....")
output_sequences = generate_texts(pipelines[model_index],
sent,
max_length=max_length,
num_return_sequences=num_return_sequences,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
no_repeat_ngram_size=2,
repetition_penalty = repetition_penalty,
early_stopping=False,
top_p=top_p)
st.write(output_sequences)