import gc import numpy as np import os import torch import traceback import sys import modules.controlnet import modules.async_worker as worker import modules.prompt_processing as pp from PIL import Image, ImageOps from comfy.model_base import BaseModel, SDXL, SD3, Flux, Lumina2 from shared import path_manager, settings import shared from pathlib import Path import json import random import comfy.utils import comfy.model_management from comfy.sd import load_checkpoint_guess_config, load_state_dict_guess_config from tqdm import tqdm from comfy_extras.nodes_model_advanced import ModelSamplingAuraFlow from nodes import ( CLIPTextEncode, CLIPSetLastLayer, ControlNetApplyAdvanced, EmptyLatentImage, VAEDecode, VAEEncode, VAEEncodeForInpaint, CLIPLoader, VAELoader, ) from comfy.sampler_helpers import ( cleanup_additional_models, convert_cond, get_additional_models, prepare_mask, ) from comfy_extras.nodes_sd3 import EmptySD3LatentImage from node_helpers import conditioning_set_values from comfy.samplers import KSampler from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels from comfy_extras.nodes_canny import Canny from comfy_extras.nodes_freelunch import FreeU from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP, VAE from comfy.utils import load_torch_file from comfy.sd import save_checkpoint from modules.pipleline_utils import ( get_previewer, clean_prompt_cond_caches, set_timestep_range, ) #from comfyui_gguf.nodes import gguf_sd_loader, DualCLIPLoaderGGUF, GGUFModelPatcher #from comfyui_gguf.ops import GGMLOps from calcuis_gguf.pig import load_gguf_sd, GGMLOps, GGUFModelPatcher from calcuis_gguf.pig import DualClipLoaderGGUF as DualCLIPLoaderGGUF class pipeline: pipeline_type = ["sdxl", "ssd", "sd3", "flux", "lumina2"] comfy.model_management.DISABLE_SMART_MEMORY = False comfy.model_management.EXTRA_RESERVED_VRAM = 800 * 1024 * 1024 class StableDiffusionModel: def __init__(self, unet, vae, clip, clip_vision): self.unet = unet self.vae = vae self.clip = clip self.clip_vision = clip_vision def to_meta(self): if self.unet is not None: self.unet.model.to("meta") if self.clip is not None: self.clip.cond_stage_model.to("meta") if self.vae is not None: self.vae.first_stage_model.to("meta") xl_base: StableDiffusionModel = None xl_base_hash = "" xl_base_patched: StableDiffusionModel = None xl_base_patched_hash = "" xl_base_patched_extra = set() xl_controlnet: StableDiffusionModel = None xl_controlnet_hash = "" models = [] inference_memory = None ggml_ops = GGMLOps() def get_clip_name(self, shortname): # List of short names and default names for different text encoders defaults = { "clip_g": "clip_g.safetensors", "clip_gemma": "gemma_2_2b_fp16.safetensors", "clip_l": "clip_l.safetensors", "clip_t5": "t5-v1_1-xxl-encoder-Q3_K_S.gguf", } return settings.default_settings.get(shortname, defaults[shortname] if shortname in defaults else None) # FIXME move this to separate file def merge_models(self, name): print(f"Loading merge: {name}") self.xl_base_patched = None self.xl_base_patched_hash = "" self.xl_base_patched_extra = set() self.conditions = None filename = shared.models.get_file("checkpoints", name) cache_name = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors")) if Path(cache_name).exists() and Path(cache_name).stat().st_mtime >= Path(filename).stat().st_mtime: print(f"Loading cached version:") self.load_base_model(cache_name) return try: with filename.open() as f: merge_data = json.load(f) if 'comment' in merge_data: print(f" {merge_data['comment']}") filename = shared.models.get_file("checkpoints", merge_data["base"]["name"]) norm = 1.0 if "models" in merge_data and len(merge_data["models"]) > 0: weights = sum([merge_data["base"]["weight"]] + [x.get("weight") for x in merge_data["models"]]) if "normalize" in merge_data: norm = float(merge_data["normalize"]) / weights else: norm = 1.0 / weights print(f"Loading base {merge_data['base']['name']} ({round(merge_data['base']['weight'] * norm * 100)}%)") with torch.torch.inference_mode(): unet, clip, vae, clip_vision = load_checkpoint_guess_config(str(filename)) self.xl_base = self.StableDiffusionModel( unet=unet, clip=clip, vae=vae, clip_vision=clip_vision ) if self.xl_base is not None: self.xl_base_hash = name self.xl_base_patched = self.xl_base self.xl_base_patched_hash = "" except Exception as e: self.xl_base = None print(f"ERROR: {e}") return if "models" in merge_data and len(merge_data["models"]) > 0: device = comfy.model_management.get_torch_device() mp = ModelPatcher(self.xl_base_patched.unet, device, "cpu", size=1) w = float(merge_data["base"]["weight"]) * norm for m in merge_data["models"]: print(f"Merging {m['name']} ({round(m['weight'] * norm * 100)}%)") filename = str(shared.models.get_file("checkpoints", m["name"])) # FIXME add error check?` with torch.torch.inference_mode(): m_unet, m_clip, m_vae, m_clip_vision = load_checkpoint_guess_config(str(filename)) del m_clip del m_vae del m_clip_vision kp = m_unet.get_key_patches("diffusion_model.") for k in kp: mp.model.add_patches({k: kp[k]}, strength_patch=float(m['weight'] * norm), strength_model=w) del m_unet w = 1.0 self.xl_base = self.StableDiffusionModel( unet=mp.model, clip=clip, vae=vae, clip_vision=clip_vision ) if "loras" in merge_data and len(merge_data["loras"]) > 0: loras = [(x.get("name"), x.get("weight")) for x in merge_data["loras"]] self.load_loras(loras) self.xl_base = self.xl_base_patched if 'cache' in merge_data and merge_data['cache'] == True: filename = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors")) print(f"Saving merged model: {filename}") with torch.torch.inference_mode(): save_checkpoint( filename, self.xl_base.unet, clip=self.xl_base.clip, vae=self.xl_base.vae, clip_vision=self.xl_base.clip_vision, metadata={"rf_merge_data": str(merge_data)} ) return def load_base_model(self, name, unet_only=False, input_unet=None): if self.xl_base_hash == name and self.xl_base_patched_extra == set(): return filename = shared.models.get_file("checkpoints", name) # If we don't have a filename, get the default. if filename is None: base_model = settings.default_settings.get("base_model", "sd_xl_base_1.0_0.9vae.safetensors") filename = path_manager.get_folder_file_path( "checkpoints", base_model, ) if Path(filename).suffix == '.merge': self.merge_models(name) return if input_unet is None: # Be quiet if we already loaded a unet print(f"Loading base {'unet' if unet_only else 'model'}: {name}") self.xl_base = None self.xl_base_hash = "" self.xl_base_patched = None self.xl_base_patched_hash = "" self.xl_base_patched_extra = set() self.conditions = None gc.collect(generation=2) comfy.model_management.cleanup_models() comfy.model_management.soft_empty_cache() unet = None filename = str(filename) # FIXME use Path and suffix instead? if filename.endswith(".gguf") or unet_only: with torch.torch.inference_mode(): try: if input_unet is not None: if isinstance(input_unet, ModelPatcher): unet = input_unet else: unet = comfy.sd.load_diffusion_model_state_dict( input_unet, model_options={"custom_operations": self.ggml_ops} ) unet = GGUFModelPatcher.clone(unet) unet.patch_on_device = True elif filename.endswith(".gguf"): sd = load_gguf_sd(filename) unet = comfy.sd.load_diffusion_model_state_dict( sd, model_options={"custom_operations": self.ggml_ops} ) unet = GGUFModelPatcher.clone(unet) unet.patch_on_device = True else: model_options = {} model_options["dtype"] = torch.float8_e4m3fn # FIXME should be a setting unet = comfy.sd.load_diffusion_model(filename, model_options=model_options) # Get text encoders (clip) and vae to match the unet clip_names = [] if isinstance(unet.model, Flux): clip_names.append(self.get_clip_name("clip_l")) clip_names.append(self.get_clip_name("clip_t5")) clip_type = comfy.sd.CLIPType.FLUX vae_name = settings.default_settings.get("vae_flux", "ae.safetensors") elif isinstance(unet.model, SD3): clip_names.append(self.get_clip_name("clip_l")) clip_names.append(self.get_clip_name("clip_g")) clip_names.append(self.get_clip_name("clip_t5")) clip_type = comfy.sd.CLIPType.SD3 vae_name = settings.default_settings.get("vae_sd3", "sd3_vae.safetensors") elif isinstance(unet.model, Lumina2): clip_names.append(self.get_clip_name("clip_gemma")) clip_type = comfy.sd.CLIPType.LUMINA2 vae_name = settings.default_settings.get("vae_lumina2", "lumina2_vae_fp32.safetensors") unet = ModelSamplingAuraFlow().patch_aura( model=unet, shift=settings.default_settings.get("lumina2_shift", 3.0), )[0] else: # SDXL clip_names.append(self.get_clip_name("clip_l")) clip_names.append(self.get_clip_name("clip_g")) clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION vae_name = settings.default_settings.get("vae_sdxl", "sdxl_vae.safetensors") clip_paths = [] for clip_name in clip_names: clip_paths.append( str( path_manager.get_folder_file_path( "clip", clip_name, default = os.path.join(path_manager.model_paths["clip_path"], clip_name) ) ) ) clip_loader = DualCLIPLoaderGGUF() print(f"Loading CLIP: {clip_names}") clip = clip_loader.load_patcher( clip_paths, clip_type, clip_loader.load_data(clip_paths) ) vae_path = path_manager.get_folder_file_path( "vae", vae_name, default = os.path.join(path_manager.model_paths["vae_path"], vae_name) ) print(f"Loading VAE: {vae_name}") sd = comfy.utils.load_torch_file(str(vae_path)) vae = comfy.sd.VAE(sd=sd) clip_vision = None except Exception as e: unet = None traceback.print_exc() else: sd = None unet = None try: with torch.torch.inference_mode(): sd = comfy.utils.load_torch_file(filename) except Exception as e: # Failed loading print(f"ERROR: Failed loading {filename}: {e}") if sd is not None: aio = load_state_dict_guess_config(sd) if isinstance(aio, tuple): unet, clip, vae, clip_vision = aio if ( isinstance(unet, ModelPatcher) and isinstance(clip, CLIP) and isinstance(vae, VAE) ): # If we got here, we have all models. Dump sd since we don't need it sd = None else: if isinstance(unet, ModelPatcher): sd = unet if sd is not None: # We got something, assume it was a unet self.load_base_model( filename, unet_only=True, input_unet=sd, ) return else: unet = None if unet == None: print(f"Failed to load {name}") self.xl_base = None self.xl_base_hash = "" self.xl_base_patched = None self.xl_base_patched_hash = "" else: self.xl_base = self.StableDiffusionModel( unet=unet, clip=clip, vae=vae, clip_vision=clip_vision ) if not ( isinstance(self.xl_base.unet.model, BaseModel) or isinstance(self.xl_base.unet.model, SDXL) or isinstance(self.xl_base.unet.model, SD3) or isinstance(self.xl_base.unet.model, Flux) or isinstance(self.xl_base.unet.model, Lumina2) ): print( f"Model {type(self.xl_base.unet.model)} not supported. RuinedFooocus only support SD1.x/SDXL/SD3/Flux/Lumina2 models as the base model." ) self.xl_base = None if self.xl_base is not None: self.xl_base_hash = name self.xl_base_patched = self.xl_base self.xl_base_patched_hash = "" # self.xl_base_patched.unet.model.to("cuda") #print(f"Base model loaded: {self.xl_base_hash}") return def freeu(self, model, b1, b2, s1, s2): freeu_model = FreeU() unet = freeu_model.patch(model=model.unet, b1=b1, b2=b2, s1=s1, s2=s2)[0] return self.StableDiffusionModel( unet=unet, clip=model.clip, vae=model.vae, clip_vision=model.clip_vision ) def load_loras(self, loras): loaded_loras = [] model = self.xl_base for name, weight in loras: if name == "None" or weight == 0: continue filename = str(shared.models.get_file("loras", name)) print(f"Loading LoRAs: {name}") try: lora = comfy.utils.load_torch_file(filename, safe_load=True) unet, clip = comfy.sd.load_lora_for_models( model.unet, model.clip, lora, weight, weight ) model = self.StableDiffusionModel( unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision, ) loaded_loras += [(name, weight)] except: print(f"Error loading LoRA: {filename}") pass self.xl_base_patched = model # Uncomment below to enable FreeU shit # self.xl_base_patched = self.freeu(model, 1.01, 1.02, 0.99, 0.95) # self.xl_base_patched_hash = str(loras + [1.01, 1.02, 0.99, 0.95]) self.xl_base_patched_hash = str(loras) print(f"LoRAs loaded: {loaded_loras}") return def refresh_controlnet(self, name=None): if self.xl_controlnet_hash == str(self.xl_controlnet): return filename = modules.controlnet.get_model(name) if filename is not None and self.xl_controlnet_hash != name: self.xl_controlnet = comfy.controlnet.load_controlnet(str(filename)) self.xl_controlnet_hash = name print(f"ControlNet model loaded: {self.xl_controlnet_hash}") if self.xl_controlnet_hash != name: self.xl_controlnet = None self.xl_controlnet_hash = None print(f"Controlnet model unloaded") conditions = None def textencode(self, id, text, clip_skip): update = False hash = f"{text} {clip_skip}" if hash != self.conditions[id]["text"]: if clip_skip > 1: self.xl_base_patched.clip = CLIPSetLastLayer().set_last_layer( self.xl_base_patched.clip, clip_skip * -1 )[0] self.conditions[id]["cache"] = CLIPTextEncode().encode( clip=self.xl_base_patched.clip, text=text )[0] self.conditions[id]["text"] = hash update = True return update @torch.inference_mode() def process( self, gen_data=None, callback=None, ): try: if self.xl_base_patched == None or not ( isinstance(self.xl_base_patched.unet.model, BaseModel) or isinstance(self.xl_base_patched.unet.model, SDXL) or isinstance(self.xl_base_patched.unet.model, SD3) or isinstance(self.xl_base_patched.unet.model, Flux) or isinstance(self.xl_base_patched.unet.model, Lumina2) ): print(f"ERROR: Can only use SD1.x, SDXL, SD3, Flux or Lumina2 models") worker.interrupt_ruined_processing = True if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Can only use SDXL, SD3 or Flux models ...", "html/error.png") ) return [] except Exception as e: # Something went very wrong print(f"ERROR: {e}") worker.interrupt_ruined_processing = True if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Error when trying to use model ...", "html/error.png") ) return [] positive_prompt = gen_data["positive_prompt"] negative_prompt = gen_data["negative_prompt"] input_image = gen_data["input_image"] controlnet = modules.controlnet.get_settings(gen_data) cfg = gen_data["cfg"] sampler_name = gen_data["sampler_name"] scheduler = gen_data["scheduler"] clip_skip = gen_data["clip_skip"] img2img_mode = False input_image_pil = None seed = gen_data["seed"] if isinstance(gen_data["seed"], int) else random.randint(1, 2**32) if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Processing text encoding ...", None) ) updated_conditions = False if self.conditions is None: self.conditions = clean_prompt_cond_caches() if self.textencode("+", positive_prompt, clip_skip): updated_conditions = True if self.textencode("-", negative_prompt, clip_skip): updated_conditions = True switched_prompt = [] if "[" in positive_prompt and "]" in positive_prompt: if controlnet is not None and input_image is not None: print("ControlNet and [prompt|switching] do not work well together.") print("ControlNet will only be applied to the first prompt.") prompt_per_step = pp.prompt_switch_per_step(positive_prompt, gen_data["steps"]) perc_per_step = round(100 / gen_data["steps"], 2) for i in range(len(prompt_per_step)): if self.textencode("switch", prompt_per_step[i], clip_skip): updated_conditions = True positive_switch = self.conditions["switch"]["cache"] start_perc = round((perc_per_step * i) / 100, 2) end_perc = round((perc_per_step * (i + 1)) / 100, 2) if end_perc >= 0.99: end_perc = 1 positive_switch = set_timestep_range( positive_switch, start_perc, end_perc ) switched_prompt += positive_switch device = comfy.model_management.get_torch_device() if controlnet is not None and "type" in controlnet and input_image is not None: if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Powering up ...", None) ) input_image_pil = input_image.convert("RGB") input_image = np.array(input_image_pil).astype(np.float32) / 255.0 input_image = torch.from_numpy(input_image)[None,] input_image = ImageScaleToTotalPixels().upscale( image=input_image, upscale_method="bicubic", megapixels=1.0 )[0] self.refresh_controlnet(name=controlnet["type"]) match controlnet["type"].lower(): case "canny": input_image = Canny().detect_edge( image=input_image, low_threshold=float(controlnet["edge_low"]), high_threshold=float(controlnet["edge_high"]), )[0] updated_conditions = True case "depth": updated_conditions = True if self.xl_controlnet: ( self.conditions["+"]["cache"], self.conditions["-"]["cache"], ) = ControlNetApplyAdvanced().apply_controlnet( positive=self.conditions["+"]["cache"], negative=self.conditions["-"]["cache"], control_net=self.xl_controlnet, image=input_image, strength=float(controlnet["strength"]), start_percent=float(controlnet["start"]), end_percent=float(controlnet["stop"]), ) self.conditions["+"]["text"] = None self.conditions["-"]["text"] = None if controlnet["type"].lower() == "img2img": latent = VAEEncode().encode( vae=self.xl_base_patched.vae, pixels=input_image )[0] force_full_denoise = False denoise = float(controlnet.get("denoise", controlnet.get("strength"))) img2img_mode = True if not img2img_mode: if ( isinstance(self.xl_base.unet.model, SD3) or isinstance(self.xl_base.unet.model, Flux) or isinstance(self.xl_base.unet.model, Lumina2) ): latent = EmptySD3LatentImage().generate( width=gen_data["width"], height=gen_data["height"], batch_size=1 )[0] else: # SDXL and unknown latent = EmptyLatentImage().generate( width=gen_data["width"], height=gen_data["height"], batch_size=1 )[0] force_full_denoise = False denoise = None if "inpaint_toggle" in gen_data and gen_data["inpaint_toggle"]: # This is a _very_ ugly workaround since we had to shrink the inpaint image # to not break the ui. main_image = Image.open(gen_data["main_view"]) image = np.asarray(main_image) # image = image[..., :-1] image = torch.from_numpy(image)[None,] / 255.0 inpaint_view = Image.fromarray(gen_data["inpaint_view"]["layers"][0]) red, green, blue, mask = inpaint_view.split() mask = mask.resize((main_image.width, main_image.height), Image.Resampling.LANCZOS) mask = np.asarray(mask) # mask = mask[:, :, 0] mask = torch.from_numpy(mask)[None,] / 255.0 latent = VAEEncodeForInpaint().encode( vae=self.xl_base_patched.vae, pixels=image, mask=mask, grow_mask_by=20, )[0] latent_image = latent["samples"] batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: noise_mask = latent["noise_mask"] previewer = get_previewer(device, self.xl_base_patched.unet.model.latent_format) pbar = comfy.utils.ProgressBar(gen_data["steps"]) def callback_function(step, x0, x, total_steps): y = None if previewer: y = previewer.preview(x0, step, total_steps) if callback is not None: callback(step, x0, x, total_steps, y) pbar.update_absolute(step + 1, total_steps, None) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise.shape, device) if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Prepare models ...", None) ) if updated_conditions: conds = { 0: self.conditions["+"]["cache"], 1: self.conditions["-"]["cache"], } self.models, self.inference_memory = get_additional_models( conds, self.xl_base_patched.unet.model_dtype(), ) comfy.model_management.load_models_gpu([self.xl_base_patched.unet]) comfy.model_management.load_models_gpu(self.models) noise = noise.to(device) latent_image = latent_image.to(device) # Use FluxGuidance for Flux positive_cond = switched_prompt if switched_prompt else self.conditions["+"]["cache"] if isinstance(self.xl_base.unet.model, Flux): positive_cond = conditioning_set_values(positive_cond, {"guidance": cfg}) cfg = 1.0 kwargs = { "cfg": cfg, "latent_image": latent_image, "start_step": 0, "last_step": gen_data["steps"], "force_full_denoise": force_full_denoise, "denoise_mask": noise_mask, "sigmas": None, "disable_pbar": False, "seed": seed, "callback": callback_function, } sampler = KSampler( self.xl_base_patched.unet, steps=gen_data["steps"], device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=self.xl_base_patched.unet.model_options, ) if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"Start sampling ...", None) ) samples = sampler.sample( noise, positive_cond, self.conditions["-"]["cache"], **kwargs, ) cleanup_additional_models(self.models) sampled_latent = latent.copy() sampled_latent["samples"] = samples if callback is not None: worker.add_result( gen_data["task_id"], "preview", (-1, f"VAE decoding ...", None) ) decoded_latent = VAEDecode().decode( samples=sampled_latent, vae=self.xl_base_patched.vae )[0] images = [ np.clip(255.0 * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in decoded_latent ] if callback is not None: callback(gen_data["steps"], 0, 0, gen_data["steps"], images[0]) return images