RuinedFooocus / modules /sdxl_pipeline.py
malizec's picture
Upload folder using huggingface_hub
2de3774 verified
import gc
import numpy as np
import os
import torch
import traceback
import sys
import modules.controlnet
import modules.async_worker as worker
import modules.prompt_processing as pp
from PIL import Image, ImageOps
from comfy.model_base import BaseModel, SDXL, SD3, Flux, Lumina2
from shared import path_manager, settings
import shared
from pathlib import Path
import json
import random
import comfy.utils
import comfy.model_management
from comfy.sd import load_checkpoint_guess_config, load_state_dict_guess_config
from tqdm import tqdm
from comfy_extras.nodes_model_advanced import ModelSamplingAuraFlow
from nodes import (
CLIPTextEncode,
CLIPSetLastLayer,
ControlNetApplyAdvanced,
EmptyLatentImage,
VAEDecode,
VAEEncode,
VAEEncodeForInpaint,
CLIPLoader,
VAELoader,
)
from comfy.sampler_helpers import (
cleanup_additional_models,
convert_cond,
get_additional_models,
prepare_mask,
)
from comfy_extras.nodes_sd3 import EmptySD3LatentImage
from node_helpers import conditioning_set_values
from comfy.samplers import KSampler
from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels
from comfy_extras.nodes_canny import Canny
from comfy_extras.nodes_freelunch import FreeU
from comfy.model_patcher import ModelPatcher
from comfy.sd import CLIP, VAE
from comfy.utils import load_torch_file
from comfy.sd import save_checkpoint
from modules.pipleline_utils import (
get_previewer,
clean_prompt_cond_caches,
set_timestep_range,
)
#from comfyui_gguf.nodes import gguf_sd_loader, DualCLIPLoaderGGUF, GGUFModelPatcher
#from comfyui_gguf.ops import GGMLOps
from calcuis_gguf.pig import load_gguf_sd, GGMLOps, GGUFModelPatcher
from calcuis_gguf.pig import DualClipLoaderGGUF as DualCLIPLoaderGGUF
class pipeline:
pipeline_type = ["sdxl", "ssd", "sd3", "flux", "lumina2"]
comfy.model_management.DISABLE_SMART_MEMORY = False
comfy.model_management.EXTRA_RESERVED_VRAM = 800 * 1024 * 1024
class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision):
self.unet = unet
self.vae = vae
self.clip = clip
self.clip_vision = clip_vision
def to_meta(self):
if self.unet is not None:
self.unet.model.to("meta")
if self.clip is not None:
self.clip.cond_stage_model.to("meta")
if self.vae is not None:
self.vae.first_stage_model.to("meta")
xl_base: StableDiffusionModel = None
xl_base_hash = ""
xl_base_patched: StableDiffusionModel = None
xl_base_patched_hash = ""
xl_base_patched_extra = set()
xl_controlnet: StableDiffusionModel = None
xl_controlnet_hash = ""
models = []
inference_memory = None
ggml_ops = GGMLOps()
def get_clip_name(self, shortname):
# List of short names and default names for different text encoders
defaults = {
"clip_g": "clip_g.safetensors",
"clip_gemma": "gemma_2_2b_fp16.safetensors",
"clip_l": "clip_l.safetensors",
"clip_t5": "t5-v1_1-xxl-encoder-Q3_K_S.gguf",
}
return settings.default_settings.get(shortname, defaults[shortname] if shortname in defaults else None)
# FIXME move this to separate file
def merge_models(self, name):
print(f"Loading merge: {name}")
self.xl_base_patched = None
self.xl_base_patched_hash = ""
self.xl_base_patched_extra = set()
self.conditions = None
filename = shared.models.get_file("checkpoints", name)
cache_name = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors"))
if Path(cache_name).exists() and Path(cache_name).stat().st_mtime >= Path(filename).stat().st_mtime:
print(f"Loading cached version:")
self.load_base_model(cache_name)
return
try:
with filename.open() as f:
merge_data = json.load(f)
if 'comment' in merge_data:
print(f" {merge_data['comment']}")
filename = shared.models.get_file("checkpoints", merge_data["base"]["name"])
norm = 1.0
if "models" in merge_data and len(merge_data["models"]) > 0:
weights = sum([merge_data["base"]["weight"]] + [x.get("weight") for x in merge_data["models"]])
if "normalize" in merge_data:
norm = float(merge_data["normalize"]) / weights
else:
norm = 1.0 / weights
print(f"Loading base {merge_data['base']['name']} ({round(merge_data['base']['weight'] * norm * 100)}%)")
with torch.torch.inference_mode():
unet, clip, vae, clip_vision = load_checkpoint_guess_config(str(filename))
self.xl_base = self.StableDiffusionModel(
unet=unet, clip=clip, vae=vae, clip_vision=clip_vision
)
if self.xl_base is not None:
self.xl_base_hash = name
self.xl_base_patched = self.xl_base
self.xl_base_patched_hash = ""
except Exception as e:
self.xl_base = None
print(f"ERROR: {e}")
return
if "models" in merge_data and len(merge_data["models"]) > 0:
device = comfy.model_management.get_torch_device()
mp = ModelPatcher(self.xl_base_patched.unet, device, "cpu", size=1)
w = float(merge_data["base"]["weight"]) * norm
for m in merge_data["models"]:
print(f"Merging {m['name']} ({round(m['weight'] * norm * 100)}%)")
filename = str(shared.models.get_file("checkpoints", m["name"]))
# FIXME add error check?`
with torch.torch.inference_mode():
m_unet, m_clip, m_vae, m_clip_vision = load_checkpoint_guess_config(str(filename))
del m_clip
del m_vae
del m_clip_vision
kp = m_unet.get_key_patches("diffusion_model.")
for k in kp:
mp.model.add_patches({k: kp[k]}, strength_patch=float(m['weight'] * norm), strength_model=w)
del m_unet
w = 1.0
self.xl_base = self.StableDiffusionModel(
unet=mp.model, clip=clip, vae=vae, clip_vision=clip_vision
)
if "loras" in merge_data and len(merge_data["loras"]) > 0:
loras = [(x.get("name"), x.get("weight")) for x in merge_data["loras"]]
self.load_loras(loras)
self.xl_base = self.xl_base_patched
if 'cache' in merge_data and merge_data['cache'] == True:
filename = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors"))
print(f"Saving merged model: {filename}")
with torch.torch.inference_mode():
save_checkpoint(
filename,
self.xl_base.unet,
clip=self.xl_base.clip,
vae=self.xl_base.vae,
clip_vision=self.xl_base.clip_vision,
metadata={"rf_merge_data": str(merge_data)}
)
return
def load_base_model(self, name, unet_only=False, input_unet=None):
if self.xl_base_hash == name and self.xl_base_patched_extra == set():
return
filename = shared.models.get_file("checkpoints", name)
# If we don't have a filename, get the default.
if filename is None:
base_model = settings.default_settings.get("base_model", "sd_xl_base_1.0_0.9vae.safetensors")
filename = path_manager.get_folder_file_path(
"checkpoints",
base_model,
)
if Path(filename).suffix == '.merge':
self.merge_models(name)
return
if input_unet is None: # Be quiet if we already loaded a unet
print(f"Loading base {'unet' if unet_only else 'model'}: {name}")
self.xl_base = None
self.xl_base_hash = ""
self.xl_base_patched = None
self.xl_base_patched_hash = ""
self.xl_base_patched_extra = set()
self.conditions = None
gc.collect(generation=2)
comfy.model_management.cleanup_models()
comfy.model_management.soft_empty_cache()
unet = None
filename = str(filename) # FIXME use Path and suffix instead?
if filename.endswith(".gguf") or unet_only:
with torch.torch.inference_mode():
try:
if input_unet is not None:
if isinstance(input_unet, ModelPatcher):
unet = input_unet
else:
unet = comfy.sd.load_diffusion_model_state_dict(
input_unet, model_options={"custom_operations": self.ggml_ops}
)
unet = GGUFModelPatcher.clone(unet)
unet.patch_on_device = True
elif filename.endswith(".gguf"):
sd = load_gguf_sd(filename)
unet = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": self.ggml_ops}
)
unet = GGUFModelPatcher.clone(unet)
unet.patch_on_device = True
else:
model_options = {}
model_options["dtype"] = torch.float8_e4m3fn # FIXME should be a setting
unet = comfy.sd.load_diffusion_model(filename, model_options=model_options)
# Get text encoders (clip) and vae to match the unet
clip_names = []
if isinstance(unet.model, Flux):
clip_names.append(self.get_clip_name("clip_l"))
clip_names.append(self.get_clip_name("clip_t5"))
clip_type = comfy.sd.CLIPType.FLUX
vae_name = settings.default_settings.get("vae_flux", "ae.safetensors")
elif isinstance(unet.model, SD3):
clip_names.append(self.get_clip_name("clip_l"))
clip_names.append(self.get_clip_name("clip_g"))
clip_names.append(self.get_clip_name("clip_t5"))
clip_type = comfy.sd.CLIPType.SD3
vae_name = settings.default_settings.get("vae_sd3", "sd3_vae.safetensors")
elif isinstance(unet.model, Lumina2):
clip_names.append(self.get_clip_name("clip_gemma"))
clip_type = comfy.sd.CLIPType.LUMINA2
vae_name = settings.default_settings.get("vae_lumina2", "lumina2_vae_fp32.safetensors")
unet = ModelSamplingAuraFlow().patch_aura(
model=unet,
shift=settings.default_settings.get("lumina2_shift", 3.0),
)[0]
else: # SDXL
clip_names.append(self.get_clip_name("clip_l"))
clip_names.append(self.get_clip_name("clip_g"))
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
vae_name = settings.default_settings.get("vae_sdxl", "sdxl_vae.safetensors")
clip_paths = []
for clip_name in clip_names:
clip_paths.append(
str(
path_manager.get_folder_file_path(
"clip",
clip_name,
default = os.path.join(path_manager.model_paths["clip_path"], clip_name)
)
)
)
clip_loader = DualCLIPLoaderGGUF()
print(f"Loading CLIP: {clip_names}")
clip = clip_loader.load_patcher(
clip_paths,
clip_type,
clip_loader.load_data(clip_paths)
)
vae_path = path_manager.get_folder_file_path(
"vae",
vae_name,
default = os.path.join(path_manager.model_paths["vae_path"], vae_name)
)
print(f"Loading VAE: {vae_name}")
sd = comfy.utils.load_torch_file(str(vae_path))
vae = comfy.sd.VAE(sd=sd)
clip_vision = None
except Exception as e:
unet = None
traceback.print_exc()
else:
sd = None
unet = None
try:
with torch.torch.inference_mode():
sd = comfy.utils.load_torch_file(filename)
except Exception as e:
# Failed loading
print(f"ERROR: Failed loading {filename}: {e}")
if sd is not None:
aio = load_state_dict_guess_config(sd)
if isinstance(aio, tuple):
unet, clip, vae, clip_vision = aio
if (
isinstance(unet, ModelPatcher) and
isinstance(clip, CLIP) and
isinstance(vae, VAE)
):
# If we got here, we have all models. Dump sd since we don't need it
sd = None
else:
if isinstance(unet, ModelPatcher):
sd = unet
if sd is not None:
# We got something, assume it was a unet
self.load_base_model(
filename,
unet_only=True,
input_unet=sd,
)
return
else:
unet = None
if unet == None:
print(f"Failed to load {name}")
self.xl_base = None
self.xl_base_hash = ""
self.xl_base_patched = None
self.xl_base_patched_hash = ""
else:
self.xl_base = self.StableDiffusionModel(
unet=unet, clip=clip, vae=vae, clip_vision=clip_vision
)
if not (
isinstance(self.xl_base.unet.model, BaseModel) or
isinstance(self.xl_base.unet.model, SDXL) or
isinstance(self.xl_base.unet.model, SD3) or
isinstance(self.xl_base.unet.model, Flux) or
isinstance(self.xl_base.unet.model, Lumina2)
):
print(
f"Model {type(self.xl_base.unet.model)} not supported. RuinedFooocus only support SD1.x/SDXL/SD3/Flux/Lumina2 models as the base model."
)
self.xl_base = None
if self.xl_base is not None:
self.xl_base_hash = name
self.xl_base_patched = self.xl_base
self.xl_base_patched_hash = ""
# self.xl_base_patched.unet.model.to("cuda")
#print(f"Base model loaded: {self.xl_base_hash}")
return
def freeu(self, model, b1, b2, s1, s2):
freeu_model = FreeU()
unet = freeu_model.patch(model=model.unet, b1=b1, b2=b2, s1=s1, s2=s2)[0]
return self.StableDiffusionModel(
unet=unet, clip=model.clip, vae=model.vae, clip_vision=model.clip_vision
)
def load_loras(self, loras):
loaded_loras = []
model = self.xl_base
for name, weight in loras:
if name == "None" or weight == 0:
continue
filename = str(shared.models.get_file("loras", name))
print(f"Loading LoRAs: {name}")
try:
lora = comfy.utils.load_torch_file(filename, safe_load=True)
unet, clip = comfy.sd.load_lora_for_models(
model.unet, model.clip, lora, weight, weight
)
model = self.StableDiffusionModel(
unet=unet,
clip=clip,
vae=model.vae,
clip_vision=model.clip_vision,
)
loaded_loras += [(name, weight)]
except:
print(f"Error loading LoRA: {filename}")
pass
self.xl_base_patched = model
# Uncomment below to enable FreeU shit
# self.xl_base_patched = self.freeu(model, 1.01, 1.02, 0.99, 0.95)
# self.xl_base_patched_hash = str(loras + [1.01, 1.02, 0.99, 0.95])
self.xl_base_patched_hash = str(loras)
print(f"LoRAs loaded: {loaded_loras}")
return
def refresh_controlnet(self, name=None):
if self.xl_controlnet_hash == str(self.xl_controlnet):
return
filename = modules.controlnet.get_model(name)
if filename is not None and self.xl_controlnet_hash != name:
self.xl_controlnet = comfy.controlnet.load_controlnet(str(filename))
self.xl_controlnet_hash = name
print(f"ControlNet model loaded: {self.xl_controlnet_hash}")
if self.xl_controlnet_hash != name:
self.xl_controlnet = None
self.xl_controlnet_hash = None
print(f"Controlnet model unloaded")
conditions = None
def textencode(self, id, text, clip_skip):
update = False
hash = f"{text} {clip_skip}"
if hash != self.conditions[id]["text"]:
if clip_skip > 1:
self.xl_base_patched.clip = CLIPSetLastLayer().set_last_layer(
self.xl_base_patched.clip, clip_skip * -1
)[0]
self.conditions[id]["cache"] = CLIPTextEncode().encode(
clip=self.xl_base_patched.clip, text=text
)[0]
self.conditions[id]["text"] = hash
update = True
return update
@torch.inference_mode()
def process(
self,
gen_data=None,
callback=None,
):
try:
if self.xl_base_patched == None or not (
isinstance(self.xl_base_patched.unet.model, BaseModel) or
isinstance(self.xl_base_patched.unet.model, SDXL) or
isinstance(self.xl_base_patched.unet.model, SD3) or
isinstance(self.xl_base_patched.unet.model, Flux) or
isinstance(self.xl_base_patched.unet.model, Lumina2)
):
print(f"ERROR: Can only use SD1.x, SDXL, SD3, Flux or Lumina2 models")
worker.interrupt_ruined_processing = True
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Can only use SDXL, SD3 or Flux models ...", "html/error.png")
)
return []
except Exception as e:
# Something went very wrong
print(f"ERROR: {e}")
worker.interrupt_ruined_processing = True
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Error when trying to use model ...", "html/error.png")
)
return []
positive_prompt = gen_data["positive_prompt"]
negative_prompt = gen_data["negative_prompt"]
input_image = gen_data["input_image"]
controlnet = modules.controlnet.get_settings(gen_data)
cfg = gen_data["cfg"]
sampler_name = gen_data["sampler_name"]
scheduler = gen_data["scheduler"]
clip_skip = gen_data["clip_skip"]
img2img_mode = False
input_image_pil = None
seed = gen_data["seed"] if isinstance(gen_data["seed"], int) else random.randint(1, 2**32)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Processing text encoding ...", None)
)
updated_conditions = False
if self.conditions is None:
self.conditions = clean_prompt_cond_caches()
if self.textencode("+", positive_prompt, clip_skip):
updated_conditions = True
if self.textencode("-", negative_prompt, clip_skip):
updated_conditions = True
switched_prompt = []
if "[" in positive_prompt and "]" in positive_prompt:
if controlnet is not None and input_image is not None:
print("ControlNet and [prompt|switching] do not work well together.")
print("ControlNet will only be applied to the first prompt.")
prompt_per_step = pp.prompt_switch_per_step(positive_prompt, gen_data["steps"])
perc_per_step = round(100 / gen_data["steps"], 2)
for i in range(len(prompt_per_step)):
if self.textencode("switch", prompt_per_step[i], clip_skip):
updated_conditions = True
positive_switch = self.conditions["switch"]["cache"]
start_perc = round((perc_per_step * i) / 100, 2)
end_perc = round((perc_per_step * (i + 1)) / 100, 2)
if end_perc >= 0.99:
end_perc = 1
positive_switch = set_timestep_range(
positive_switch, start_perc, end_perc
)
switched_prompt += positive_switch
device = comfy.model_management.get_torch_device()
if controlnet is not None and "type" in controlnet and input_image is not None:
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Powering up ...", None)
)
input_image_pil = input_image.convert("RGB")
input_image = np.array(input_image_pil).astype(np.float32) / 255.0
input_image = torch.from_numpy(input_image)[None,]
input_image = ImageScaleToTotalPixels().upscale(
image=input_image, upscale_method="bicubic", megapixels=1.0
)[0]
self.refresh_controlnet(name=controlnet["type"])
match controlnet["type"].lower():
case "canny":
input_image = Canny().detect_edge(
image=input_image,
low_threshold=float(controlnet["edge_low"]),
high_threshold=float(controlnet["edge_high"]),
)[0]
updated_conditions = True
case "depth":
updated_conditions = True
if self.xl_controlnet:
(
self.conditions["+"]["cache"],
self.conditions["-"]["cache"],
) = ControlNetApplyAdvanced().apply_controlnet(
positive=self.conditions["+"]["cache"],
negative=self.conditions["-"]["cache"],
control_net=self.xl_controlnet,
image=input_image,
strength=float(controlnet["strength"]),
start_percent=float(controlnet["start"]),
end_percent=float(controlnet["stop"]),
)
self.conditions["+"]["text"] = None
self.conditions["-"]["text"] = None
if controlnet["type"].lower() == "img2img":
latent = VAEEncode().encode(
vae=self.xl_base_patched.vae, pixels=input_image
)[0]
force_full_denoise = False
denoise = float(controlnet.get("denoise", controlnet.get("strength")))
img2img_mode = True
if not img2img_mode:
if (
isinstance(self.xl_base.unet.model, SD3) or
isinstance(self.xl_base.unet.model, Flux) or
isinstance(self.xl_base.unet.model, Lumina2)
):
latent = EmptySD3LatentImage().generate(
width=gen_data["width"], height=gen_data["height"], batch_size=1
)[0]
else: # SDXL and unknown
latent = EmptyLatentImage().generate(
width=gen_data["width"], height=gen_data["height"], batch_size=1
)[0]
force_full_denoise = False
denoise = None
if "inpaint_toggle" in gen_data and gen_data["inpaint_toggle"]:
# This is a _very_ ugly workaround since we had to shrink the inpaint image
# to not break the ui.
main_image = Image.open(gen_data["main_view"])
image = np.asarray(main_image)
# image = image[..., :-1]
image = torch.from_numpy(image)[None,] / 255.0
inpaint_view = Image.fromarray(gen_data["inpaint_view"]["layers"][0])
red, green, blue, mask = inpaint_view.split()
mask = mask.resize((main_image.width, main_image.height), Image.Resampling.LANCZOS)
mask = np.asarray(mask)
# mask = mask[:, :, 0]
mask = torch.from_numpy(mask)[None,] / 255.0
latent = VAEEncodeForInpaint().encode(
vae=self.xl_base_patched.vae,
pixels=image,
mask=mask,
grow_mask_by=20,
)[0]
latent_image = latent["samples"]
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
previewer = get_previewer(device, self.xl_base_patched.unet.model.latent_format)
pbar = comfy.utils.ProgressBar(gen_data["steps"])
def callback_function(step, x0, x, total_steps):
y = None
if previewer:
y = previewer.preview(x0, step, total_steps)
if callback is not None:
callback(step, x0, x, total_steps, y)
pbar.update_absolute(step + 1, total_steps, None)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Prepare models ...", None)
)
if updated_conditions:
conds = {
0: self.conditions["+"]["cache"],
1: self.conditions["-"]["cache"],
}
self.models, self.inference_memory = get_additional_models(
conds,
self.xl_base_patched.unet.model_dtype(),
)
comfy.model_management.load_models_gpu([self.xl_base_patched.unet])
comfy.model_management.load_models_gpu(self.models)
noise = noise.to(device)
latent_image = latent_image.to(device)
# Use FluxGuidance for Flux
positive_cond = switched_prompt if switched_prompt else self.conditions["+"]["cache"]
if isinstance(self.xl_base.unet.model, Flux):
positive_cond = conditioning_set_values(positive_cond, {"guidance": cfg})
cfg = 1.0
kwargs = {
"cfg": cfg,
"latent_image": latent_image,
"start_step": 0,
"last_step": gen_data["steps"],
"force_full_denoise": force_full_denoise,
"denoise_mask": noise_mask,
"sigmas": None,
"disable_pbar": False,
"seed": seed,
"callback": callback_function,
}
sampler = KSampler(
self.xl_base_patched.unet,
steps=gen_data["steps"],
device=device,
sampler=sampler_name,
scheduler=scheduler,
denoise=denoise,
model_options=self.xl_base_patched.unet.model_options,
)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Start sampling ...", None)
)
samples = sampler.sample(
noise,
positive_cond,
self.conditions["-"]["cache"],
**kwargs,
)
cleanup_additional_models(self.models)
sampled_latent = latent.copy()
sampled_latent["samples"] = samples
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"VAE decoding ...", None)
)
decoded_latent = VAEDecode().decode(
samples=sampled_latent, vae=self.xl_base_patched.vae
)[0]
images = [
np.clip(255.0 * y.cpu().numpy(), 0, 255).astype(np.uint8)
for y in decoded_latent
]
if callback is not None:
callback(gen_data["steps"], 0, 0, gen_data["steps"], images[0])
return images