import os from functools import lru_cache from time import time import streamlit as st from grouped_sampling import GroupedSamplingPipeLine from download_repo import download_pytorch_model def is_downloaded(model_name: str) -> bool: """ Checks if the model is downloaded. :param model_name: The name of the model to check. :return: True if the model is downloaded, False otherwise. """ models_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}") return os.path.isdir(model_dir) @lru_cache(maxsize=10) def create_pipeline(model_name: str) -> GroupedSamplingPipeLine: """ Creates a pipeline with the given model name and group size. :param model_name: The name of the model to use. :return: A pipeline with the given model name and group size. """ if not is_downloaded(model_name): download_repository_start_time = time() st.write(f"Starts downloading model: {model_name} from the internet.") download_pytorch_model(model_name) download_repository_end_time = time() download_time = download_repository_end_time - download_repository_start_time st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.") st.write(f"Starts creating pipeline with model: {model_name}") pipeline_start_time = time() pipeline = GroupedSamplingPipeLine( model_name=model_name, group_size=512, end_of_sentence_stop=False, top_k=50, load_in_8bit=False, ) pipeline_end_time = time() pipeline_time = pipeline_end_time - pipeline_start_time st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") return pipeline def generate_text( pipeline: GroupedSamplingPipeLine, prompt: str, output_length: int, ) -> str: """ Generates text using the given pipeline. :param pipeline: The pipeline to use. GroupedSamplingPipeLine. :param prompt: The prompt to use. str. :param output_length: The size of the text to generate in tokens. int > 0. :return: The generated text. str. """ return pipeline( prompt_s=prompt, max_new_tokens=output_length, return_text=True, return_full_text=False, )["generated_text"] def on_form_submit( model_name: str, output_length: int, prompt: str, ) -> str: """ Called when the user submits the form. :param model_name: The name of the model to use. :param output_length: The size of the groups to use. :param prompt: The prompt to use. :return: The output of the model. :raises ValueError: If the model name is not supported, the output length is <= 0, the prompt is empty or longer than 16384 characters, or the output length is not an integer. TypeError: If the output length is not an integer or the prompt is not a string. RuntimeError: If the model is not found. """ if len(prompt) == 0: raise ValueError("The prompt must not be empty.") st.write(f"Loading model: {model_name}...") loading_start_time = time() pipeline = create_pipeline( model_name=model_name, ) loading_end_time = time() loading_time = loading_end_time - loading_start_time st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.") st.write("Generating text...") generation_start_time = time() generated_text = generate_text( pipeline=pipeline, prompt=prompt, output_length=output_length, ) generation_end_time = time() generation_time = generation_end_time - generation_start_time st.write(f"Finished generating text in {generation_time:,.2f} seconds.") if not isinstance(generated_text, str): raise RuntimeError(f"The model {model_name} did not generate any text.") if len(generated_text) == 0: raise RuntimeError(f"The model {model_name} did not generate any text.") return generated_text