GenAI-Arena / model /pre_download.py
vinesmsuic's picture
Revert "load all image models"
4d8824e
raw
history blame
2.99 kB
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
def pre_download_all_models():
"""
Pre-download all models to avoid download delay during the first user request
"""
imagen_dl_error = pre_download_image_models()
imagedit_dl_error = pre_download_image_models()
videogen_dl_error = pre_download_video_models()
print("All models downloaded.")
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
def pre_download_image_models():
"""
Pre-download image models to avoid download delay during the first user request
"""
import imagen_hub
errored_models = []
for model_string in IMAGE_GENERATION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "imagenhub":
try:
print("Loading image generation model:", model_name)
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models
def pre_download_image_models():
"""
Pre-download image models to avoid download delay during the first user request
"""
import imagen_hub
errored_models = []
for model_string in IMAGE_EDITION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "imagenhub":
try:
print("Loading image edition model:", model_name)
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models
def pre_download_video_models():
"""
Pre-download video models to avoid download delay during the first user request
"""
import videogen_hub
errored_models = []
for model_string in VIDEO_GENERATION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "videogenhub":
try:
print("Loading video generation model:", model_name)
temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models