Spaces:
Runtime error
Runtime error
import os | |
from shared import state, path_manager | |
import shared | |
from pathlib import Path | |
import re | |
try: | |
import modules.faceswapper_pipeline as faceswapper_pipeline | |
print("INFO: Faceswap enabled") | |
state["faceswap_loaded"] = True | |
except: | |
state["faceswap_loaded"] = False | |
import modules.sdxl_pipeline as sdxl_pipeline | |
import modules.template_pipeline as template_pipeline | |
import modules.upscale_pipeline as upscale_pipeline | |
import modules.search_pipeline as search_pipeline | |
import modules.huggingface_dl_pipeline as huggingface_dl_pipeline | |
import modules.diffusers_pipeline as diffusers_pipeline | |
import modules.rembg_pipeline as rembg_pipeline | |
import modules.llama_pipeline as llama_pipeline | |
import modules.hunyuan_video_pipeline as hunyuan_video_pipeline | |
import modules.wan_video_pipeline as wan_video_pipeline | |
import modules.hashbang_pipeline as hashbang_pipeline | |
import modules.controlnet as controlnet | |
class NoPipeLine: | |
pipeline_type = [] | |
def update(gen_data): | |
prompt = gen_data["prompt"] if "prompt" in gen_data else "" | |
cn_settings = controlnet.get_settings(gen_data) | |
cn_type = cn_settings["type"] if "type" in cn_settings else "" | |
try: | |
if "task_type" in gen_data and gen_data["task_type"] == "llama": | |
if ( | |
state["pipeline"] is None | |
or "llama" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = llama_pipeline.pipeline() | |
elif prompt.lower() == "ruinedfooocuslogo": | |
if ( | |
state["pipeline"] is None | |
or "template" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = template_pipeline.pipeline() | |
elif prompt.startswith("#!"): | |
if ( | |
state["pipeline"] is None | |
or "hashbang" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = hashbang_pipeline.pipeline() | |
elif prompt.lower().startswith("search:"): | |
if ( | |
state["pipeline"] is None | |
or "search" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = search_pipeline.pipeline() | |
elif re.match(r"^\s*hf:", prompt): | |
if ( | |
state["pipeline"] is None | |
or "huggingface_dl" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = huggingface_dl_pipeline.pipeline() | |
elif cn_type.lower() == "upscale": | |
if ( | |
state["pipeline"] is None | |
or "upscale" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = upscale_pipeline.pipeline() | |
elif cn_type.lower() == "faceswap" and state["faceswap_loaded"]: | |
if ( | |
state["pipeline"] is None | |
or "faceswap" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = faceswapper_pipeline.pipeline() | |
elif cn_type.lower() == "rembg": | |
if ( | |
state["pipeline"] is None | |
or "rembg" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = rembg_pipeline.pipeline() | |
else: | |
baseModel = None | |
if "base_model_name" in gen_data: | |
file = shared.models.get_file("checkpoints", gen_data['base_model_name']) | |
if file is None: | |
file = "" | |
baseModel = "None" | |
else: | |
path = shared.models.get_models_by_path("checkpoints", file) | |
baseModel = shared.models.get_model_base(path) | |
baseModelName = gen_data['base_model_name'] | |
if state["pipeline"] is None: | |
state["pipeline"] = NoPipeLine() | |
if baseModelName.startswith("🤗"): | |
if ( | |
state["pipeline"] is None | |
or "diffusers" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = diffusers_pipeline.pipeline() | |
elif ( | |
baseModel == "Hunyuan Video" or | |
Path(gen_data['base_model_name']).parts[0] == "Hunyuan Video" or | |
str(Path(file).name).startswith("hunyuan-video-t2v-") or | |
str(Path(file).name).startswith("fast-hunyuan-video-t2v-") | |
): | |
if ( | |
state["pipeline"] is None | |
or "hunyuan_video" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = hunyuan_video_pipeline.pipeline() | |
elif ( | |
baseModel == "Wan Video" or | |
Path(gen_data['base_model_name']).parts[0] == "Wan Video" or | |
str(Path(file).name).startswith("wan2.1-t2v-") or | |
str(Path(file).name).startswith("wan2.1_t2v_") or | |
str(Path(file).name).startswith("wan2.1-i2v-") or | |
str(Path(file).name).startswith("wan2.1_i2v_") | |
): | |
if ( | |
state["pipeline"] is None | |
or "wan_video" not in state["pipeline"].pipeline_type | |
): | |
state["pipeline"] = wan_video_pipeline.pipeline() | |
elif baseModel is not None: | |
# Try with the sdxl/default pipeline if baseModel is set. | |
if ("sdxl" not in state["pipeline"].pipeline_type): | |
state["pipeline"] = sdxl_pipeline.pipeline() | |
if state["pipeline"] is None or len(state["pipeline"].pipeline_type) == 0: | |
print(f"Using default pipeline.") | |
state["pipeline"] = sdxl_pipeline.pipeline() | |
return state["pipeline"] | |
except: | |
# If things fail. Use the template pipeline that only returns a logo | |
print(f"Something went wrong. Falling back to template pipeline.") | |
state["pipeline"] = template_pipeline.pipeline() | |
return state["pipeline"] | |