Spaces:
Building
on
A10G
Building
on
A10G
import os | |
from enum import Enum | |
import torch | |
import functools | |
import copy | |
from typing import Optional, List | |
from dataclasses import dataclass | |
import folder_paths | |
import comfy.model_management | |
import comfy.model_base | |
import comfy.supported_models | |
import comfy.supported_models_base | |
from comfy.model_patcher import ModelPatcher | |
from folder_paths import get_folder_paths | |
from comfy.utils import load_torch_file | |
from comfy_extras.nodes_compositing import JoinImageWithAlpha | |
from comfy.conds import CONDRegular | |
from .lib_layerdiffusion.utils import ( | |
load_file_from_url, | |
to_lora_patch_dict, | |
) | |
from .lib_layerdiffusion.models import TransparentVAEDecoder | |
from .lib_layerdiffusion.attention_sharing import AttentionSharingPatcher | |
from .lib_layerdiffusion.enums import StableDiffusionVersion | |
if "layer_model" in folder_paths.folder_names_and_paths: | |
layer_model_root = get_folder_paths("layer_model")[0] | |
else: | |
layer_model_root = os.path.join(folder_paths.models_dir, "layer_model") | |
load_layer_model_state_dict = load_torch_file | |
# ------------ Start patching ComfyUI ------------ | |
def calculate_weight_adjust_channel(func): | |
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs.""" | |
def calculate_weight( | |
self: ModelPatcher, patches, weight: torch.Tensor, key: str | |
) -> torch.Tensor: | |
weight = func(self, patches, weight, key) | |
for p in patches: | |
alpha = p[0] | |
v = p[1] | |
# The recursion call should be handled in the main func call. | |
if isinstance(v, list): | |
continue | |
if len(v) == 1: | |
patch_type = "diff" | |
elif len(v) == 2: | |
patch_type = v[0] | |
v = v[1] | |
if patch_type == "diff": | |
w1 = v[0] | |
if all( | |
( | |
alpha != 0.0, | |
w1.shape != weight.shape, | |
w1.ndim == weight.ndim == 4, | |
) | |
): | |
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] | |
print( | |
f"Merged with {key} channel changed from {weight.shape} to {new_shape}" | |
) | |
new_diff = alpha * comfy.model_management.cast_to_device( | |
w1, weight.device, weight.dtype | |
) | |
new_weight = torch.zeros(size=new_shape).to(weight) | |
new_weight[ | |
: weight.shape[0], | |
: weight.shape[1], | |
: weight.shape[2], | |
: weight.shape[3], | |
] = weight | |
new_weight[ | |
: new_diff.shape[0], | |
: new_diff.shape[1], | |
: new_diff.shape[2], | |
: new_diff.shape[3], | |
] += new_diff | |
new_weight = new_weight.contiguous().clone() | |
weight = new_weight | |
return weight | |
return calculate_weight | |
ModelPatcher.calculate_weight = calculate_weight_adjust_channel( | |
ModelPatcher.calculate_weight | |
) | |
# ------------ End patching ComfyUI ------------ | |
class LayeredDiffusionDecode: | |
""" | |
Decode alpha channel value from pixel value. | |
[B, C=3, H, W] => [B, C=4, H, W] | |
Outputs RGB image + Alpha mask. | |
""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"samples": ("LATENT",), | |
"images": ("IMAGE",), | |
"sd_version": ( | |
[ | |
StableDiffusionVersion.SD1x.value, | |
StableDiffusionVersion.SDXL.value, | |
], | |
{ | |
"default": StableDiffusionVersion.SDXL.value, | |
}, | |
), | |
"sub_batch_size": ( | |
"INT", | |
{"default": 16, "min": 1, "max": 4096, "step": 1}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE", "MASK") | |
FUNCTION = "decode" | |
CATEGORY = "layer_diffuse" | |
def __init__(self) -> None: | |
self.vae_transparent_decoder = {} | |
def decode(self, samples, images, sd_version: str, sub_batch_size: int): | |
""" | |
sub_batch_size: How many images to decode in a single pass. | |
See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more | |
context. | |
""" | |
sd_version = StableDiffusionVersion(sd_version) | |
if sd_version == StableDiffusionVersion.SD1x: | |
url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors" | |
file_name = "layer_sd15_vae_transparent_decoder.safetensors" | |
elif sd_version == StableDiffusionVersion.SDXL: | |
url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors" | |
file_name = "vae_transparent_decoder.safetensors" | |
if not self.vae_transparent_decoder.get(sd_version): | |
model_path = load_file_from_url( | |
url=url, model_dir=layer_model_root, file_name=file_name | |
) | |
self.vae_transparent_decoder[sd_version] = TransparentVAEDecoder( | |
load_torch_file(model_path), | |
device=comfy.model_management.get_torch_device(), | |
dtype=( | |
torch.float16 | |
if comfy.model_management.should_use_fp16() | |
else torch.float32 | |
), | |
) | |
pixel = images.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W] | |
# Decoder requires dimension to be 64-aligned. | |
B, C, H, W = pixel.shape | |
assert H % 64 == 0, f"Height({H}) is not multiple of 64." | |
assert W % 64 == 0, f"Height({W}) is not multiple of 64." | |
decoded = [] | |
for start_idx in range(0, samples["samples"].shape[0], sub_batch_size): | |
decoded.append( | |
self.vae_transparent_decoder[sd_version].decode_pixel( | |
pixel[start_idx : start_idx + sub_batch_size], | |
samples["samples"][start_idx : start_idx + sub_batch_size], | |
) | |
) | |
pixel_with_alpha = torch.cat(decoded, dim=0) | |
# [B, C, H, W] => [B, H, W, C] | |
pixel_with_alpha = pixel_with_alpha.movedim(1, -1) | |
image = pixel_with_alpha[..., 1:] | |
alpha = pixel_with_alpha[..., 0] | |
return (image, alpha) | |
class LayeredDiffusionDecodeRGBA(LayeredDiffusionDecode): | |
""" | |
Decode alpha channel value from pixel value. | |
[B, C=3, H, W] => [B, C=4, H, W] | |
Outputs RGBA image. | |
""" | |
RETURN_TYPES = ("IMAGE",) | |
def decode(self, samples, images, sd_version: str, sub_batch_size: int): | |
image, mask = super().decode(samples, images, sd_version, sub_batch_size) | |
alpha = 1.0 - mask | |
return JoinImageWithAlpha().join_image_with_alpha(image, alpha) | |
class LayeredDiffusionDecodeSplit(LayeredDiffusionDecodeRGBA): | |
"""Decode RGBA every N images.""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"samples": ("LATENT",), | |
"images": ("IMAGE",), | |
# Do RGBA decode every N output images. | |
"frames": ( | |
"INT", | |
{"default": 2, "min": 2, "max": s.MAX_FRAMES, "step": 1}, | |
), | |
"sd_version": ( | |
[ | |
StableDiffusionVersion.SD1x.value, | |
StableDiffusionVersion.SDXL.value, | |
], | |
{ | |
"default": StableDiffusionVersion.SDXL.value, | |
}, | |
), | |
"sub_batch_size": ( | |
"INT", | |
{"default": 16, "min": 1, "max": 4096, "step": 1}, | |
), | |
}, | |
} | |
MAX_FRAMES = 3 | |
RETURN_TYPES = ("IMAGE",) * MAX_FRAMES | |
def decode( | |
self, | |
samples, | |
images: torch.Tensor, | |
frames: int, | |
sd_version: str, | |
sub_batch_size: int, | |
): | |
sliced_samples = copy.copy(samples) | |
sliced_samples["samples"] = sliced_samples["samples"][::frames] | |
return tuple( | |
( | |
( | |
super(LayeredDiffusionDecodeSplit, self).decode( | |
sliced_samples, imgs, sd_version, sub_batch_size | |
)[0] | |
if i == 0 | |
else imgs | |
) | |
for i in range(frames) | |
for imgs in (images[i::frames],) | |
) | |
) + (None,) * (self.MAX_FRAMES - frames) | |
class LayerMethod(Enum): | |
ATTN = "Attention Injection" | |
CONV = "Conv Injection" | |
class LayerType(Enum): | |
FG = "Foreground" | |
BG = "Background" | |
class LayeredDiffusionBase: | |
model_file_name: str | |
model_url: str | |
sd_version: StableDiffusionVersion | |
attn_sharing: bool = False | |
injection_method: Optional[LayerMethod] = None | |
cond_type: Optional[LayerType] = None | |
# Number of output images per run. | |
frames: int = 1 | |
def config_string(self) -> str: | |
injection_method = self.injection_method.value if self.injection_method else "" | |
cond_type = self.cond_type.value if self.cond_type else "" | |
attn_sharing = "attn_sharing" if self.attn_sharing else "" | |
frames = f"Batch size ({self.frames}N)" if self.frames != 1 else "" | |
return ", ".join( | |
x | |
for x in ( | |
self.sd_version.value, | |
injection_method, | |
cond_type, | |
attn_sharing, | |
frames, | |
) | |
if x | |
) | |
def apply_c_concat(self, cond, uncond, c_concat): | |
"""Set foreground/background concat condition.""" | |
def write_c_concat(cond): | |
new_cond = [] | |
for t in cond: | |
n = [t[0], t[1].copy()] | |
if "model_conds" not in n[1]: | |
n[1]["model_conds"] = {} | |
n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat) | |
new_cond.append(n) | |
return new_cond | |
return (write_c_concat(cond), write_c_concat(uncond)) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
weight: float, | |
): | |
"""Patch model""" | |
model_path = load_file_from_url( | |
url=self.model_url, | |
model_dir=layer_model_root, | |
file_name=self.model_file_name, | |
) | |
layer_lora_state_dict = load_layer_model_state_dict(model_path) | |
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict) | |
work_model = model.clone() | |
work_model.add_patches(layer_lora_patch_dict, weight) | |
return (work_model,) | |
def apply_layered_diffusion_attn_sharing( | |
self, | |
model: ModelPatcher, | |
control_img: Optional[torch.TensorType] = None, | |
): | |
"""Patch model with attn sharing""" | |
model_path = load_file_from_url( | |
url=self.model_url, | |
model_dir=layer_model_root, | |
file_name=self.model_file_name, | |
) | |
layer_lora_state_dict = load_layer_model_state_dict(model_path) | |
work_model = model.clone() | |
patcher = AttentionSharingPatcher( | |
work_model, self.frames, use_control=control_img is not None | |
) | |
patcher.load_state_dict(layer_lora_state_dict, strict=True) | |
if control_img is not None: | |
patcher.set_control(control_img) | |
return (work_model,) | |
def get_model_sd_version(model: ModelPatcher) -> StableDiffusionVersion: | |
"""Get model's StableDiffusionVersion.""" | |
base: comfy.model_base.BaseModel = model.model | |
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config | |
if isinstance(model_config, comfy.supported_models.SDXL): | |
return StableDiffusionVersion.SDXL | |
elif isinstance( | |
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20) | |
): | |
# SD15 and SD20 are compatible with each other. | |
return StableDiffusionVersion.SD1x | |
else: | |
raise Exception(f"Unsupported SD Version: {type(model_config)}.") | |
class LayeredDiffusionFG: | |
"""Generate foreground with transparent background.""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"config": ([c.config_string for c in s.MODELS],), | |
"weight": ( | |
"FLOAT", | |
{"default": 1.0, "min": -1, "max": 3, "step": 0.05}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "apply_layered_diffusion" | |
CATEGORY = "layer_diffuse" | |
MODELS = ( | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_transparent_attn.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
injection_method=LayerMethod.ATTN, | |
), | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_transparent_conv.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
injection_method=LayerMethod.CONV, | |
), | |
LayeredDiffusionBase( | |
model_file_name="layer_sd15_transparent_attn.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors", | |
sd_version=StableDiffusionVersion.SD1x, | |
injection_method=LayerMethod.ATTN, | |
attn_sharing=True, | |
), | |
) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
config: str, | |
weight: float, | |
): | |
ld_model = [m for m in self.MODELS if m.config_string == config][0] | |
assert get_model_sd_version(model) == ld_model.sd_version | |
if ld_model.attn_sharing: | |
return ld_model.apply_layered_diffusion_attn_sharing(model) | |
else: | |
return ld_model.apply_layered_diffusion(model, weight) | |
class LayeredDiffusionJoint: | |
"""Generate FG + BG + Blended in one inference batch. Batch size = 3N.""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"config": ([c.config_string for c in s.MODELS],), | |
}, | |
"optional": { | |
"fg_cond": ("CONDITIONING",), | |
"bg_cond": ("CONDITIONING",), | |
"blended_cond": ("CONDITIONING",), | |
}, | |
} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "apply_layered_diffusion" | |
CATEGORY = "layer_diffuse" | |
MODELS = ( | |
LayeredDiffusionBase( | |
model_file_name="layer_sd15_joint.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors", | |
sd_version=StableDiffusionVersion.SD1x, | |
attn_sharing=True, | |
frames=3, | |
), | |
) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
config: str, | |
fg_cond: Optional[List[List[torch.TensorType]]] = None, | |
bg_cond: Optional[List[List[torch.TensorType]]] = None, | |
blended_cond: Optional[List[List[torch.TensorType]]] = None, | |
): | |
ld_model = [m for m in self.MODELS if m.config_string == config][0] | |
assert get_model_sd_version(model) == ld_model.sd_version | |
assert ld_model.attn_sharing | |
work_model = ld_model.apply_layered_diffusion_attn_sharing(model)[0] | |
work_model.model_options.setdefault("transformer_options", {}) | |
work_model.model_options["transformer_options"]["cond_overwrite"] = [ | |
cond[0][0] if cond is not None else None | |
for cond in ( | |
fg_cond, | |
bg_cond, | |
blended_cond, | |
) | |
] | |
return (work_model,) | |
class LayeredDiffusionCond: | |
"""Generate foreground + background given background / foreground. | |
- FG => Blended | |
- BG => Blended | |
""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"cond": ("CONDITIONING",), | |
"uncond": ("CONDITIONING",), | |
"latent": ("LATENT",), | |
"config": ([c.config_string for c in s.MODELS],), | |
"weight": ( | |
"FLOAT", | |
{"default": 1.0, "min": -1, "max": 3, "step": 0.05}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING") | |
FUNCTION = "apply_layered_diffusion" | |
CATEGORY = "layer_diffuse" | |
MODELS = ( | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_fg2ble.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
cond_type=LayerType.FG, | |
), | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_bg2ble.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
cond_type=LayerType.BG, | |
), | |
) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
cond, | |
uncond, | |
latent, | |
config: str, | |
weight: float, | |
): | |
ld_model = [m for m in self.MODELS if m.config_string == config][0] | |
assert get_model_sd_version(model) == ld_model.sd_version | |
c_concat = model.model.latent_format.process_in(latent["samples"]) | |
return ld_model.apply_layered_diffusion( | |
model, weight | |
) + ld_model.apply_c_concat(cond, uncond, c_concat) | |
class LayeredDiffusionCondJoint: | |
"""Generate fg/bg + blended given fg/bg. | |
- FG => Blended + BG | |
- BG => Blended + FG | |
""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"image": ("IMAGE",), | |
"config": ([c.config_string for c in s.MODELS],), | |
}, | |
"optional": { | |
"cond": ("CONDITIONING",), | |
"blended_cond": ("CONDITIONING",), | |
}, | |
} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "apply_layered_diffusion" | |
CATEGORY = "layer_diffuse" | |
MODELS = ( | |
LayeredDiffusionBase( | |
model_file_name="layer_sd15_fg2bg.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors", | |
sd_version=StableDiffusionVersion.SD1x, | |
attn_sharing=True, | |
frames=2, | |
cond_type=LayerType.FG, | |
), | |
LayeredDiffusionBase( | |
model_file_name="layer_sd15_bg2fg.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors", | |
sd_version=StableDiffusionVersion.SD1x, | |
attn_sharing=True, | |
frames=2, | |
cond_type=LayerType.BG, | |
), | |
) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
image, | |
config: str, | |
cond: Optional[List[List[torch.TensorType]]] = None, | |
blended_cond: Optional[List[List[torch.TensorType]]] = None, | |
): | |
ld_model = [m for m in self.MODELS if m.config_string == config][0] | |
assert get_model_sd_version(model) == ld_model.sd_version | |
assert ld_model.attn_sharing | |
work_model = ld_model.apply_layered_diffusion_attn_sharing( | |
model, control_img=image.movedim(-1, 1) | |
)[0] | |
work_model.model_options.setdefault("transformer_options", {}) | |
work_model.model_options["transformer_options"]["cond_overwrite"] = [ | |
cond[0][0] if cond is not None else None | |
for cond in ( | |
cond, | |
blended_cond, | |
) | |
] | |
return (work_model,) | |
class LayeredDiffusionDiff: | |
"""Extract FG/BG from blended image. | |
- Blended + FG => BG | |
- Blended + BG => FG | |
""" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"cond": ("CONDITIONING",), | |
"uncond": ("CONDITIONING",), | |
"blended_latent": ("LATENT",), | |
"latent": ("LATENT",), | |
"config": ([c.config_string for c in s.MODELS],), | |
"weight": ( | |
"FLOAT", | |
{"default": 1.0, "min": -1, "max": 3, "step": 0.05}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING") | |
FUNCTION = "apply_layered_diffusion" | |
CATEGORY = "layer_diffuse" | |
MODELS = ( | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_fgble2bg.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
cond_type=LayerType.FG, | |
), | |
LayeredDiffusionBase( | |
model_file_name="layer_xl_bgble2fg.safetensors", | |
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors", | |
sd_version=StableDiffusionVersion.SDXL, | |
cond_type=LayerType.BG, | |
), | |
) | |
def apply_layered_diffusion( | |
self, | |
model: ModelPatcher, | |
cond, | |
uncond, | |
blended_latent, | |
latent, | |
config: str, | |
weight: float, | |
): | |
ld_model = [m for m in self.MODELS if m.config_string == config][0] | |
assert get_model_sd_version(model) == ld_model.sd_version | |
c_concat = model.model.latent_format.process_in( | |
torch.cat([latent["samples"], blended_latent["samples"]], dim=1) | |
) | |
return ld_model.apply_layered_diffusion( | |
model, weight | |
) + ld_model.apply_c_concat(cond, uncond, c_concat) | |
NODE_CLASS_MAPPINGS = { | |
"LayeredDiffusionApply": LayeredDiffusionFG, | |
"LayeredDiffusionJointApply": LayeredDiffusionJoint, | |
"LayeredDiffusionCondApply": LayeredDiffusionCond, | |
"LayeredDiffusionCondJointApply": LayeredDiffusionCondJoint, | |
"LayeredDiffusionDiffApply": LayeredDiffusionDiff, | |
"LayeredDiffusionDecode": LayeredDiffusionDecode, | |
"LayeredDiffusionDecodeRGBA": LayeredDiffusionDecodeRGBA, | |
"LayeredDiffusionDecodeSplit": LayeredDiffusionDecodeSplit, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"LayeredDiffusionApply": "Layer Diffuse Apply", | |
"LayeredDiffusionJointApply": "Layer Diffuse Joint Apply", | |
"LayeredDiffusionCondApply": "Layer Diffuse Cond Apply", | |
"LayeredDiffusionCondJointApply": "Layer Diffuse Cond Joint Apply", | |
"LayeredDiffusionDiffApply": "Layer Diffuse Diff Apply", | |
"LayeredDiffusionDecode": "Layer Diffuse Decode", | |
"LayeredDiffusionDecodeRGBA": "Layer Diffuse Decode (RGBA)", | |
"LayeredDiffusionDecodeSplit": "Layer Diffuse Decode (Split)", | |
} | |