daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
28.7 kB
import copy
from typing import Union
from einops import rearrange
from torch import Tensor
import torch.nn.functional as F
import torch
import comfy.model_management
import comfy.utils
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from .ad_settings import AnimateDiffSettings
from .context import ContextOptions, ContextOptions, ContextOptionsGroup
from .motion_module_ad import AnimateDiffModel, AnimateDiffFormat, has_mid_block, normalize_ad_state_dict
from .logger import logger
from .utils_motion import ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, get_combined_multival, normalize_min_max
from .motion_lora import MotionLoraInfo, MotionLoraList
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type
from .sample_settings import SampleSettings, SeedNoiseGeneration
# some motion_model casts here might fail if model becomes metatensor or is not castable;
# should not really matter if it fails, so ignore raised Exceptions
class ModelPatcherAndInjector(ModelPatcher):
def __init__(self, m: ModelPatcher):
# replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
self.patches = {}
for k in m.patches:
self.patches[k] = m.patches[k][:]
self.object_patches = m.object_patches.copy()
self.model_options = copy.deepcopy(m.model_options)
self.model_keys = m.model_keys
# injection stuff
self.motion_injection_params: InjectionParams = None
self.sample_settings: SampleSettings = SampleSettings()
self.motion_models: MotionModelGroup = None
def model_patches_to(self, device):
super().model_patches_to(device)
if self.motion_models is not None:
for motion_model in self.motion_models.models:
try:
motion_model.model.to(device)
except Exception:
pass
def patch_model(self, device_to=None):
# first, perform model patching
patched_model = super().patch_model(device_to)
# finally, perform motion model injection
self.inject_model(device_to=device_to)
return patched_model
def unpatch_model(self, device_to=None):
# first, eject motion model from unet
self.eject_model(device_to=device_to)
# finally, do normal model unpatching
return super().unpatch_model(device_to)
def inject_model(self, device_to=None):
if self.motion_models is not None:
for motion_model in self.motion_models.models:
motion_model.model.inject(self)
try:
motion_model.model.to(device_to)
except Exception:
pass
def eject_model(self, device_to=None):
if self.motion_models is not None:
for motion_model in self.motion_models.models:
motion_model.model.eject(self)
try:
motion_model.model.to(device_to)
except Exception:
pass
def clone(self):
cloned = ModelPatcherAndInjector(self)
cloned.motion_models = self.motion_models.clone() if self.motion_models else self.motion_models
cloned.sample_settings = self.sample_settings
cloned.motion_injection_params = self.motion_injection_params.clone() if self.motion_injection_params else self.motion_injection_params
return cloned
class MotionModelPatcher(ModelPatcher):
# Mostly here so that type hints work in IDEs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model: AnimateDiffModel = self.model
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range: tuple[float, float] = None
self.keyframes: ADKeyframeGroup = ADKeyframeGroup()
self.scale_multival = None
self.effect_multival = None
# temporary variables
self.current_used_steps = 0
self.current_keyframe: ADKeyframe = None
self.current_index = -1
self.current_scale: Union[float, Tensor] = None
self.current_effect: Union[float, Tensor] = None
self.combined_scale: Union[float, Tensor] = None
self.combined_effect: Union[float, Tensor] = None
self.was_within_range = False
def patch_model(self, *args, **kwargs):
# patch as normal, but prepare_weights so that lowvram meta device works properly
patched_model = super().patch_model(*args, **kwargs)
self.prepare_weights()
return patched_model
def prepare_weights(self):
# in case lowvram is active and meta device is used, need to convert weights
# otherwise, will get exceptions thrown related to meta device
# TODO: with new comfy lowvram system, this is unnecessary
state_dict = self.model.state_dict()
for key in state_dict:
weight = comfy.model_management.resolve_lowvram_weight(state_dict[key], self.model, key)
try:
comfy.utils.set_attr(self.model, key, weight)
except Exception:
pass
def pre_run(self, model: ModelPatcherAndInjector):
self.cleanup()
self.model.reset()
# just in case, prepare_weights before every run
self.prepare_weights()
self.model.set_scale(self.scale_multival)
self.model.set_effect(self.effect_multival)
def initialize_timesteps(self, model: BaseModel):
self.timestep_range = (model.model_sampling.percent_to_sigma(self.timestep_percent_range[0]),
model.model_sampling.percent_to_sigma(self.timestep_percent_range[1]))
if self.keyframes is not None:
for keyframe in self.keyframes.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, t: Tensor):
curr_t: float = t[0]
prev_index = self.current_index
# if met guaranteed steps, look for next keyframe in case need to switch
if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.keyframes.has_index(self.current_index+1):
for i in range(self.current_index+1, len(self.keyframes)):
eval_kf = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_kf.start_t >= curr_t:
self.current_index = i
self.current_keyframe = eval_kf
self.current_used_steps = 0
# keep track of scale and effect multivals, accounting for inherit_missing
if self.current_keyframe.has_scale():
self.current_scale = self.current_keyframe.scale_multival
elif not self.current_keyframe.inherit_missing:
self.current_scale = None
if self.current_keyframe.has_effect():
self.current_effect = self.current_keyframe.effect_multival
elif not self.current_keyframe.inherit_missing:
self.current_effect = None
# if guarantee_steps greater than zero, stop searching for other keyframes
if self.current_keyframe.guarantee_steps > 0:
break
# if eval_kf is outside the percent range, stop looking further
else:
break
# if index changed, apply new combined values
if prev_index != self.current_index:
# combine model's scale and effect with keyframe's scale and effect
self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale)
self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect)
# apply scale and effect
self.model.set_scale(self.combined_scale)
self.model.set_effect(self.combined_effect)
# apply effect - if not within range, set effect to 0, effectively turning model off
if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]:
self.model.set_effect(0.0)
self.was_within_range = False
else:
# if was not in range last step, apply effect to toggle AD status
if not self.was_within_range:
self.model.set_effect(self.combined_effect)
self.was_within_range = True
# update steps current keyframe is used
self.current_used_steps += 1
def cleanup(self):
if self.model is not None:
self.model.cleanup()
self.current_used_steps = 0
self.current_keyframe = None
self.current_index = -1
self.current_scale = None
self.current_effect = None
self.combined_scale = None
self.combined_effect = None
self.was_within_range = False
def clone(self):
# normal ModelPatcher clone actions
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
# extra cloned params
n.timestep_percent_range = self.timestep_percent_range
n.timestep_range = self.timestep_range
n.keyframes = self.keyframes.clone()
n.scale_multival = self.scale_multival
n.effect_multival = self.effect_multival
return n
class MotionModelGroup:
def __init__(self, init_motion_model: MotionModelPatcher=None):
self.models: list[MotionModelPatcher] = []
if init_motion_model is not None:
self.add(init_motion_model)
def add(self, mm: MotionModelPatcher):
# add to end of list
self.models.append(mm)
def add_to_start(self, mm: MotionModelPatcher):
self.models.insert(0, mm)
def __getitem__(self, index) -> MotionModelPatcher:
return self.models[index]
def is_empty(self) -> bool:
return len(self.models) == 0
def clone(self) -> 'MotionModelGroup':
cloned = MotionModelGroup()
for mm in self.models:
cloned.add(mm)
return cloned
def set_sub_idxs(self, sub_idxs: list[int]):
for motion_model in self.models:
motion_model.model.set_sub_idxs(sub_idxs=sub_idxs)
def set_view_options(self, view_options: ContextOptions):
for motion_model in self.models:
motion_model.model.set_view_options(view_options)
def set_video_length(self, video_length: int, full_length: int):
for motion_model in self.models:
motion_model.model.set_video_length(video_length=video_length, full_length=full_length)
def initialize_timesteps(self, model: BaseModel):
for motion_model in self.models:
motion_model.initialize_timesteps(model)
def pre_run(self, model: ModelPatcherAndInjector):
for motion_model in self.models:
motion_model.pre_run(model)
def prepare_current_keyframe(self, t: Tensor):
for motion_model in self.models:
motion_model.prepare_current_keyframe(t=t)
def get_name_string(self, show_version=False):
identifiers = []
for motion_model in self.models:
id = motion_model.model.mm_info.mm_name
if show_version:
id += f":{motion_model.model.mm_info.mm_version}"
identifiers.append(id)
return ", ".join(identifiers)
def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher:
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
model.patches = {}
for k in m.patches:
model.patches[k] = m.patches[k][:]
model.object_patches = m.object_patches.copy()
model.model_options = copy.deepcopy(m.model_options)
model.model_keys = m.model_keys
return model
# adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py
# Example LoRA keys:
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.down.weight
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.up.weight
#
# Example model keys:
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight
#
def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None:
def get_version(has_midblock: bool):
return "v2" if has_midblock else "v1"
lora_path = get_motion_lora_path(lora.name)
logger.info(f"Loading motion LoRA {lora.name}")
state_dict = comfy.utils.load_torch_file(lora_path)
# remove all non-temporal keys (in case model has extra stuff in it)
for key in list(state_dict.keys()):
if "temporal" not in key:
del state_dict[key]
if len(state_dict) == 0:
raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!")
model_has_midblock = motion_model.model.mid_block != None
lora_has_midblock = has_mid_block(state_dict)
logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.")
patches = {}
# convert lora state dict to one that matches motion_module keys and tensors
for key in state_dict:
# if motion_module doesn't have a midblock, skip mid_block entries
if not model_has_midblock:
if "mid_block" in key: continue
# only process lora down key (we will process up at the same time as down)
if "up." in key: continue
# get up key version of down key
up_key = key.replace(".down.", ".up.")
# adapt key to match motion_module key format - remove 'processor.', '_lora', 'down.', and 'up.'
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
# motion_module keys have a '0.' after all 'to_out.' weight keys
model_key = model_key.replace("to_out.", "to_out.0.")
weight_down = state_dict[key]
weight_up = state_dict[up_key]
# actual weights obtained by matrix multiplication of up and down weights
# save as a tuple, so that (Motion)ModelPatcher's calculate_weight function detects len==1, applying it correctly
patches[model_key] = (torch.mm(weight_up, weight_down),)
del state_dict
# add patches to motion ModelPatcher
motion_model.add_patches(patches=patches, strength_patch=lora.strength)
def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher:
model_path = get_motion_model_path(model_name)
logger.info(f"Loading motion module {model_name}")
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True)
# TODO: check for empty state dict?
# get normalized state_dict and motion model info
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name)
# check that motion model is compatible with sd model
model_sd_type = get_sd_model_type(model)
if model_sd_type != mm_info.sd_type:
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \
+ f"but the provided model is type {model_sd_type}.")
# apply motion model settings
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings)
# initialize AnimateDiffModelWrapper
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info)
ad_wrapper.to(model.model_dtype())
ad_wrapper.to(model.offload_device)
is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm)
# TODO: report load_result of motion_module loading?
# wrap motion_module into a ModelPatcher, to allow motion lora patches
motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device)
# load motion_lora, if present
if motion_lora is not None:
for lora in motion_lora.loras:
load_motion_lora_as_patches(motion_model, lora)
return motion_model
def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher:
model_path = get_motion_model_path(model_name)
logger.info(f"Loading motion module {model_name} via Gen2")
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True)
# TODO: check for empty state dict?
# get normalized state_dict and motion model info (converts alternate AD models like HotshotXL into AD keys)
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name)
# apply motion model settings
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings)
# initialize AnimateDiffModelWrapper
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info)
ad_wrapper.to(comfy.model_management.unet_dtype())
ad_wrapper.to(comfy.model_management.unet_offload_device())
is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm)
# TODO: manually check load_results for AnimateLCM models
if is_animatelcm:
pass
# TODO: report load_result of motion_module loading?
# wrap motion_module into a ModelPatcher, to allow motion lora patches
motion_model = MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device())
return motion_model
def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher:
ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info)
ad_wrapper.to(comfy.model_management.unet_dtype())
ad_wrapper.to(comfy.model_management.unet_offload_device())
ad_wrapper.load_state_dict(motion_model.model.state_dict())
return MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device())
def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionModelPatcher):
# check that motion model is compatible with sd model
model_sd_type = get_sd_model_type(model)
mm_info = motion_model.model.mm_info
if model_sd_type != mm_info.sd_type:
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \
+ f"but the provided model is type {model_sd_type}.")
def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int):
pe_shape = model_dict[key].shape
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1)
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear")
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1)
model_dict[key] = temp_pe
del temp_pe
def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int):
# TODO: fill out and try out
pe_shape = model_dict[key].shape
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1)
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear")
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1)
model_dict[key] = temp_pe
del temp_pe
def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int):
if model_dict[key].shape[1] < new_length:
temp_pe = model_dict[key]
flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1])
use_flipped = True
preview_pe = None
while model_dict[key].shape[1] < new_length:
preview_pe = model_dict[key]
model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1)
use_flipped = not use_flipped
del temp_pe
del flipped_temp_pe
del preview_pe
model_dict[key] = model_dict[key][:, :new_length]
def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str):
pe_portion = model_dict[key].shape[2] // 64
first_pe = model_dict[key][:,:1,:]
model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:]
del first_pe
def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str):
attn_portion = model_dict[key].shape[0] // 2
model_dict[key][:attn_portion,:attn_portion] *= 1.5
def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: AnimateDiffSettings) -> dict[str, Tensor]:
if mm_settings is None:
return model_dict
if not mm_settings.has_anything_to_apply():
return model_dict
# first, handle PE Adjustments
for adjust in mm_settings.adjust_pe.adjusts:
if adjust.has_anything_to_apply():
already_printed = False
for key in model_dict:
if "attention_blocks" in key and "pos_encoder" in key:
# apply simple motion pe stretch, if needed
if adjust.has_motion_pe_stretch():
original_length = model_dict[key].shape[1]
new_pe_length = original_length + adjust.motion_pe_stretch
interpolate_pe_to_length(model_dict, key, new_length=new_pe_length)
if adjust.print_adjustment and not already_printed:
logger.info(f"[Adjust PE]: PE Stretch from {original_length} to {new_pe_length}.")
# apply pe_idx_offset, if needed
if adjust.has_initial_pe_idx_offset():
original_length = model_dict[key].shape[1]
model_dict[key] = model_dict[key][:, adjust.initial_pe_idx_offset:]
if adjust.print_adjustment and not already_printed:
logger.info(f"[Adjust PE]: Offsetting PEs by {adjust.initial_pe_idx_offset}; PE length to shortens from {original_length} to {model_dict[key].shape[1]}.")
# apply has_cap_initial_pe_length, if needed
if adjust.has_cap_initial_pe_length():
original_length = model_dict[key].shape[1]
model_dict[key] = model_dict[key][:, :adjust.cap_initial_pe_length]
if adjust.print_adjustment and not already_printed:
logger.info(f"[Adjust PE]: Capping PEs (initial) from {original_length} to {model_dict[key].shape[1]}.")
# apply interpolate_pe_to_length, if needed
if adjust.has_interpolate_pe_to_length():
original_length = model_dict[key].shape[1]
interpolate_pe_to_length(model_dict, key, new_length=adjust.interpolate_pe_to_length)
if adjust.print_adjustment and not already_printed:
logger.info(f"[Adjust PE]: Interpolating PE length from {original_length} to {model_dict[key].shape[1]}.")
# apply final_pe_idx_offset, if needed
if adjust.has_final_pe_idx_offset():
original_length = model_dict[key].shape[1]
model_dict[key] = model_dict[key][:, adjust.final_pe_idx_offset:]
if adjust.print_adjustment and not already_printed:
logger.info(f"[Adjust PE]: Capping PEs (final) from {original_length} to {model_dict[key].shape[1]}.")
already_printed = True
# finally, apply any weight changes
for key in model_dict:
if "attention_blocks" in key:
if "pos_encoder" in key and mm_settings.adjust_pe.has_anything_to_apply():
# apply pe_strength, if needed
if mm_settings.has_pe_strength():
model_dict[key] *= mm_settings.pe_strength
else:
# apply attn_strenth, if needed
if mm_settings.has_attn_strength():
model_dict[key] *= mm_settings.attn_strength
# apply specific attn_strengths, if needed
if mm_settings.has_any_attn_sub_strength():
if "to_q" in key and mm_settings.has_attn_q_strength():
model_dict[key] *= mm_settings.attn_q_strength
elif "to_k" in key and mm_settings.has_attn_k_strength():
model_dict[key] *= mm_settings.attn_k_strength
elif "to_v" in key and mm_settings.has_attn_v_strength():
model_dict[key] *= mm_settings.attn_v_strength
elif "to_out" in key:
if key.strip().endswith("weight") and mm_settings.has_attn_out_weight_strength():
model_dict[key] *= mm_settings.attn_out_weight_strength
elif key.strip().endswith("bias") and mm_settings.has_attn_out_bias_strength():
model_dict[key] *= mm_settings.attn_out_bias_strength
# apply other strength, if needed
elif mm_settings.has_other_strength():
model_dict[key] *= mm_settings.other_strength
return model_dict
class InjectionParams:
def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, model_name: str="",
apply_v2_properly: bool=True) -> None:
self.full_length = None
self.unlimited_area_hack = unlimited_area_hack
self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack
self.model_name = model_name
self.apply_v2_properly = apply_v2_properly
self.context_options: ContextOptionsGroup = ContextOptionsGroup.default()
self.motion_model_settings = AnimateDiffSettings() # Gen1
self.sub_idxs = None # value should NOT be included in clone, so it will auto reset
def set_noise_extra_args(self, noise_extra_args: dict):
noise_extra_args["context_options"] = self.context_options.clone()
def set_context(self, context_options: ContextOptionsGroup):
self.context_options = context_options.clone() if context_options else ContextOptionsGroup.default()
def is_using_sliding_context(self) -> bool:
return self.context_options.context_length is not None
def set_motion_model_settings(self, motion_model_settings: AnimateDiffSettings): # Gen1
if motion_model_settings is None:
self.motion_model_settings = AnimateDiffSettings()
else:
self.motion_model_settings = motion_model_settings
def reset_context(self):
self.context_options = ContextOptionsGroup.default()
def clone(self) -> 'InjectionParams':
new_params = InjectionParams(
self.unlimited_area_hack, self.apply_mm_groupnorm_hack,
self.model_name, apply_v2_properly=self.apply_v2_properly,
)
new_params.full_length = self.full_length
new_params.set_context(self.context_options)
new_params.set_motion_model_settings(self.motion_model_settings) # Gen1
return new_params