from constants import ( TAESD_MODEL, TAESDXL_MODEL, TAESD_MODEL_OPENVINO, TAESDXL_MODEL_OPENVINO, ) def get_tiny_decoder_vae_model(pipeline_class) -> str: print(f"Pipeline class : {pipeline_class}") if ( pipeline_class == "LatentConsistencyModelPipeline" or pipeline_class == "StableDiffusionPipeline" or pipeline_class == "StableDiffusionImg2ImgPipeline" ): return TAESD_MODEL elif ( pipeline_class == "StableDiffusionXLPipeline" or pipeline_class == "StableDiffusionXLImg2ImgPipeline" ): return TAESDXL_MODEL elif ( pipeline_class == "OVStableDiffusionPipeline" or pipeline_class == "OVStableDiffusionImg2ImgPipeline" ): return TAESD_MODEL_OPENVINO elif pipeline_class == "OVStableDiffusionXLPipeline": return TAESDXL_MODEL_OPENVINO else: raise Exception("No valid pipeline class found!")