RuinedFooocus / modules /pipelines.py
malizec's picture
Upload folder using huggingface_hub
2de3774 verified
import os
from shared import state, path_manager
import shared
from pathlib import Path
import re
try:
import modules.faceswapper_pipeline as faceswapper_pipeline
print("INFO: Faceswap enabled")
state["faceswap_loaded"] = True
except:
state["faceswap_loaded"] = False
import modules.sdxl_pipeline as sdxl_pipeline
import modules.template_pipeline as template_pipeline
import modules.upscale_pipeline as upscale_pipeline
import modules.search_pipeline as search_pipeline
import modules.huggingface_dl_pipeline as huggingface_dl_pipeline
import modules.diffusers_pipeline as diffusers_pipeline
import modules.rembg_pipeline as rembg_pipeline
import modules.llama_pipeline as llama_pipeline
import modules.hunyuan_video_pipeline as hunyuan_video_pipeline
import modules.wan_video_pipeline as wan_video_pipeline
import modules.hashbang_pipeline as hashbang_pipeline
import modules.controlnet as controlnet
class NoPipeLine:
pipeline_type = []
def update(gen_data):
prompt = gen_data["prompt"] if "prompt" in gen_data else ""
cn_settings = controlnet.get_settings(gen_data)
cn_type = cn_settings["type"] if "type" in cn_settings else ""
try:
if "task_type" in gen_data and gen_data["task_type"] == "llama":
if (
state["pipeline"] is None
or "llama" not in state["pipeline"].pipeline_type
):
state["pipeline"] = llama_pipeline.pipeline()
elif prompt.lower() == "ruinedfooocuslogo":
if (
state["pipeline"] is None
or "template" not in state["pipeline"].pipeline_type
):
state["pipeline"] = template_pipeline.pipeline()
elif prompt.startswith("#!"):
if (
state["pipeline"] is None
or "hashbang" not in state["pipeline"].pipeline_type
):
state["pipeline"] = hashbang_pipeline.pipeline()
elif prompt.lower().startswith("search:"):
if (
state["pipeline"] is None
or "search" not in state["pipeline"].pipeline_type
):
state["pipeline"] = search_pipeline.pipeline()
elif re.match(r"^\s*hf:", prompt):
if (
state["pipeline"] is None
or "huggingface_dl" not in state["pipeline"].pipeline_type
):
state["pipeline"] = huggingface_dl_pipeline.pipeline()
elif cn_type.lower() == "upscale":
if (
state["pipeline"] is None
or "upscale" not in state["pipeline"].pipeline_type
):
state["pipeline"] = upscale_pipeline.pipeline()
elif cn_type.lower() == "faceswap" and state["faceswap_loaded"]:
if (
state["pipeline"] is None
or "faceswap" not in state["pipeline"].pipeline_type
):
state["pipeline"] = faceswapper_pipeline.pipeline()
elif cn_type.lower() == "rembg":
if (
state["pipeline"] is None
or "rembg" not in state["pipeline"].pipeline_type
):
state["pipeline"] = rembg_pipeline.pipeline()
else:
baseModel = None
if "base_model_name" in gen_data:
file = shared.models.get_file("checkpoints", gen_data['base_model_name'])
if file is None:
file = ""
baseModel = "None"
else:
path = shared.models.get_models_by_path("checkpoints", file)
baseModel = shared.models.get_model_base(path)
baseModelName = gen_data['base_model_name']
if state["pipeline"] is None:
state["pipeline"] = NoPipeLine()
if baseModelName.startswith("🤗"):
if (
state["pipeline"] is None
or "diffusers" not in state["pipeline"].pipeline_type
):
state["pipeline"] = diffusers_pipeline.pipeline()
elif (
baseModel == "Hunyuan Video" or
Path(gen_data['base_model_name']).parts[0] == "Hunyuan Video" or
str(Path(file).name).startswith("hunyuan-video-t2v-") or
str(Path(file).name).startswith("fast-hunyuan-video-t2v-")
):
if (
state["pipeline"] is None
or "hunyuan_video" not in state["pipeline"].pipeline_type
):
state["pipeline"] = hunyuan_video_pipeline.pipeline()
elif (
baseModel == "Wan Video" or
Path(gen_data['base_model_name']).parts[0] == "Wan Video" or
str(Path(file).name).startswith("wan2.1-t2v-") or
str(Path(file).name).startswith("wan2.1_t2v_") or
str(Path(file).name).startswith("wan2.1-i2v-") or
str(Path(file).name).startswith("wan2.1_i2v_")
):
if (
state["pipeline"] is None
or "wan_video" not in state["pipeline"].pipeline_type
):
state["pipeline"] = wan_video_pipeline.pipeline()
elif baseModel is not None:
# Try with the sdxl/default pipeline if baseModel is set.
if ("sdxl" not in state["pipeline"].pipeline_type):
state["pipeline"] = sdxl_pipeline.pipeline()
if state["pipeline"] is None or len(state["pipeline"].pipeline_type) == 0:
print(f"Using default pipeline.")
state["pipeline"] = sdxl_pipeline.pipeline()
return state["pipeline"]
except:
# If things fail. Use the template pipeline that only returns a logo
print(f"Something went wrong. Falling back to template pipeline.")
state["pipeline"] = template_pipeline.pipeline()
return state["pipeline"]