|
""" |
|
The Streamlit app for the project demo. |
|
In the demo, the user can write a prompt |
|
and the model will generate a response using the grouped sampling algorithm. |
|
""" |
|
import os |
|
from time import time |
|
|
|
import streamlit as st |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
from torch.cuda import CudaError |
|
from huggingface_hub import logging as hf_hub_logging |
|
|
|
from available_models import AVAILABLE_MODELS |
|
from hanlde_form_submit import on_form_submit |
|
|
|
|
|
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine: |
|
""" |
|
Creates a pipeline with the given model name and group size. |
|
:param model_name: The name of the model to use. |
|
:param group_size: The size of the groups to use. |
|
:return: A pipeline with the given model name and group size. |
|
""" |
|
st.write(f"Starts creating pipeline with model: {model_name}") |
|
pipeline_start_time = time() |
|
pipeline = GroupedSamplingPipeLine( |
|
model_name=model_name, |
|
group_size=group_size, |
|
end_of_sentence_stop=False, |
|
top_k=50, |
|
) |
|
pipeline_end_time = time() |
|
pipeline_time = pipeline_end_time - pipeline_start_time |
|
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") |
|
return pipeline |
|
|
|
|
|
hf_hub_logging.set_verbosity_error() |
|
|
|
st.set_page_config( |
|
page_title="讚讙讬诪讛 讘拽讘讜爪讜转 - 砖讬诪讜砖 讬注讬诇 讘诪讜讚诇讬 砖驻讛 住讬讘转讬讬诐", |
|
layout="wide", |
|
) |
|
|
|
pipelines = { |
|
model_name: create_pipeline(model_name, 1024) for model_name in AVAILABLE_MODELS[1:] |
|
} |
|
|
|
with st.form("request_form"): |
|
selected_model_name: str = st.selectbox( |
|
label="讘讞专讜 诪讜讚诇", |
|
options=AVAILABLE_MODELS, |
|
help="llama-30b-hf generates better texts but is slower", |
|
) |
|
|
|
output_length: int = st.number_input( |
|
label="讻诪讜转 讛诪讬诇讬诐 讛诪拽住讬诪诇讬转 讘驻诇讟 - 讘讬谉 1 诇-1024", |
|
min_value=1, |
|
max_value=1024, |
|
value=5, |
|
) |
|
|
|
submitted_prompt: str = st.text_area( |
|
label="讛拽诇讟 诇讗诇讜讙专讬转诐 (讘讗谞讙诇讬转 讘诇讘讚)", |
|
value="Instruction: Answer in yes or no.\n" |
|
"Question: Is the sky blue?\n" |
|
"Answer:", |
|
max_chars=2048, |
|
) |
|
|
|
submitted: bool = st.form_submit_button( |
|
label="爪讜专 讟拽住讟", |
|
disabled=False, |
|
) |
|
|
|
if submitted: |
|
try: |
|
output = on_form_submit( |
|
pipelines[selected_model_name], |
|
output_length, |
|
submitted_prompt, |
|
) |
|
except CudaError as e: |
|
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.") |
|
except (ValueError, TypeError, RuntimeError) as e: |
|
st.error(e) |
|
else: |
|
st.write(f"Generated text: {output}") |
|
|
|
|
|
user_instructions_file = os.path.join( |
|
os.path.dirname(__file__), |
|
"user_instructions_hebrew.md", |
|
) |
|
with open(user_instructions_file, "r") as fh: |
|
long_description = fh.read() |
|
st.markdown(long_description) |
|
|