Spaces:
Running
Running
from collections.abc import Iterable | |
from typing import Union | |
import torch | |
from torch import Tensor | |
from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size | |
class ScaleType: | |
ABSOLUTE = "absolute" | |
RELATIVE = "relative" | |
LIST = [ABSOLUTE, RELATIVE] | |
class MultivalDynamicNode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001},), | |
}, | |
"optional": { | |
"mask_optional": ("MASK",) | |
} | |
} | |
RETURN_TYPES = ("MULTIVAL",) | |
CATEGORY = "Animate Diff ππ π /multival" | |
FUNCTION = "create_multival" | |
def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None): | |
# first, normalize inputs | |
# if float_val is iterable, treat as a list and assume inputs are floats | |
float_is_iterable = False | |
if isinstance(float_val, Iterable): | |
float_is_iterable = True | |
float_val = list(float_val) | |
# if mask present, make sure float_val list can be applied to list - match lengths | |
if mask_optional is not None: | |
if len(float_val) < mask_optional.shape[0]: | |
# copies last entry enough times to match mask shape | |
float_val = float_val + float_val[-1]*(mask_optional.shape[0]-len(float_val)) | |
if mask_optional.shape[0] < len(float_val): | |
mask_optional = extend_to_batch_size(mask_optional, len(float_val)) | |
float_val = float_val[:mask_optional.shape[0]] | |
float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1) | |
# now that inputs are normalized, figure out what value to actually return | |
if mask_optional is not None: | |
mask_optional = mask_optional.clone() | |
if float_is_iterable: | |
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) | |
else: | |
mask_optional = mask_optional * float_val | |
return (mask_optional,) | |
else: | |
if not float_is_iterable: | |
return (float_val,) | |
# create a dummy mask of b,h,w=float_len,1,1 (sigle pixel) | |
# purpose is for float input to work with mask code, without special cases | |
float_len = float_val.shape[0] if float_is_iterable else 1 | |
shape = (float_len,1,1) | |
mask_optional = torch.ones(shape) | |
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device) | |
return (mask_optional,) | |
class MultivalScaledMaskNode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"min_float_val": ("FLOAT", {"default": 0.0, "min": 0.0, "step": 0.001}), | |
"max_float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}), | |
"mask": ("MASK",), | |
}, | |
"optional": { | |
"scaling": (ScaleType.LIST,), | |
} | |
} | |
RETURN_TYPES = ("MULTIVAL",) | |
CATEGORY = "Animate Diff ππ π /multival" | |
FUNCTION = "create_multival" | |
def create_multival(self, min_float_val: float, max_float_val: float, mask: Tensor, scaling: str=ScaleType.ABSOLUTE): | |
# TODO: allow min_float_val and max_float_val to be list[float] | |
if isinstance(min_float_val, Iterable): | |
raise ValueError(f"min_float_val must be type float (no lists allowed here), not {type(min_float_val).__name__}.") | |
if isinstance(max_float_val, Iterable): | |
raise ValueError(f"max_float_val must be type float (no lists allowed here), not {type(max_float_val).__name__}.") | |
if scaling == ScaleType.ABSOLUTE: | |
mask = linear_conversion(mask.clone(), new_min=min_float_val, new_max=max_float_val) | |
elif scaling == ScaleType.RELATIVE: | |
mask = normalize_min_max(mask.clone(), new_min=min_float_val, new_max=max_float_val) | |
else: | |
raise ValueError(f"scaling '{scaling}' not recognized.") | |
return MultivalDynamicNode.create_multival(self, mask_optional=mask) | |
class MultivalDynamicFloatInputNode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "forceInput": True},), | |
}, | |
"optional": { | |
"mask_optional": ("MASK",) | |
} | |
} | |
RETURN_TYPES = ("MULTIVAL",) | |
CATEGORY = "Animate Diff ππ π /multival" | |
FUNCTION = "create_multival" | |
def create_multival(self, float_val: Union[float, list[float]]=None, mask_optional: Tensor=None): | |
return MultivalDynamicNode.create_multival(self, float_val=float_val, mask_optional=mask_optional) | |
class MultivalFloatNode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},), | |
}, | |
} | |
RETURN_TYPES = ("MULTIVAL",) | |
CATEGORY = "Animate Diff ππ π /multival" | |
FUNCTION = "create_multival" | |
def create_multival(self, float_val: Union[float, list[float]]=None): | |
return MultivalDynamicNode.create_multival(self, float_val=float_val) | |