Spaces:
Runtime error
Runtime error
import gc | |
import numpy as np | |
import os | |
import torch | |
import traceback | |
import sys | |
import modules.controlnet | |
import modules.async_worker as worker | |
import modules.prompt_processing as pp | |
from PIL import Image, ImageOps | |
from comfy.model_base import BaseModel, SDXL, SD3, Flux, Lumina2 | |
from shared import path_manager, settings | |
import shared | |
from pathlib import Path | |
import json | |
import random | |
import comfy.utils | |
import comfy.model_management | |
from comfy.sd import load_checkpoint_guess_config, load_state_dict_guess_config | |
from tqdm import tqdm | |
from comfy_extras.nodes_model_advanced import ModelSamplingAuraFlow | |
from nodes import ( | |
CLIPTextEncode, | |
CLIPSetLastLayer, | |
ControlNetApplyAdvanced, | |
EmptyLatentImage, | |
VAEDecode, | |
VAEEncode, | |
VAEEncodeForInpaint, | |
CLIPLoader, | |
VAELoader, | |
) | |
from comfy.sampler_helpers import ( | |
cleanup_additional_models, | |
convert_cond, | |
get_additional_models, | |
prepare_mask, | |
) | |
from comfy_extras.nodes_sd3 import EmptySD3LatentImage | |
from node_helpers import conditioning_set_values | |
from comfy.samplers import KSampler | |
from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels | |
from comfy_extras.nodes_canny import Canny | |
from comfy_extras.nodes_freelunch import FreeU | |
from comfy.model_patcher import ModelPatcher | |
from comfy.sd import CLIP, VAE | |
from comfy.utils import load_torch_file | |
from comfy.sd import save_checkpoint | |
from modules.pipleline_utils import ( | |
get_previewer, | |
clean_prompt_cond_caches, | |
set_timestep_range, | |
) | |
#from comfyui_gguf.nodes import gguf_sd_loader, DualCLIPLoaderGGUF, GGUFModelPatcher | |
#from comfyui_gguf.ops import GGMLOps | |
from calcuis_gguf.pig import load_gguf_sd, GGMLOps, GGUFModelPatcher | |
from calcuis_gguf.pig import DualClipLoaderGGUF as DualCLIPLoaderGGUF | |
class pipeline: | |
pipeline_type = ["sdxl", "ssd", "sd3", "flux", "lumina2"] | |
comfy.model_management.DISABLE_SMART_MEMORY = False | |
comfy.model_management.EXTRA_RESERVED_VRAM = 800 * 1024 * 1024 | |
class StableDiffusionModel: | |
def __init__(self, unet, vae, clip, clip_vision): | |
self.unet = unet | |
self.vae = vae | |
self.clip = clip | |
self.clip_vision = clip_vision | |
def to_meta(self): | |
if self.unet is not None: | |
self.unet.model.to("meta") | |
if self.clip is not None: | |
self.clip.cond_stage_model.to("meta") | |
if self.vae is not None: | |
self.vae.first_stage_model.to("meta") | |
xl_base: StableDiffusionModel = None | |
xl_base_hash = "" | |
xl_base_patched: StableDiffusionModel = None | |
xl_base_patched_hash = "" | |
xl_base_patched_extra = set() | |
xl_controlnet: StableDiffusionModel = None | |
xl_controlnet_hash = "" | |
models = [] | |
inference_memory = None | |
ggml_ops = GGMLOps() | |
def get_clip_name(self, shortname): | |
# List of short names and default names for different text encoders | |
defaults = { | |
"clip_g": "clip_g.safetensors", | |
"clip_gemma": "gemma_2_2b_fp16.safetensors", | |
"clip_l": "clip_l.safetensors", | |
"clip_t5": "t5-v1_1-xxl-encoder-Q3_K_S.gguf", | |
} | |
return settings.default_settings.get(shortname, defaults[shortname] if shortname in defaults else None) | |
# FIXME move this to separate file | |
def merge_models(self, name): | |
print(f"Loading merge: {name}") | |
self.xl_base_patched = None | |
self.xl_base_patched_hash = "" | |
self.xl_base_patched_extra = set() | |
self.conditions = None | |
filename = shared.models.get_file("checkpoints", name) | |
cache_name = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors")) | |
if Path(cache_name).exists() and Path(cache_name).stat().st_mtime >= Path(filename).stat().st_mtime: | |
print(f"Loading cached version:") | |
self.load_base_model(cache_name) | |
return | |
try: | |
with filename.open() as f: | |
merge_data = json.load(f) | |
if 'comment' in merge_data: | |
print(f" {merge_data['comment']}") | |
filename = shared.models.get_file("checkpoints", merge_data["base"]["name"]) | |
norm = 1.0 | |
if "models" in merge_data and len(merge_data["models"]) > 0: | |
weights = sum([merge_data["base"]["weight"]] + [x.get("weight") for x in merge_data["models"]]) | |
if "normalize" in merge_data: | |
norm = float(merge_data["normalize"]) / weights | |
else: | |
norm = 1.0 / weights | |
print(f"Loading base {merge_data['base']['name']} ({round(merge_data['base']['weight'] * norm * 100)}%)") | |
with torch.torch.inference_mode(): | |
unet, clip, vae, clip_vision = load_checkpoint_guess_config(str(filename)) | |
self.xl_base = self.StableDiffusionModel( | |
unet=unet, clip=clip, vae=vae, clip_vision=clip_vision | |
) | |
if self.xl_base is not None: | |
self.xl_base_hash = name | |
self.xl_base_patched = self.xl_base | |
self.xl_base_patched_hash = "" | |
except Exception as e: | |
self.xl_base = None | |
print(f"ERROR: {e}") | |
return | |
if "models" in merge_data and len(merge_data["models"]) > 0: | |
device = comfy.model_management.get_torch_device() | |
mp = ModelPatcher(self.xl_base_patched.unet, device, "cpu", size=1) | |
w = float(merge_data["base"]["weight"]) * norm | |
for m in merge_data["models"]: | |
print(f"Merging {m['name']} ({round(m['weight'] * norm * 100)}%)") | |
filename = str(shared.models.get_file("checkpoints", m["name"])) | |
# FIXME add error check?` | |
with torch.torch.inference_mode(): | |
m_unet, m_clip, m_vae, m_clip_vision = load_checkpoint_guess_config(str(filename)) | |
del m_clip | |
del m_vae | |
del m_clip_vision | |
kp = m_unet.get_key_patches("diffusion_model.") | |
for k in kp: | |
mp.model.add_patches({k: kp[k]}, strength_patch=float(m['weight'] * norm), strength_model=w) | |
del m_unet | |
w = 1.0 | |
self.xl_base = self.StableDiffusionModel( | |
unet=mp.model, clip=clip, vae=vae, clip_vision=clip_vision | |
) | |
if "loras" in merge_data and len(merge_data["loras"]) > 0: | |
loras = [(x.get("name"), x.get("weight")) for x in merge_data["loras"]] | |
self.load_loras(loras) | |
self.xl_base = self.xl_base_patched | |
if 'cache' in merge_data and merge_data['cache'] == True: | |
filename = str(Path(path_manager.model_paths["cache_path"] / "merges" / Path(name).name).with_suffix(".safetensors")) | |
print(f"Saving merged model: {filename}") | |
with torch.torch.inference_mode(): | |
save_checkpoint( | |
filename, | |
self.xl_base.unet, | |
clip=self.xl_base.clip, | |
vae=self.xl_base.vae, | |
clip_vision=self.xl_base.clip_vision, | |
metadata={"rf_merge_data": str(merge_data)} | |
) | |
return | |
def load_base_model(self, name, unet_only=False, input_unet=None): | |
if self.xl_base_hash == name and self.xl_base_patched_extra == set(): | |
return | |
filename = shared.models.get_file("checkpoints", name) | |
# If we don't have a filename, get the default. | |
if filename is None: | |
base_model = settings.default_settings.get("base_model", "sd_xl_base_1.0_0.9vae.safetensors") | |
filename = path_manager.get_folder_file_path( | |
"checkpoints", | |
base_model, | |
) | |
if Path(filename).suffix == '.merge': | |
self.merge_models(name) | |
return | |
if input_unet is None: # Be quiet if we already loaded a unet | |
print(f"Loading base {'unet' if unet_only else 'model'}: {name}") | |
self.xl_base = None | |
self.xl_base_hash = "" | |
self.xl_base_patched = None | |
self.xl_base_patched_hash = "" | |
self.xl_base_patched_extra = set() | |
self.conditions = None | |
gc.collect(generation=2) | |
comfy.model_management.cleanup_models() | |
comfy.model_management.soft_empty_cache() | |
unet = None | |
filename = str(filename) # FIXME use Path and suffix instead? | |
if filename.endswith(".gguf") or unet_only: | |
with torch.torch.inference_mode(): | |
try: | |
if input_unet is not None: | |
if isinstance(input_unet, ModelPatcher): | |
unet = input_unet | |
else: | |
unet = comfy.sd.load_diffusion_model_state_dict( | |
input_unet, model_options={"custom_operations": self.ggml_ops} | |
) | |
unet = GGUFModelPatcher.clone(unet) | |
unet.patch_on_device = True | |
elif filename.endswith(".gguf"): | |
sd = load_gguf_sd(filename) | |
unet = comfy.sd.load_diffusion_model_state_dict( | |
sd, model_options={"custom_operations": self.ggml_ops} | |
) | |
unet = GGUFModelPatcher.clone(unet) | |
unet.patch_on_device = True | |
else: | |
model_options = {} | |
model_options["dtype"] = torch.float8_e4m3fn # FIXME should be a setting | |
unet = comfy.sd.load_diffusion_model(filename, model_options=model_options) | |
# Get text encoders (clip) and vae to match the unet | |
clip_names = [] | |
if isinstance(unet.model, Flux): | |
clip_names.append(self.get_clip_name("clip_l")) | |
clip_names.append(self.get_clip_name("clip_t5")) | |
clip_type = comfy.sd.CLIPType.FLUX | |
vae_name = settings.default_settings.get("vae_flux", "ae.safetensors") | |
elif isinstance(unet.model, SD3): | |
clip_names.append(self.get_clip_name("clip_l")) | |
clip_names.append(self.get_clip_name("clip_g")) | |
clip_names.append(self.get_clip_name("clip_t5")) | |
clip_type = comfy.sd.CLIPType.SD3 | |
vae_name = settings.default_settings.get("vae_sd3", "sd3_vae.safetensors") | |
elif isinstance(unet.model, Lumina2): | |
clip_names.append(self.get_clip_name("clip_gemma")) | |
clip_type = comfy.sd.CLIPType.LUMINA2 | |
vae_name = settings.default_settings.get("vae_lumina2", "lumina2_vae_fp32.safetensors") | |
unet = ModelSamplingAuraFlow().patch_aura( | |
model=unet, | |
shift=settings.default_settings.get("lumina2_shift", 3.0), | |
)[0] | |
else: # SDXL | |
clip_names.append(self.get_clip_name("clip_l")) | |
clip_names.append(self.get_clip_name("clip_g")) | |
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION | |
vae_name = settings.default_settings.get("vae_sdxl", "sdxl_vae.safetensors") | |
clip_paths = [] | |
for clip_name in clip_names: | |
clip_paths.append( | |
str( | |
path_manager.get_folder_file_path( | |
"clip", | |
clip_name, | |
default = os.path.join(path_manager.model_paths["clip_path"], clip_name) | |
) | |
) | |
) | |
clip_loader = DualCLIPLoaderGGUF() | |
print(f"Loading CLIP: {clip_names}") | |
clip = clip_loader.load_patcher( | |
clip_paths, | |
clip_type, | |
clip_loader.load_data(clip_paths) | |
) | |
vae_path = path_manager.get_folder_file_path( | |
"vae", | |
vae_name, | |
default = os.path.join(path_manager.model_paths["vae_path"], vae_name) | |
) | |
print(f"Loading VAE: {vae_name}") | |
sd = comfy.utils.load_torch_file(str(vae_path)) | |
vae = comfy.sd.VAE(sd=sd) | |
clip_vision = None | |
except Exception as e: | |
unet = None | |
traceback.print_exc() | |
else: | |
sd = None | |
unet = None | |
try: | |
with torch.torch.inference_mode(): | |
sd = comfy.utils.load_torch_file(filename) | |
except Exception as e: | |
# Failed loading | |
print(f"ERROR: Failed loading {filename}: {e}") | |
if sd is not None: | |
aio = load_state_dict_guess_config(sd) | |
if isinstance(aio, tuple): | |
unet, clip, vae, clip_vision = aio | |
if ( | |
isinstance(unet, ModelPatcher) and | |
isinstance(clip, CLIP) and | |
isinstance(vae, VAE) | |
): | |
# If we got here, we have all models. Dump sd since we don't need it | |
sd = None | |
else: | |
if isinstance(unet, ModelPatcher): | |
sd = unet | |
if sd is not None: | |
# We got something, assume it was a unet | |
self.load_base_model( | |
filename, | |
unet_only=True, | |
input_unet=sd, | |
) | |
return | |
else: | |
unet = None | |
if unet == None: | |
print(f"Failed to load {name}") | |
self.xl_base = None | |
self.xl_base_hash = "" | |
self.xl_base_patched = None | |
self.xl_base_patched_hash = "" | |
else: | |
self.xl_base = self.StableDiffusionModel( | |
unet=unet, clip=clip, vae=vae, clip_vision=clip_vision | |
) | |
if not ( | |
isinstance(self.xl_base.unet.model, BaseModel) or | |
isinstance(self.xl_base.unet.model, SDXL) or | |
isinstance(self.xl_base.unet.model, SD3) or | |
isinstance(self.xl_base.unet.model, Flux) or | |
isinstance(self.xl_base.unet.model, Lumina2) | |
): | |
print( | |
f"Model {type(self.xl_base.unet.model)} not supported. RuinedFooocus only support SD1.x/SDXL/SD3/Flux/Lumina2 models as the base model." | |
) | |
self.xl_base = None | |
if self.xl_base is not None: | |
self.xl_base_hash = name | |
self.xl_base_patched = self.xl_base | |
self.xl_base_patched_hash = "" | |
# self.xl_base_patched.unet.model.to("cuda") | |
#print(f"Base model loaded: {self.xl_base_hash}") | |
return | |
def freeu(self, model, b1, b2, s1, s2): | |
freeu_model = FreeU() | |
unet = freeu_model.patch(model=model.unet, b1=b1, b2=b2, s1=s1, s2=s2)[0] | |
return self.StableDiffusionModel( | |
unet=unet, clip=model.clip, vae=model.vae, clip_vision=model.clip_vision | |
) | |
def load_loras(self, loras): | |
loaded_loras = [] | |
model = self.xl_base | |
for name, weight in loras: | |
if name == "None" or weight == 0: | |
continue | |
filename = str(shared.models.get_file("loras", name)) | |
print(f"Loading LoRAs: {name}") | |
try: | |
lora = comfy.utils.load_torch_file(filename, safe_load=True) | |
unet, clip = comfy.sd.load_lora_for_models( | |
model.unet, model.clip, lora, weight, weight | |
) | |
model = self.StableDiffusionModel( | |
unet=unet, | |
clip=clip, | |
vae=model.vae, | |
clip_vision=model.clip_vision, | |
) | |
loaded_loras += [(name, weight)] | |
except: | |
print(f"Error loading LoRA: {filename}") | |
pass | |
self.xl_base_patched = model | |
# Uncomment below to enable FreeU shit | |
# self.xl_base_patched = self.freeu(model, 1.01, 1.02, 0.99, 0.95) | |
# self.xl_base_patched_hash = str(loras + [1.01, 1.02, 0.99, 0.95]) | |
self.xl_base_patched_hash = str(loras) | |
print(f"LoRAs loaded: {loaded_loras}") | |
return | |
def refresh_controlnet(self, name=None): | |
if self.xl_controlnet_hash == str(self.xl_controlnet): | |
return | |
filename = modules.controlnet.get_model(name) | |
if filename is not None and self.xl_controlnet_hash != name: | |
self.xl_controlnet = comfy.controlnet.load_controlnet(str(filename)) | |
self.xl_controlnet_hash = name | |
print(f"ControlNet model loaded: {self.xl_controlnet_hash}") | |
if self.xl_controlnet_hash != name: | |
self.xl_controlnet = None | |
self.xl_controlnet_hash = None | |
print(f"Controlnet model unloaded") | |
conditions = None | |
def textencode(self, id, text, clip_skip): | |
update = False | |
hash = f"{text} {clip_skip}" | |
if hash != self.conditions[id]["text"]: | |
if clip_skip > 1: | |
self.xl_base_patched.clip = CLIPSetLastLayer().set_last_layer( | |
self.xl_base_patched.clip, clip_skip * -1 | |
)[0] | |
self.conditions[id]["cache"] = CLIPTextEncode().encode( | |
clip=self.xl_base_patched.clip, text=text | |
)[0] | |
self.conditions[id]["text"] = hash | |
update = True | |
return update | |
def process( | |
self, | |
gen_data=None, | |
callback=None, | |
): | |
try: | |
if self.xl_base_patched == None or not ( | |
isinstance(self.xl_base_patched.unet.model, BaseModel) or | |
isinstance(self.xl_base_patched.unet.model, SDXL) or | |
isinstance(self.xl_base_patched.unet.model, SD3) or | |
isinstance(self.xl_base_patched.unet.model, Flux) or | |
isinstance(self.xl_base_patched.unet.model, Lumina2) | |
): | |
print(f"ERROR: Can only use SD1.x, SDXL, SD3, Flux or Lumina2 models") | |
worker.interrupt_ruined_processing = True | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Can only use SDXL, SD3 or Flux models ...", "html/error.png") | |
) | |
return [] | |
except Exception as e: | |
# Something went very wrong | |
print(f"ERROR: {e}") | |
worker.interrupt_ruined_processing = True | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Error when trying to use model ...", "html/error.png") | |
) | |
return [] | |
positive_prompt = gen_data["positive_prompt"] | |
negative_prompt = gen_data["negative_prompt"] | |
input_image = gen_data["input_image"] | |
controlnet = modules.controlnet.get_settings(gen_data) | |
cfg = gen_data["cfg"] | |
sampler_name = gen_data["sampler_name"] | |
scheduler = gen_data["scheduler"] | |
clip_skip = gen_data["clip_skip"] | |
img2img_mode = False | |
input_image_pil = None | |
seed = gen_data["seed"] if isinstance(gen_data["seed"], int) else random.randint(1, 2**32) | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Processing text encoding ...", None) | |
) | |
updated_conditions = False | |
if self.conditions is None: | |
self.conditions = clean_prompt_cond_caches() | |
if self.textencode("+", positive_prompt, clip_skip): | |
updated_conditions = True | |
if self.textencode("-", negative_prompt, clip_skip): | |
updated_conditions = True | |
switched_prompt = [] | |
if "[" in positive_prompt and "]" in positive_prompt: | |
if controlnet is not None and input_image is not None: | |
print("ControlNet and [prompt|switching] do not work well together.") | |
print("ControlNet will only be applied to the first prompt.") | |
prompt_per_step = pp.prompt_switch_per_step(positive_prompt, gen_data["steps"]) | |
perc_per_step = round(100 / gen_data["steps"], 2) | |
for i in range(len(prompt_per_step)): | |
if self.textencode("switch", prompt_per_step[i], clip_skip): | |
updated_conditions = True | |
positive_switch = self.conditions["switch"]["cache"] | |
start_perc = round((perc_per_step * i) / 100, 2) | |
end_perc = round((perc_per_step * (i + 1)) / 100, 2) | |
if end_perc >= 0.99: | |
end_perc = 1 | |
positive_switch = set_timestep_range( | |
positive_switch, start_perc, end_perc | |
) | |
switched_prompt += positive_switch | |
device = comfy.model_management.get_torch_device() | |
if controlnet is not None and "type" in controlnet and input_image is not None: | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Powering up ...", None) | |
) | |
input_image_pil = input_image.convert("RGB") | |
input_image = np.array(input_image_pil).astype(np.float32) / 255.0 | |
input_image = torch.from_numpy(input_image)[None,] | |
input_image = ImageScaleToTotalPixels().upscale( | |
image=input_image, upscale_method="bicubic", megapixels=1.0 | |
)[0] | |
self.refresh_controlnet(name=controlnet["type"]) | |
match controlnet["type"].lower(): | |
case "canny": | |
input_image = Canny().detect_edge( | |
image=input_image, | |
low_threshold=float(controlnet["edge_low"]), | |
high_threshold=float(controlnet["edge_high"]), | |
)[0] | |
updated_conditions = True | |
case "depth": | |
updated_conditions = True | |
if self.xl_controlnet: | |
( | |
self.conditions["+"]["cache"], | |
self.conditions["-"]["cache"], | |
) = ControlNetApplyAdvanced().apply_controlnet( | |
positive=self.conditions["+"]["cache"], | |
negative=self.conditions["-"]["cache"], | |
control_net=self.xl_controlnet, | |
image=input_image, | |
strength=float(controlnet["strength"]), | |
start_percent=float(controlnet["start"]), | |
end_percent=float(controlnet["stop"]), | |
) | |
self.conditions["+"]["text"] = None | |
self.conditions["-"]["text"] = None | |
if controlnet["type"].lower() == "img2img": | |
latent = VAEEncode().encode( | |
vae=self.xl_base_patched.vae, pixels=input_image | |
)[0] | |
force_full_denoise = False | |
denoise = float(controlnet.get("denoise", controlnet.get("strength"))) | |
img2img_mode = True | |
if not img2img_mode: | |
if ( | |
isinstance(self.xl_base.unet.model, SD3) or | |
isinstance(self.xl_base.unet.model, Flux) or | |
isinstance(self.xl_base.unet.model, Lumina2) | |
): | |
latent = EmptySD3LatentImage().generate( | |
width=gen_data["width"], height=gen_data["height"], batch_size=1 | |
)[0] | |
else: # SDXL and unknown | |
latent = EmptyLatentImage().generate( | |
width=gen_data["width"], height=gen_data["height"], batch_size=1 | |
)[0] | |
force_full_denoise = False | |
denoise = None | |
if "inpaint_toggle" in gen_data and gen_data["inpaint_toggle"]: | |
# This is a _very_ ugly workaround since we had to shrink the inpaint image | |
# to not break the ui. | |
main_image = Image.open(gen_data["main_view"]) | |
image = np.asarray(main_image) | |
# image = image[..., :-1] | |
image = torch.from_numpy(image)[None,] / 255.0 | |
inpaint_view = Image.fromarray(gen_data["inpaint_view"]["layers"][0]) | |
red, green, blue, mask = inpaint_view.split() | |
mask = mask.resize((main_image.width, main_image.height), Image.Resampling.LANCZOS) | |
mask = np.asarray(mask) | |
# mask = mask[:, :, 0] | |
mask = torch.from_numpy(mask)[None,] / 255.0 | |
latent = VAEEncodeForInpaint().encode( | |
vae=self.xl_base_patched.vae, | |
pixels=image, | |
mask=mask, | |
grow_mask_by=20, | |
)[0] | |
latent_image = latent["samples"] | |
batch_inds = latent["batch_index"] if "batch_index" in latent else None | |
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) | |
noise_mask = None | |
if "noise_mask" in latent: | |
noise_mask = latent["noise_mask"] | |
previewer = get_previewer(device, self.xl_base_patched.unet.model.latent_format) | |
pbar = comfy.utils.ProgressBar(gen_data["steps"]) | |
def callback_function(step, x0, x, total_steps): | |
y = None | |
if previewer: | |
y = previewer.preview(x0, step, total_steps) | |
if callback is not None: | |
callback(step, x0, x, total_steps, y) | |
pbar.update_absolute(step + 1, total_steps, None) | |
if noise_mask is not None: | |
noise_mask = prepare_mask(noise_mask, noise.shape, device) | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Prepare models ...", None) | |
) | |
if updated_conditions: | |
conds = { | |
0: self.conditions["+"]["cache"], | |
1: self.conditions["-"]["cache"], | |
} | |
self.models, self.inference_memory = get_additional_models( | |
conds, | |
self.xl_base_patched.unet.model_dtype(), | |
) | |
comfy.model_management.load_models_gpu([self.xl_base_patched.unet]) | |
comfy.model_management.load_models_gpu(self.models) | |
noise = noise.to(device) | |
latent_image = latent_image.to(device) | |
# Use FluxGuidance for Flux | |
positive_cond = switched_prompt if switched_prompt else self.conditions["+"]["cache"] | |
if isinstance(self.xl_base.unet.model, Flux): | |
positive_cond = conditioning_set_values(positive_cond, {"guidance": cfg}) | |
cfg = 1.0 | |
kwargs = { | |
"cfg": cfg, | |
"latent_image": latent_image, | |
"start_step": 0, | |
"last_step": gen_data["steps"], | |
"force_full_denoise": force_full_denoise, | |
"denoise_mask": noise_mask, | |
"sigmas": None, | |
"disable_pbar": False, | |
"seed": seed, | |
"callback": callback_function, | |
} | |
sampler = KSampler( | |
self.xl_base_patched.unet, | |
steps=gen_data["steps"], | |
device=device, | |
sampler=sampler_name, | |
scheduler=scheduler, | |
denoise=denoise, | |
model_options=self.xl_base_patched.unet.model_options, | |
) | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"Start sampling ...", None) | |
) | |
samples = sampler.sample( | |
noise, | |
positive_cond, | |
self.conditions["-"]["cache"], | |
**kwargs, | |
) | |
cleanup_additional_models(self.models) | |
sampled_latent = latent.copy() | |
sampled_latent["samples"] = samples | |
if callback is not None: | |
worker.add_result( | |
gen_data["task_id"], | |
"preview", | |
(-1, f"VAE decoding ...", None) | |
) | |
decoded_latent = VAEDecode().decode( | |
samples=sampled_latent, vae=self.xl_base_patched.vae | |
)[0] | |
images = [ | |
np.clip(255.0 * y.cpu().numpy(), 0, 255).astype(np.uint8) | |
for y in decoded_latent | |
] | |
if callback is not None: | |
callback(gen_data["steps"], 0, 0, gen_data["steps"], images[0]) | |
return images | |