grouped-sampling-demo / hanlde_form_submit.py
yonikremer's picture
changed end_of_sentence_stop to false
1fac618
raw
history blame
3.11 kB
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names
SUPPORTED_MODEL_NAMES = get_supported_model_names()
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
print(f"Starts downloading model: {model_name} from the internet.")
pipeline = GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=False,
temp=0.5,
top_p=0.6,
)
print(f"Finished downloading model: {model_name} from the internet.")
return pipeline
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:return: The generated text. str.
"""
better_prompt = rewrite_prompt(prompt)
return pipeline(
prompt_s=better_prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
@st.cache
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if model_name not in SUPPORTED_MODEL_NAMES:
raise ValueError(f"The selected model {model_name} is not supported."
f"Supported models are all the models in:"
f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
if output_length <= 0:
raise ValueError(f"The output length {output_length} must be > 0.")
if len(prompt) == 0:
raise ValueError(f"The prompt must not be empty.")
if not isinstance(prompt, str):
raise ValueError(f"The prompt must be a string.")
if not isinstance(output_length, int):
raise ValueError(f"The output length must be an integer.")
pipeline = create_pipeline(
model_name=model_name,
group_size=output_length,
)
return generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
)