JasonSmithSO's picture
Upload 578 files
8866644 verified
from torch import Tensor
import folder_paths
from nodes import VAEEncode
import comfy.utils
from comfy.sd import VAE
from .utils import TimestepKeyframeGroup
from .control_sparsectrl import SparseMethod, SparseIndexMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced
# node for SparseCtrl loading
class SparseCtrlLoaderAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
"use_motion": ("BOOLEAN", {"default": True}, ),
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
},
"optional": {
"sparse_method": ("SPARSE_METHOD", ),
"tk_optional": ("TIMESTEP_KEYFRAME", ),
}
}
RETURN_TYPES = ("CONTROL_NET", )
FUNCTION = "load_controlnet"
CATEGORY = "Adv-ControlNet πŸ›‚πŸ…πŸ…’πŸ…/SparseCtrl"
def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale)
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
return (sparsectrl,)
class SparseCtrlMergedLoaderAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
"use_motion": ("BOOLEAN", {"default": True}, ),
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
},
"optional": {
"sparse_method": ("SPARSE_METHOD", ),
"tk_optional": ("TIMESTEP_KEYFRAME", ),
}
}
RETURN_TYPES = ("CONTROL_NET", )
FUNCTION = "load_controlnet"
CATEGORY = "Adv-ControlNet πŸ›‚πŸ…πŸ…’πŸ…/SparseCtrl/experimental"
def load_controlnet(self, sparsectrl_name: str, control_net_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, merged=True)
# first, load normal controlnet
controlnet = load_controlnet(controlnet_path, timestep_keyframe=tk_optional)
# confirm that controlnet is ControlNetAdvanced
if controlnet is None or type(controlnet) != ControlNetAdvanced:
raise ValueError(f"controlnet_path must point to a normal ControlNet, but instead: {type(controlnet).__name__}")
# next, load sparsectrl, making sure to load motion portion
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=SparseSettings.default())
# now, combine state dicts
new_state_dict = controlnet.control_model.state_dict()
for key, value in sparsectrl.control_model.motion_holder.motion_wrapper.state_dict().items():
new_state_dict[key] = value
# now, reload sparsectrl with real settings
sparsectrl = load_sparsectrl(sparsectrl_path, controlnet_data=new_state_dict, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
return (sparsectrl,)
class SparseIndexMethodNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"indexes": ("STRING", {"default": "0"}),
}
}
RETURN_TYPES = ("SPARSE_METHOD",)
FUNCTION = "get_method"
CATEGORY = "Adv-ControlNet πŸ›‚πŸ…πŸ…’πŸ…/SparseCtrl"
def get_method(self, indexes: str):
idxs = []
unique_idxs = set()
# get indeces from string
str_idxs = [x.strip() for x in indexes.strip().split(",")]
for str_idx in str_idxs:
try:
idx = int(str_idx)
if idx in unique_idxs:
raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
idxs.append(idx)
unique_idxs.add(idx)
except ValueError:
raise ValueError(f"'{str_idx}' is not a valid integer index.")
if len(idxs) == 0:
raise ValueError(f"No indexes were listed in Sparse Index Method.")
return (SparseIndexMethod(idxs),)
class SparseSpreadMethodNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"spread": (SparseSpreadMethod.LIST,),
}
}
RETURN_TYPES = ("SPARSE_METHOD",)
FUNCTION = "get_method"
CATEGORY = "Adv-ControlNet πŸ›‚πŸ…πŸ…’πŸ…/SparseCtrl"
def get_method(self, spread: str):
return (SparseSpreadMethod(spread=spread),)
class RgbSparseCtrlPreprocessor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE", ),
"vae": ("VAE", ),
"latent_size": ("LATENT", ),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("proc_IMAGE",)
FUNCTION = "preprocess_images"
CATEGORY = "Adv-ControlNet πŸ›‚πŸ…πŸ…’πŸ…/SparseCtrl/preprocess"
def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
# first, resize image to match latents
image = image.movedim(-1,1)
image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
image = image.movedim(1,-1)
# then, vae encode
try:
image = vae.vae_encode_crop_pixels(image)
except Exception:
image = VAEEncode.vae_encode_crop_pixels(image)
encoded = vae.encode(image[:,:,:,:3])
return (PreprocSparseRGBWrapper(condhint=encoded),)