|
""" |
|
The Streamlit app for the project demo. |
|
In the demo, the user can write a prompt |
|
and the model will generate a response using the grouped sampling algorithm. |
|
""" |
|
|
|
import streamlit as st |
|
from torch.cuda import CudaError |
|
|
|
from hanlde_form_submit import on_form_submit |
|
from on_server_start import main as on_server_start_main |
|
|
|
on_server_start_main() |
|
|
|
st.title("Grouped Sampling Demo") |
|
|
|
with st.form("request_form"): |
|
selected_model_name: str = st.text_input( |
|
label="Model name", |
|
value="gpt2", |
|
help="The name of the model to use." |
|
"Supported models are all the models in:" |
|
" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch", |
|
) |
|
|
|
output_length: int = st.number_input( |
|
label="Number of word pieces in the generated text, 1-4096 (default: 100)", |
|
min_value=1, |
|
max_value=4096, |
|
value=100, |
|
help="The length of the output text in tokens (word pieces)." |
|
) |
|
|
|
submitted_prompt: str = st.text_area( |
|
label="Input for the model, It is highly recommended to write an English prompt.", |
|
help="Enter the prompt for the model. The model will generate a response based on this prompt.", |
|
value="Instruction: Answer in yes or no.\n" |
|
"Question: Is this a prompt?\n" |
|
"Answer: ", |
|
max_chars=2048, |
|
) |
|
|
|
web_search: bool = st.checkbox( |
|
label="Web search", |
|
value=True, |
|
help="If checked, the model will get your prompt as well as some web search results." |
|
) |
|
|
|
submitted: bool = st.form_submit_button( |
|
label="Generate", |
|
help="Generate the output text.", |
|
disabled=False, |
|
) |
|
|
|
if submitted: |
|
try: |
|
output = on_form_submit( |
|
selected_model_name, |
|
output_length, |
|
submitted_prompt, |
|
web_search, |
|
) |
|
except CudaError as e: |
|
st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.") |
|
except (ValueError, TypeError, RuntimeError) as e: |
|
st.error(e) |
|
st.write(f"Generated text: {output}") |
|
|
|
|
|
with open("user_instructions_hebrew.md", "r") as fh: |
|
long_description = fh.read() |
|
st.markdown(long_description) |
|
|
|
await on_server_start_main() |
|
|