yonikremer's picture
created an initial app
826e275
raw
history blame
No virus
2.15 kB
"""
The Streamlit app for the project demo.
In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm.
"""
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
"""
pipeline = create_pipeline(model_name, group_size)
return pipeline(prompt)["generated_text"]
with st.form("request_form"):
selected_model_name: str = st.text_input(
label="Model name",
value="gpt2",
help=f"The name of the model to use. Must be a model from this list: {available_models_list}"
)
output_length: int = st.number_input(
label="Output Length in tokens",
min_value=1,
max_value=4096,
value=100,
help="The length of the output text in tokens (word pieces)."
)
submitted_prompt: str = st.text_area(
label="Input for the model",
help="Enter the prompt for the model. The model will generate a response based on this prompt.",
max_chars=16384,
)
submitted: bool = st.form_submit_button(
label="Generate",
help="Generate the output text.",
disabled=False
)
if submitted:
output = on_form_submit(selected_model_name, output_length, submitted_prompt)
st.write(f"Generated text: {output}")