Spaces:
Runtime error
Runtime error
File size: 6,063 Bytes
2de3774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from shared import state, path_manager
import shared
from pathlib import Path
import re
try:
import modules.faceswapper_pipeline as faceswapper_pipeline
print("INFO: Faceswap enabled")
state["faceswap_loaded"] = True
except:
state["faceswap_loaded"] = False
import modules.sdxl_pipeline as sdxl_pipeline
import modules.template_pipeline as template_pipeline
import modules.upscale_pipeline as upscale_pipeline
import modules.search_pipeline as search_pipeline
import modules.huggingface_dl_pipeline as huggingface_dl_pipeline
import modules.diffusers_pipeline as diffusers_pipeline
import modules.rembg_pipeline as rembg_pipeline
import modules.llama_pipeline as llama_pipeline
import modules.hunyuan_video_pipeline as hunyuan_video_pipeline
import modules.wan_video_pipeline as wan_video_pipeline
import modules.hashbang_pipeline as hashbang_pipeline
import modules.controlnet as controlnet
class NoPipeLine:
pipeline_type = []
def update(gen_data):
prompt = gen_data["prompt"] if "prompt" in gen_data else ""
cn_settings = controlnet.get_settings(gen_data)
cn_type = cn_settings["type"] if "type" in cn_settings else ""
try:
if "task_type" in gen_data and gen_data["task_type"] == "llama":
if (
state["pipeline"] is None
or "llama" not in state["pipeline"].pipeline_type
):
state["pipeline"] = llama_pipeline.pipeline()
elif prompt.lower() == "ruinedfooocuslogo":
if (
state["pipeline"] is None
or "template" not in state["pipeline"].pipeline_type
):
state["pipeline"] = template_pipeline.pipeline()
elif prompt.startswith("#!"):
if (
state["pipeline"] is None
or "hashbang" not in state["pipeline"].pipeline_type
):
state["pipeline"] = hashbang_pipeline.pipeline()
elif prompt.lower().startswith("search:"):
if (
state["pipeline"] is None
or "search" not in state["pipeline"].pipeline_type
):
state["pipeline"] = search_pipeline.pipeline()
elif re.match(r"^\s*hf:", prompt):
if (
state["pipeline"] is None
or "huggingface_dl" not in state["pipeline"].pipeline_type
):
state["pipeline"] = huggingface_dl_pipeline.pipeline()
elif cn_type.lower() == "upscale":
if (
state["pipeline"] is None
or "upscale" not in state["pipeline"].pipeline_type
):
state["pipeline"] = upscale_pipeline.pipeline()
elif cn_type.lower() == "faceswap" and state["faceswap_loaded"]:
if (
state["pipeline"] is None
or "faceswap" not in state["pipeline"].pipeline_type
):
state["pipeline"] = faceswapper_pipeline.pipeline()
elif cn_type.lower() == "rembg":
if (
state["pipeline"] is None
or "rembg" not in state["pipeline"].pipeline_type
):
state["pipeline"] = rembg_pipeline.pipeline()
else:
baseModel = None
if "base_model_name" in gen_data:
file = shared.models.get_file("checkpoints", gen_data['base_model_name'])
if file is None:
file = ""
baseModel = "None"
else:
path = shared.models.get_models_by_path("checkpoints", file)
baseModel = shared.models.get_model_base(path)
baseModelName = gen_data['base_model_name']
if state["pipeline"] is None:
state["pipeline"] = NoPipeLine()
if baseModelName.startswith("🤗"):
if (
state["pipeline"] is None
or "diffusers" not in state["pipeline"].pipeline_type
):
state["pipeline"] = diffusers_pipeline.pipeline()
elif (
baseModel == "Hunyuan Video" or
Path(gen_data['base_model_name']).parts[0] == "Hunyuan Video" or
str(Path(file).name).startswith("hunyuan-video-t2v-") or
str(Path(file).name).startswith("fast-hunyuan-video-t2v-")
):
if (
state["pipeline"] is None
or "hunyuan_video" not in state["pipeline"].pipeline_type
):
state["pipeline"] = hunyuan_video_pipeline.pipeline()
elif (
baseModel == "Wan Video" or
Path(gen_data['base_model_name']).parts[0] == "Wan Video" or
str(Path(file).name).startswith("wan2.1-t2v-") or
str(Path(file).name).startswith("wan2.1_t2v_") or
str(Path(file).name).startswith("wan2.1-i2v-") or
str(Path(file).name).startswith("wan2.1_i2v_")
):
if (
state["pipeline"] is None
or "wan_video" not in state["pipeline"].pipeline_type
):
state["pipeline"] = wan_video_pipeline.pipeline()
elif baseModel is not None:
# Try with the sdxl/default pipeline if baseModel is set.
if ("sdxl" not in state["pipeline"].pipeline_type):
state["pipeline"] = sdxl_pipeline.pipeline()
if state["pipeline"] is None or len(state["pipeline"].pipeline_type) == 0:
print(f"Using default pipeline.")
state["pipeline"] = sdxl_pipeline.pipeline()
return state["pipeline"]
except:
# If things fail. Use the template pipeline that only returns a logo
print(f"Something went wrong. Falling back to template pipeline.")
state["pipeline"] = template_pipeline.pipeline()
return state["pipeline"]
|