""" The Streamlit app for the project demo. In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm. """ import os from time import time import streamlit as st from grouped_sampling import GroupedSamplingPipeLine from torch.cuda import CudaError from huggingface_hub import logging as hf_hub_logging from available_models import AVAILABLE_MODELS from hanlde_form_submit import on_form_submit def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine: """ Creates a pipeline with the given model name and group size. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :return: A pipeline with the given model name and group size. """ st.write(f"Starts creating pipeline with model: {model_name}") pipeline_start_time = time() pipeline = GroupedSamplingPipeLine( model_name=model_name, group_size=group_size, end_of_sentence_stop=False, top_k=50, ) pipeline_end_time = time() pipeline_time = pipeline_end_time - pipeline_start_time st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") return pipeline hf_hub_logging.set_verbosity_error() st.set_page_config( page_title="דגימה בקבוצות - שימוש יעיל במודלי שפה סיבתיים", layout="wide", ) pipelines = { model_name: create_pipeline(model_name, 1024) for model_name in AVAILABLE_MODELS[1:] } with st.form("request_form"): selected_model_name: str = st.selectbox( label="בחרו מודל", options=AVAILABLE_MODELS, help="llama-30b-hf generates better texts but is slower", ) output_length: int = st.number_input( label="כמות המילים המקסימלית בפלט - בין 1 ל-1024", min_value=1, max_value=1024, value=5, ) submitted_prompt: str = st.text_area( label="הקלט לאלוגריתם (באנגלית בלבד)", value="Instruction: Answer in yes or no.\n" "Question: Is the sky blue?\n" "Answer:", max_chars=2048, ) submitted: bool = st.form_submit_button( label="צור טקסט", disabled=False, ) if submitted: try: output = on_form_submit( pipelines[selected_model_name], output_length, submitted_prompt, ) except CudaError as e: st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.") except (ValueError, TypeError, RuntimeError) as e: st.error(e) else: st.write(f"Generated text: {output}") user_instructions_file = os.path.join( os.path.dirname(__file__), "user_instructions_hebrew.md", ) with open(user_instructions_file, "r") as fh: long_description = fh.read() st.markdown(long_description)