""" A script that is run when the server starts. """ from concurrent.futures import ThreadPoolExecutor from transformers import logging as transformers_logging from huggingface_hub import logging as huggingface_hub_logging from available_models import AVAILABLE_MODELS from download_repo import download_pytorch_model def disable_progress_bar(): """ Disables the progress bar when downloading models. """ transformers_logging.disable_progress_bar() huggingface_hub_logging.disable_propagation() def download_useful_models(): """ Downloads the models that are useful for this project. So that the user doesn't have to wait for the models to download when they first use the app. """ print("Downloading useful models. It might take a while...") with ThreadPoolExecutor() as executor: executor.map(download_pytorch_model, AVAILABLE_MODELS) def main(): # disable_progress_bar() download_useful_models() if __name__ == "__main__": main()