Spaces:
Running
Running
from .imagenhub_models import load_imagenhub_model | |
from .playground_api import load_playground_model | |
IMAGE_GENERATION_MODELS = ['imagenhub_LCM_generation','imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', | |
'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', | |
'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation'] | |
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition', | |
'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', 'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition'] | |
def load_pipeline(model_name): | |
""" | |
Load a model pipeline based on the model name | |
Args: | |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type} | |
the source can be either imagenhub or playground | |
the name is the name of the model used to load the model | |
the type is the type of the model, either generation or edition | |
""" | |
model_source, model_name, model_type = model_name.split("_") | |
if model_source == "imagenhub": | |
pipe = load_imagenhub_model(model_name, model_type) | |
elif model_source == "playground": | |
pipe = load_playground_model(model_name) | |
else: | |
raise ValueError(f"Model source {model_source} not supported") | |
return pipe |