radames's picture
layerdiffuse
7951db8
import os
from enum import Enum
import torch
import functools
import copy
from typing import Optional, List
from dataclasses import dataclass
import folder_paths
import comfy.model_management
import comfy.model_base
import comfy.supported_models
import comfy.supported_models_base
from comfy.model_patcher import ModelPatcher
from folder_paths import get_folder_paths
from comfy.utils import load_torch_file
from comfy_extras.nodes_compositing import JoinImageWithAlpha
from comfy.conds import CONDRegular
from .lib_layerdiffusion.utils import (
load_file_from_url,
to_lora_patch_dict,
)
from .lib_layerdiffusion.models import TransparentVAEDecoder
from .lib_layerdiffusion.attention_sharing import AttentionSharingPatcher
from .lib_layerdiffusion.enums import StableDiffusionVersion
if "layer_model" in folder_paths.folder_names_and_paths:
layer_model_root = get_folder_paths("layer_model")[0]
else:
layer_model_root = os.path.join(folder_paths.models_dir, "layer_model")
load_layer_model_state_dict = load_torch_file
# ------------ Start patching ComfyUI ------------
def calculate_weight_adjust_channel(func):
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""
@functools.wraps(func)
def calculate_weight(
self: ModelPatcher, patches, weight: torch.Tensor, key: str
) -> torch.Tensor:
weight = func(self, patches, weight, key)
for p in patches:
alpha = p[0]
v = p[1]
# The recursion call should be handled in the main func call.
if isinstance(v, list):
continue
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if all(
(
alpha != 0.0,
w1.shape != weight.shape,
w1.ndim == weight.ndim == 4,
)
):
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(
f"Merged with {key} channel changed from {weight.shape} to {new_shape}"
)
new_diff = alpha * comfy.model_management.cast_to_device(
w1, weight.device, weight.dtype
)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[
: weight.shape[0],
: weight.shape[1],
: weight.shape[2],
: weight.shape[3],
] = weight
new_weight[
: new_diff.shape[0],
: new_diff.shape[1],
: new_diff.shape[2],
: new_diff.shape[3],
] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
return weight
return calculate_weight
ModelPatcher.calculate_weight = calculate_weight_adjust_channel(
ModelPatcher.calculate_weight
)
# ------------ End patching ComfyUI ------------
class LayeredDiffusionDecode:
"""
Decode alpha channel value from pixel value.
[B, C=3, H, W] => [B, C=4, H, W]
Outputs RGB image + Alpha mask.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"images": ("IMAGE",),
"sd_version": (
[
StableDiffusionVersion.SD1x.value,
StableDiffusionVersion.SDXL.value,
],
{
"default": StableDiffusionVersion.SDXL.value,
},
),
"sub_batch_size": (
"INT",
{"default": 16, "min": 1, "max": 4096, "step": 1},
),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "decode"
CATEGORY = "layer_diffuse"
def __init__(self) -> None:
self.vae_transparent_decoder = {}
def decode(self, samples, images, sd_version: str, sub_batch_size: int):
"""
sub_batch_size: How many images to decode in a single pass.
See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more
context.
"""
sd_version = StableDiffusionVersion(sd_version)
if sd_version == StableDiffusionVersion.SD1x:
url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
file_name = "layer_sd15_vae_transparent_decoder.safetensors"
elif sd_version == StableDiffusionVersion.SDXL:
url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
file_name = "vae_transparent_decoder.safetensors"
if not self.vae_transparent_decoder.get(sd_version):
model_path = load_file_from_url(
url=url, model_dir=layer_model_root, file_name=file_name
)
self.vae_transparent_decoder[sd_version] = TransparentVAEDecoder(
load_torch_file(model_path),
device=comfy.model_management.get_torch_device(),
dtype=(
torch.float16
if comfy.model_management.should_use_fp16()
else torch.float32
),
)
pixel = images.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W]
# Decoder requires dimension to be 64-aligned.
B, C, H, W = pixel.shape
assert H % 64 == 0, f"Height({H}) is not multiple of 64."
assert W % 64 == 0, f"Height({W}) is not multiple of 64."
decoded = []
for start_idx in range(0, samples["samples"].shape[0], sub_batch_size):
decoded.append(
self.vae_transparent_decoder[sd_version].decode_pixel(
pixel[start_idx : start_idx + sub_batch_size],
samples["samples"][start_idx : start_idx + sub_batch_size],
)
)
pixel_with_alpha = torch.cat(decoded, dim=0)
# [B, C, H, W] => [B, H, W, C]
pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
image = pixel_with_alpha[..., 1:]
alpha = pixel_with_alpha[..., 0]
return (image, alpha)
class LayeredDiffusionDecodeRGBA(LayeredDiffusionDecode):
"""
Decode alpha channel value from pixel value.
[B, C=3, H, W] => [B, C=4, H, W]
Outputs RGBA image.
"""
RETURN_TYPES = ("IMAGE",)
def decode(self, samples, images, sd_version: str, sub_batch_size: int):
image, mask = super().decode(samples, images, sd_version, sub_batch_size)
alpha = 1.0 - mask
return JoinImageWithAlpha().join_image_with_alpha(image, alpha)
class LayeredDiffusionDecodeSplit(LayeredDiffusionDecodeRGBA):
"""Decode RGBA every N images."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"images": ("IMAGE",),
# Do RGBA decode every N output images.
"frames": (
"INT",
{"default": 2, "min": 2, "max": s.MAX_FRAMES, "step": 1},
),
"sd_version": (
[
StableDiffusionVersion.SD1x.value,
StableDiffusionVersion.SDXL.value,
],
{
"default": StableDiffusionVersion.SDXL.value,
},
),
"sub_batch_size": (
"INT",
{"default": 16, "min": 1, "max": 4096, "step": 1},
),
},
}
MAX_FRAMES = 3
RETURN_TYPES = ("IMAGE",) * MAX_FRAMES
def decode(
self,
samples,
images: torch.Tensor,
frames: int,
sd_version: str,
sub_batch_size: int,
):
sliced_samples = copy.copy(samples)
sliced_samples["samples"] = sliced_samples["samples"][::frames]
return tuple(
(
(
super(LayeredDiffusionDecodeSplit, self).decode(
sliced_samples, imgs, sd_version, sub_batch_size
)[0]
if i == 0
else imgs
)
for i in range(frames)
for imgs in (images[i::frames],)
)
) + (None,) * (self.MAX_FRAMES - frames)
class LayerMethod(Enum):
ATTN = "Attention Injection"
CONV = "Conv Injection"
class LayerType(Enum):
FG = "Foreground"
BG = "Background"
@dataclass
class LayeredDiffusionBase:
model_file_name: str
model_url: str
sd_version: StableDiffusionVersion
attn_sharing: bool = False
injection_method: Optional[LayerMethod] = None
cond_type: Optional[LayerType] = None
# Number of output images per run.
frames: int = 1
@property
def config_string(self) -> str:
injection_method = self.injection_method.value if self.injection_method else ""
cond_type = self.cond_type.value if self.cond_type else ""
attn_sharing = "attn_sharing" if self.attn_sharing else ""
frames = f"Batch size ({self.frames}N)" if self.frames != 1 else ""
return ", ".join(
x
for x in (
self.sd_version.value,
injection_method,
cond_type,
attn_sharing,
frames,
)
if x
)
def apply_c_concat(self, cond, uncond, c_concat):
"""Set foreground/background concat condition."""
def write_c_concat(cond):
new_cond = []
for t in cond:
n = [t[0], t[1].copy()]
if "model_conds" not in n[1]:
n[1]["model_conds"] = {}
n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat)
new_cond.append(n)
return new_cond
return (write_c_concat(cond), write_c_concat(uncond))
def apply_layered_diffusion(
self,
model: ModelPatcher,
weight: float,
):
"""Patch model"""
model_path = load_file_from_url(
url=self.model_url,
model_dir=layer_model_root,
file_name=self.model_file_name,
)
layer_lora_state_dict = load_layer_model_state_dict(model_path)
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
work_model = model.clone()
work_model.add_patches(layer_lora_patch_dict, weight)
return (work_model,)
def apply_layered_diffusion_attn_sharing(
self,
model: ModelPatcher,
control_img: Optional[torch.TensorType] = None,
):
"""Patch model with attn sharing"""
model_path = load_file_from_url(
url=self.model_url,
model_dir=layer_model_root,
file_name=self.model_file_name,
)
layer_lora_state_dict = load_layer_model_state_dict(model_path)
work_model = model.clone()
patcher = AttentionSharingPatcher(
work_model, self.frames, use_control=control_img is not None
)
patcher.load_state_dict(layer_lora_state_dict, strict=True)
if control_img is not None:
patcher.set_control(control_img)
return (work_model,)
def get_model_sd_version(model: ModelPatcher) -> StableDiffusionVersion:
"""Get model's StableDiffusionVersion."""
base: comfy.model_base.BaseModel = model.model
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
if isinstance(model_config, comfy.supported_models.SDXL):
return StableDiffusionVersion.SDXL
elif isinstance(
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
):
# SD15 and SD20 are compatible with each other.
return StableDiffusionVersion.SD1x
else:
raise Exception(f"Unsupported SD Version: {type(model_config)}.")
class LayeredDiffusionFG:
"""Generate foreground with transparent background."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"config": ([c.config_string for c in s.MODELS],),
"weight": (
"FLOAT",
{"default": 1.0, "min": -1, "max": 3, "step": 0.05},
),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_layered_diffusion"
CATEGORY = "layer_diffuse"
MODELS = (
LayeredDiffusionBase(
model_file_name="layer_xl_transparent_attn.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors",
sd_version=StableDiffusionVersion.SDXL,
injection_method=LayerMethod.ATTN,
),
LayeredDiffusionBase(
model_file_name="layer_xl_transparent_conv.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors",
sd_version=StableDiffusionVersion.SDXL,
injection_method=LayerMethod.CONV,
),
LayeredDiffusionBase(
model_file_name="layer_sd15_transparent_attn.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors",
sd_version=StableDiffusionVersion.SD1x,
injection_method=LayerMethod.ATTN,
attn_sharing=True,
),
)
def apply_layered_diffusion(
self,
model: ModelPatcher,
config: str,
weight: float,
):
ld_model = [m for m in self.MODELS if m.config_string == config][0]
assert get_model_sd_version(model) == ld_model.sd_version
if ld_model.attn_sharing:
return ld_model.apply_layered_diffusion_attn_sharing(model)
else:
return ld_model.apply_layered_diffusion(model, weight)
class LayeredDiffusionJoint:
"""Generate FG + BG + Blended in one inference batch. Batch size = 3N."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"config": ([c.config_string for c in s.MODELS],),
},
"optional": {
"fg_cond": ("CONDITIONING",),
"bg_cond": ("CONDITIONING",),
"blended_cond": ("CONDITIONING",),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_layered_diffusion"
CATEGORY = "layer_diffuse"
MODELS = (
LayeredDiffusionBase(
model_file_name="layer_sd15_joint.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors",
sd_version=StableDiffusionVersion.SD1x,
attn_sharing=True,
frames=3,
),
)
def apply_layered_diffusion(
self,
model: ModelPatcher,
config: str,
fg_cond: Optional[List[List[torch.TensorType]]] = None,
bg_cond: Optional[List[List[torch.TensorType]]] = None,
blended_cond: Optional[List[List[torch.TensorType]]] = None,
):
ld_model = [m for m in self.MODELS if m.config_string == config][0]
assert get_model_sd_version(model) == ld_model.sd_version
assert ld_model.attn_sharing
work_model = ld_model.apply_layered_diffusion_attn_sharing(model)[0]
work_model.model_options.setdefault("transformer_options", {})
work_model.model_options["transformer_options"]["cond_overwrite"] = [
cond[0][0] if cond is not None else None
for cond in (
fg_cond,
bg_cond,
blended_cond,
)
]
return (work_model,)
class LayeredDiffusionCond:
"""Generate foreground + background given background / foreground.
- FG => Blended
- BG => Blended
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"cond": ("CONDITIONING",),
"uncond": ("CONDITIONING",),
"latent": ("LATENT",),
"config": ([c.config_string for c in s.MODELS],),
"weight": (
"FLOAT",
{"default": 1.0, "min": -1, "max": 3, "step": 0.05},
),
},
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING")
FUNCTION = "apply_layered_diffusion"
CATEGORY = "layer_diffuse"
MODELS = (
LayeredDiffusionBase(
model_file_name="layer_xl_fg2ble.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors",
sd_version=StableDiffusionVersion.SDXL,
cond_type=LayerType.FG,
),
LayeredDiffusionBase(
model_file_name="layer_xl_bg2ble.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors",
sd_version=StableDiffusionVersion.SDXL,
cond_type=LayerType.BG,
),
)
def apply_layered_diffusion(
self,
model: ModelPatcher,
cond,
uncond,
latent,
config: str,
weight: float,
):
ld_model = [m for m in self.MODELS if m.config_string == config][0]
assert get_model_sd_version(model) == ld_model.sd_version
c_concat = model.model.latent_format.process_in(latent["samples"])
return ld_model.apply_layered_diffusion(
model, weight
) + ld_model.apply_c_concat(cond, uncond, c_concat)
class LayeredDiffusionCondJoint:
"""Generate fg/bg + blended given fg/bg.
- FG => Blended + BG
- BG => Blended + FG
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"image": ("IMAGE",),
"config": ([c.config_string for c in s.MODELS],),
},
"optional": {
"cond": ("CONDITIONING",),
"blended_cond": ("CONDITIONING",),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_layered_diffusion"
CATEGORY = "layer_diffuse"
MODELS = (
LayeredDiffusionBase(
model_file_name="layer_sd15_fg2bg.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors",
sd_version=StableDiffusionVersion.SD1x,
attn_sharing=True,
frames=2,
cond_type=LayerType.FG,
),
LayeredDiffusionBase(
model_file_name="layer_sd15_bg2fg.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors",
sd_version=StableDiffusionVersion.SD1x,
attn_sharing=True,
frames=2,
cond_type=LayerType.BG,
),
)
def apply_layered_diffusion(
self,
model: ModelPatcher,
image,
config: str,
cond: Optional[List[List[torch.TensorType]]] = None,
blended_cond: Optional[List[List[torch.TensorType]]] = None,
):
ld_model = [m for m in self.MODELS if m.config_string == config][0]
assert get_model_sd_version(model) == ld_model.sd_version
assert ld_model.attn_sharing
work_model = ld_model.apply_layered_diffusion_attn_sharing(
model, control_img=image.movedim(-1, 1)
)[0]
work_model.model_options.setdefault("transformer_options", {})
work_model.model_options["transformer_options"]["cond_overwrite"] = [
cond[0][0] if cond is not None else None
for cond in (
cond,
blended_cond,
)
]
return (work_model,)
class LayeredDiffusionDiff:
"""Extract FG/BG from blended image.
- Blended + FG => BG
- Blended + BG => FG
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"cond": ("CONDITIONING",),
"uncond": ("CONDITIONING",),
"blended_latent": ("LATENT",),
"latent": ("LATENT",),
"config": ([c.config_string for c in s.MODELS],),
"weight": (
"FLOAT",
{"default": 1.0, "min": -1, "max": 3, "step": 0.05},
),
},
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING")
FUNCTION = "apply_layered_diffusion"
CATEGORY = "layer_diffuse"
MODELS = (
LayeredDiffusionBase(
model_file_name="layer_xl_fgble2bg.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors",
sd_version=StableDiffusionVersion.SDXL,
cond_type=LayerType.FG,
),
LayeredDiffusionBase(
model_file_name="layer_xl_bgble2fg.safetensors",
model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors",
sd_version=StableDiffusionVersion.SDXL,
cond_type=LayerType.BG,
),
)
def apply_layered_diffusion(
self,
model: ModelPatcher,
cond,
uncond,
blended_latent,
latent,
config: str,
weight: float,
):
ld_model = [m for m in self.MODELS if m.config_string == config][0]
assert get_model_sd_version(model) == ld_model.sd_version
c_concat = model.model.latent_format.process_in(
torch.cat([latent["samples"], blended_latent["samples"]], dim=1)
)
return ld_model.apply_layered_diffusion(
model, weight
) + ld_model.apply_c_concat(cond, uncond, c_concat)
NODE_CLASS_MAPPINGS = {
"LayeredDiffusionApply": LayeredDiffusionFG,
"LayeredDiffusionJointApply": LayeredDiffusionJoint,
"LayeredDiffusionCondApply": LayeredDiffusionCond,
"LayeredDiffusionCondJointApply": LayeredDiffusionCondJoint,
"LayeredDiffusionDiffApply": LayeredDiffusionDiff,
"LayeredDiffusionDecode": LayeredDiffusionDecode,
"LayeredDiffusionDecodeRGBA": LayeredDiffusionDecodeRGBA,
"LayeredDiffusionDecodeSplit": LayeredDiffusionDecodeSplit,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LayeredDiffusionApply": "Layer Diffuse Apply",
"LayeredDiffusionJointApply": "Layer Diffuse Joint Apply",
"LayeredDiffusionCondApply": "Layer Diffuse Cond Apply",
"LayeredDiffusionCondJointApply": "Layer Diffuse Cond Joint Apply",
"LayeredDiffusionDiffApply": "Layer Diffuse Diff Apply",
"LayeredDiffusionDecode": "Layer Diffuse Decode",
"LayeredDiffusionDecodeRGBA": "Layer Diffuse Decode (RGBA)",
"LayeredDiffusionDecodeSplit": "Layer Diffuse Decode (Split)",
}