Spaces:
Running
Running
import copy | |
from typing import Union | |
from einops import rearrange | |
from torch import Tensor | |
import torch.nn.functional as F | |
import torch | |
import comfy.model_management | |
import comfy.utils | |
from comfy.model_patcher import ModelPatcher | |
from comfy.model_base import BaseModel | |
from .ad_settings import AnimateDiffSettings | |
from .context import ContextOptions, ContextOptions, ContextOptionsGroup | |
from .motion_module_ad import AnimateDiffModel, AnimateDiffFormat, has_mid_block, normalize_ad_state_dict | |
from .logger import logger | |
from .utils_motion import ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, get_combined_multival, normalize_min_max | |
from .motion_lora import MotionLoraInfo, MotionLoraList | |
from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type | |
from .sample_settings import SampleSettings, SeedNoiseGeneration | |
# some motion_model casts here might fail if model becomes metatensor or is not castable; | |
# should not really matter if it fails, so ignore raised Exceptions | |
class ModelPatcherAndInjector(ModelPatcher): | |
def __init__(self, m: ModelPatcher): | |
# replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector | |
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update) | |
self.patches = {} | |
for k in m.patches: | |
self.patches[k] = m.patches[k][:] | |
self.object_patches = m.object_patches.copy() | |
self.model_options = copy.deepcopy(m.model_options) | |
self.model_keys = m.model_keys | |
# injection stuff | |
self.motion_injection_params: InjectionParams = None | |
self.sample_settings: SampleSettings = SampleSettings() | |
self.motion_models: MotionModelGroup = None | |
def model_patches_to(self, device): | |
super().model_patches_to(device) | |
if self.motion_models is not None: | |
for motion_model in self.motion_models.models: | |
try: | |
motion_model.model.to(device) | |
except Exception: | |
pass | |
def patch_model(self, device_to=None): | |
# first, perform model patching | |
patched_model = super().patch_model(device_to) | |
# finally, perform motion model injection | |
self.inject_model(device_to=device_to) | |
return patched_model | |
def unpatch_model(self, device_to=None): | |
# first, eject motion model from unet | |
self.eject_model(device_to=device_to) | |
# finally, do normal model unpatching | |
return super().unpatch_model(device_to) | |
def inject_model(self, device_to=None): | |
if self.motion_models is not None: | |
for motion_model in self.motion_models.models: | |
motion_model.model.inject(self) | |
try: | |
motion_model.model.to(device_to) | |
except Exception: | |
pass | |
def eject_model(self, device_to=None): | |
if self.motion_models is not None: | |
for motion_model in self.motion_models.models: | |
motion_model.model.eject(self) | |
try: | |
motion_model.model.to(device_to) | |
except Exception: | |
pass | |
def clone(self): | |
cloned = ModelPatcherAndInjector(self) | |
cloned.motion_models = self.motion_models.clone() if self.motion_models else self.motion_models | |
cloned.sample_settings = self.sample_settings | |
cloned.motion_injection_params = self.motion_injection_params.clone() if self.motion_injection_params else self.motion_injection_params | |
return cloned | |
class MotionModelPatcher(ModelPatcher): | |
# Mostly here so that type hints work in IDEs | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.model: AnimateDiffModel = self.model | |
self.timestep_percent_range = (0.0, 1.0) | |
self.timestep_range: tuple[float, float] = None | |
self.keyframes: ADKeyframeGroup = ADKeyframeGroup() | |
self.scale_multival = None | |
self.effect_multival = None | |
# temporary variables | |
self.current_used_steps = 0 | |
self.current_keyframe: ADKeyframe = None | |
self.current_index = -1 | |
self.current_scale: Union[float, Tensor] = None | |
self.current_effect: Union[float, Tensor] = None | |
self.combined_scale: Union[float, Tensor] = None | |
self.combined_effect: Union[float, Tensor] = None | |
self.was_within_range = False | |
def patch_model(self, *args, **kwargs): | |
# patch as normal, but prepare_weights so that lowvram meta device works properly | |
patched_model = super().patch_model(*args, **kwargs) | |
self.prepare_weights() | |
return patched_model | |
def prepare_weights(self): | |
# in case lowvram is active and meta device is used, need to convert weights | |
# otherwise, will get exceptions thrown related to meta device | |
# TODO: with new comfy lowvram system, this is unnecessary | |
state_dict = self.model.state_dict() | |
for key in state_dict: | |
weight = comfy.model_management.resolve_lowvram_weight(state_dict[key], self.model, key) | |
try: | |
comfy.utils.set_attr(self.model, key, weight) | |
except Exception: | |
pass | |
def pre_run(self, model: ModelPatcherAndInjector): | |
self.cleanup() | |
self.model.reset() | |
# just in case, prepare_weights before every run | |
self.prepare_weights() | |
self.model.set_scale(self.scale_multival) | |
self.model.set_effect(self.effect_multival) | |
def initialize_timesteps(self, model: BaseModel): | |
self.timestep_range = (model.model_sampling.percent_to_sigma(self.timestep_percent_range[0]), | |
model.model_sampling.percent_to_sigma(self.timestep_percent_range[1])) | |
if self.keyframes is not None: | |
for keyframe in self.keyframes.keyframes: | |
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) | |
def prepare_current_keyframe(self, t: Tensor): | |
curr_t: float = t[0] | |
prev_index = self.current_index | |
# if met guaranteed steps, look for next keyframe in case need to switch | |
if self.current_keyframe is None or self.current_used_steps >= self.current_keyframe.guarantee_steps: | |
# if has next index, loop through and see if need to switch | |
if self.keyframes.has_index(self.current_index+1): | |
for i in range(self.current_index+1, len(self.keyframes)): | |
eval_kf = self.keyframes[i] | |
# check if start_t is greater or equal to curr_t | |
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling | |
if eval_kf.start_t >= curr_t: | |
self.current_index = i | |
self.current_keyframe = eval_kf | |
self.current_used_steps = 0 | |
# keep track of scale and effect multivals, accounting for inherit_missing | |
if self.current_keyframe.has_scale(): | |
self.current_scale = self.current_keyframe.scale_multival | |
elif not self.current_keyframe.inherit_missing: | |
self.current_scale = None | |
if self.current_keyframe.has_effect(): | |
self.current_effect = self.current_keyframe.effect_multival | |
elif not self.current_keyframe.inherit_missing: | |
self.current_effect = None | |
# if guarantee_steps greater than zero, stop searching for other keyframes | |
if self.current_keyframe.guarantee_steps > 0: | |
break | |
# if eval_kf is outside the percent range, stop looking further | |
else: | |
break | |
# if index changed, apply new combined values | |
if prev_index != self.current_index: | |
# combine model's scale and effect with keyframe's scale and effect | |
self.combined_scale = get_combined_multival(self.scale_multival, self.current_scale) | |
self.combined_effect = get_combined_multival(self.effect_multival, self.current_effect) | |
# apply scale and effect | |
self.model.set_scale(self.combined_scale) | |
self.model.set_effect(self.combined_effect) | |
# apply effect - if not within range, set effect to 0, effectively turning model off | |
if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: | |
self.model.set_effect(0.0) | |
self.was_within_range = False | |
else: | |
# if was not in range last step, apply effect to toggle AD status | |
if not self.was_within_range: | |
self.model.set_effect(self.combined_effect) | |
self.was_within_range = True | |
# update steps current keyframe is used | |
self.current_used_steps += 1 | |
def cleanup(self): | |
if self.model is not None: | |
self.model.cleanup() | |
self.current_used_steps = 0 | |
self.current_keyframe = None | |
self.current_index = -1 | |
self.current_scale = None | |
self.current_effect = None | |
self.combined_scale = None | |
self.combined_effect = None | |
self.was_within_range = False | |
def clone(self): | |
# normal ModelPatcher clone actions | |
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) | |
n.patches = {} | |
for k in self.patches: | |
n.patches[k] = self.patches[k][:] | |
n.object_patches = self.object_patches.copy() | |
n.model_options = copy.deepcopy(self.model_options) | |
n.model_keys = self.model_keys | |
# extra cloned params | |
n.timestep_percent_range = self.timestep_percent_range | |
n.timestep_range = self.timestep_range | |
n.keyframes = self.keyframes.clone() | |
n.scale_multival = self.scale_multival | |
n.effect_multival = self.effect_multival | |
return n | |
class MotionModelGroup: | |
def __init__(self, init_motion_model: MotionModelPatcher=None): | |
self.models: list[MotionModelPatcher] = [] | |
if init_motion_model is not None: | |
self.add(init_motion_model) | |
def add(self, mm: MotionModelPatcher): | |
# add to end of list | |
self.models.append(mm) | |
def add_to_start(self, mm: MotionModelPatcher): | |
self.models.insert(0, mm) | |
def __getitem__(self, index) -> MotionModelPatcher: | |
return self.models[index] | |
def is_empty(self) -> bool: | |
return len(self.models) == 0 | |
def clone(self) -> 'MotionModelGroup': | |
cloned = MotionModelGroup() | |
for mm in self.models: | |
cloned.add(mm) | |
return cloned | |
def set_sub_idxs(self, sub_idxs: list[int]): | |
for motion_model in self.models: | |
motion_model.model.set_sub_idxs(sub_idxs=sub_idxs) | |
def set_view_options(self, view_options: ContextOptions): | |
for motion_model in self.models: | |
motion_model.model.set_view_options(view_options) | |
def set_video_length(self, video_length: int, full_length: int): | |
for motion_model in self.models: | |
motion_model.model.set_video_length(video_length=video_length, full_length=full_length) | |
def initialize_timesteps(self, model: BaseModel): | |
for motion_model in self.models: | |
motion_model.initialize_timesteps(model) | |
def pre_run(self, model: ModelPatcherAndInjector): | |
for motion_model in self.models: | |
motion_model.pre_run(model) | |
def prepare_current_keyframe(self, t: Tensor): | |
for motion_model in self.models: | |
motion_model.prepare_current_keyframe(t=t) | |
def get_name_string(self, show_version=False): | |
identifiers = [] | |
for motion_model in self.models: | |
id = motion_model.model.mm_info.mm_name | |
if show_version: | |
id += f":{motion_model.model.mm_info.mm_version}" | |
identifiers.append(id) | |
return ", ".join(identifiers) | |
def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher: | |
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update) | |
model.patches = {} | |
for k in m.patches: | |
model.patches[k] = m.patches[k][:] | |
model.object_patches = m.object_patches.copy() | |
model.model_options = copy.deepcopy(m.model_options) | |
model.model_keys = m.model_keys | |
return model | |
# adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py | |
# Example LoRA keys: | |
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.down.weight | |
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.processor.to_q_lora.up.weight | |
# | |
# Example model keys: | |
# down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight | |
# | |
def load_motion_lora_as_patches(motion_model: MotionModelPatcher, lora: MotionLoraInfo) -> None: | |
def get_version(has_midblock: bool): | |
return "v2" if has_midblock else "v1" | |
lora_path = get_motion_lora_path(lora.name) | |
logger.info(f"Loading motion LoRA {lora.name}") | |
state_dict = comfy.utils.load_torch_file(lora_path) | |
# remove all non-temporal keys (in case model has extra stuff in it) | |
for key in list(state_dict.keys()): | |
if "temporal" not in key: | |
del state_dict[key] | |
if len(state_dict) == 0: | |
raise ValueError(f"'{lora.name}' contains no temporal keys; it is not a valid motion LoRA!") | |
model_has_midblock = motion_model.model.mid_block != None | |
lora_has_midblock = has_mid_block(state_dict) | |
logger.info(f"Applying a {get_version(lora_has_midblock)} LoRA ({lora.name}) to a { motion_model.model.mm_info.mm_version} motion model.") | |
patches = {} | |
# convert lora state dict to one that matches motion_module keys and tensors | |
for key in state_dict: | |
# if motion_module doesn't have a midblock, skip mid_block entries | |
if not model_has_midblock: | |
if "mid_block" in key: continue | |
# only process lora down key (we will process up at the same time as down) | |
if "up." in key: continue | |
# get up key version of down key | |
up_key = key.replace(".down.", ".up.") | |
# adapt key to match motion_module key format - remove 'processor.', '_lora', 'down.', and 'up.' | |
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") | |
# motion_module keys have a '0.' after all 'to_out.' weight keys | |
model_key = model_key.replace("to_out.", "to_out.0.") | |
weight_down = state_dict[key] | |
weight_up = state_dict[up_key] | |
# actual weights obtained by matrix multiplication of up and down weights | |
# save as a tuple, so that (Motion)ModelPatcher's calculate_weight function detects len==1, applying it correctly | |
patches[model_key] = (torch.mm(weight_up, weight_down),) | |
del state_dict | |
# add patches to motion ModelPatcher | |
motion_model.add_patches(patches=patches, strength_patch=lora.strength) | |
def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: MotionLoraList = None, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: | |
model_path = get_motion_model_path(model_name) | |
logger.info(f"Loading motion module {model_name}") | |
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) | |
# TODO: check for empty state dict? | |
# get normalized state_dict and motion model info | |
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) | |
# check that motion model is compatible with sd model | |
model_sd_type = get_sd_model_type(model) | |
if model_sd_type != mm_info.sd_type: | |
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ | |
+ f"but the provided model is type {model_sd_type}.") | |
# apply motion model settings | |
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) | |
# initialize AnimateDiffModelWrapper | |
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) | |
ad_wrapper.to(model.model_dtype()) | |
ad_wrapper.to(model.offload_device) | |
is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM | |
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm) | |
# TODO: report load_result of motion_module loading? | |
# wrap motion_module into a ModelPatcher, to allow motion lora patches | |
motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) | |
# load motion_lora, if present | |
if motion_lora is not None: | |
for lora in motion_lora.loras: | |
load_motion_lora_as_patches(motion_model, lora) | |
return motion_model | |
def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffSettings = None) -> MotionModelPatcher: | |
model_path = get_motion_model_path(model_name) | |
logger.info(f"Loading motion module {model_name} via Gen2") | |
mm_state_dict = comfy.utils.load_torch_file(model_path, safe_load=True) | |
# TODO: check for empty state dict? | |
# get normalized state_dict and motion model info (converts alternate AD models like HotshotXL into AD keys) | |
mm_state_dict, mm_info = normalize_ad_state_dict(mm_state_dict=mm_state_dict, mm_name=model_name) | |
# apply motion model settings | |
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings) | |
# initialize AnimateDiffModelWrapper | |
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) | |
ad_wrapper.to(comfy.model_management.unet_dtype()) | |
ad_wrapper.to(comfy.model_management.unet_offload_device()) | |
is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM | |
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm) | |
# TODO: manually check load_results for AnimateLCM models | |
if is_animatelcm: | |
pass | |
# TODO: report load_result of motion_module loading? | |
# wrap motion_module into a ModelPatcher, to allow motion lora patches | |
motion_model = MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), | |
offload_device=comfy.model_management.unet_offload_device()) | |
return motion_model | |
def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher: | |
ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) | |
ad_wrapper.to(comfy.model_management.unet_dtype()) | |
ad_wrapper.to(comfy.model_management.unet_offload_device()) | |
ad_wrapper.load_state_dict(motion_model.model.state_dict()) | |
return MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), | |
offload_device=comfy.model_management.unet_offload_device()) | |
def validate_model_compatibility_gen2(model: ModelPatcher, motion_model: MotionModelPatcher): | |
# check that motion model is compatible with sd model | |
model_sd_type = get_sd_model_type(model) | |
mm_info = motion_model.model.mm_info | |
if model_sd_type != mm_info.sd_type: | |
raise MotionCompatibilityError(f"Motion module '{mm_info.mm_name}' is intended for {mm_info.sd_type} models, " \ | |
+ f"but the provided model is type {model_sd_type}.") | |
def interpolate_pe_to_length(model_dict: dict[str, Tensor], key: str, new_length: int): | |
pe_shape = model_dict[key].shape | |
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) | |
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") | |
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) | |
model_dict[key] = temp_pe | |
del temp_pe | |
def interpolate_pe_to_length_diffs(model_dict: dict[str, Tensor], key: str, new_length: int): | |
# TODO: fill out and try out | |
pe_shape = model_dict[key].shape | |
temp_pe = rearrange(model_dict[key], "(t b) f d -> t b f d", t=1) | |
temp_pe = F.interpolate(temp_pe, size=(new_length, pe_shape[-1]), mode="bilinear") | |
temp_pe = rearrange(temp_pe, "t b f d -> (t b) f d", t=1) | |
model_dict[key] = temp_pe | |
del temp_pe | |
def interpolate_pe_to_length_pingpong(model_dict: dict[str, Tensor], key: str, new_length: int): | |
if model_dict[key].shape[1] < new_length: | |
temp_pe = model_dict[key] | |
flipped_temp_pe = torch.flip(temp_pe[:, 1:-1, :], [1]) | |
use_flipped = True | |
preview_pe = None | |
while model_dict[key].shape[1] < new_length: | |
preview_pe = model_dict[key] | |
model_dict[key] = torch.cat([model_dict[key], flipped_temp_pe if use_flipped else temp_pe], dim=1) | |
use_flipped = not use_flipped | |
del temp_pe | |
del flipped_temp_pe | |
del preview_pe | |
model_dict[key] = model_dict[key][:, :new_length] | |
def freeze_mask_of_pe(model_dict: dict[str, Tensor], key: str): | |
pe_portion = model_dict[key].shape[2] // 64 | |
first_pe = model_dict[key][:,:1,:] | |
model_dict[key][:,:,pe_portion:] = first_pe[:,:,pe_portion:] | |
del first_pe | |
def freeze_mask_of_attn(model_dict: dict[str, Tensor], key: str): | |
attn_portion = model_dict[key].shape[0] // 2 | |
model_dict[key][:attn_portion,:attn_portion] *= 1.5 | |
def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: AnimateDiffSettings) -> dict[str, Tensor]: | |
if mm_settings is None: | |
return model_dict | |
if not mm_settings.has_anything_to_apply(): | |
return model_dict | |
# first, handle PE Adjustments | |
for adjust in mm_settings.adjust_pe.adjusts: | |
if adjust.has_anything_to_apply(): | |
already_printed = False | |
for key in model_dict: | |
if "attention_blocks" in key and "pos_encoder" in key: | |
# apply simple motion pe stretch, if needed | |
if adjust.has_motion_pe_stretch(): | |
original_length = model_dict[key].shape[1] | |
new_pe_length = original_length + adjust.motion_pe_stretch | |
interpolate_pe_to_length(model_dict, key, new_length=new_pe_length) | |
if adjust.print_adjustment and not already_printed: | |
logger.info(f"[Adjust PE]: PE Stretch from {original_length} to {new_pe_length}.") | |
# apply pe_idx_offset, if needed | |
if adjust.has_initial_pe_idx_offset(): | |
original_length = model_dict[key].shape[1] | |
model_dict[key] = model_dict[key][:, adjust.initial_pe_idx_offset:] | |
if adjust.print_adjustment and not already_printed: | |
logger.info(f"[Adjust PE]: Offsetting PEs by {adjust.initial_pe_idx_offset}; PE length to shortens from {original_length} to {model_dict[key].shape[1]}.") | |
# apply has_cap_initial_pe_length, if needed | |
if adjust.has_cap_initial_pe_length(): | |
original_length = model_dict[key].shape[1] | |
model_dict[key] = model_dict[key][:, :adjust.cap_initial_pe_length] | |
if adjust.print_adjustment and not already_printed: | |
logger.info(f"[Adjust PE]: Capping PEs (initial) from {original_length} to {model_dict[key].shape[1]}.") | |
# apply interpolate_pe_to_length, if needed | |
if adjust.has_interpolate_pe_to_length(): | |
original_length = model_dict[key].shape[1] | |
interpolate_pe_to_length(model_dict, key, new_length=adjust.interpolate_pe_to_length) | |
if adjust.print_adjustment and not already_printed: | |
logger.info(f"[Adjust PE]: Interpolating PE length from {original_length} to {model_dict[key].shape[1]}.") | |
# apply final_pe_idx_offset, if needed | |
if adjust.has_final_pe_idx_offset(): | |
original_length = model_dict[key].shape[1] | |
model_dict[key] = model_dict[key][:, adjust.final_pe_idx_offset:] | |
if adjust.print_adjustment and not already_printed: | |
logger.info(f"[Adjust PE]: Capping PEs (final) from {original_length} to {model_dict[key].shape[1]}.") | |
already_printed = True | |
# finally, apply any weight changes | |
for key in model_dict: | |
if "attention_blocks" in key: | |
if "pos_encoder" in key and mm_settings.adjust_pe.has_anything_to_apply(): | |
# apply pe_strength, if needed | |
if mm_settings.has_pe_strength(): | |
model_dict[key] *= mm_settings.pe_strength | |
else: | |
# apply attn_strenth, if needed | |
if mm_settings.has_attn_strength(): | |
model_dict[key] *= mm_settings.attn_strength | |
# apply specific attn_strengths, if needed | |
if mm_settings.has_any_attn_sub_strength(): | |
if "to_q" in key and mm_settings.has_attn_q_strength(): | |
model_dict[key] *= mm_settings.attn_q_strength | |
elif "to_k" in key and mm_settings.has_attn_k_strength(): | |
model_dict[key] *= mm_settings.attn_k_strength | |
elif "to_v" in key and mm_settings.has_attn_v_strength(): | |
model_dict[key] *= mm_settings.attn_v_strength | |
elif "to_out" in key: | |
if key.strip().endswith("weight") and mm_settings.has_attn_out_weight_strength(): | |
model_dict[key] *= mm_settings.attn_out_weight_strength | |
elif key.strip().endswith("bias") and mm_settings.has_attn_out_bias_strength(): | |
model_dict[key] *= mm_settings.attn_out_bias_strength | |
# apply other strength, if needed | |
elif mm_settings.has_other_strength(): | |
model_dict[key] *= mm_settings.other_strength | |
return model_dict | |
class InjectionParams: | |
def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, model_name: str="", | |
apply_v2_properly: bool=True) -> None: | |
self.full_length = None | |
self.unlimited_area_hack = unlimited_area_hack | |
self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack | |
self.model_name = model_name | |
self.apply_v2_properly = apply_v2_properly | |
self.context_options: ContextOptionsGroup = ContextOptionsGroup.default() | |
self.motion_model_settings = AnimateDiffSettings() # Gen1 | |
self.sub_idxs = None # value should NOT be included in clone, so it will auto reset | |
def set_noise_extra_args(self, noise_extra_args: dict): | |
noise_extra_args["context_options"] = self.context_options.clone() | |
def set_context(self, context_options: ContextOptionsGroup): | |
self.context_options = context_options.clone() if context_options else ContextOptionsGroup.default() | |
def is_using_sliding_context(self) -> bool: | |
return self.context_options.context_length is not None | |
def set_motion_model_settings(self, motion_model_settings: AnimateDiffSettings): # Gen1 | |
if motion_model_settings is None: | |
self.motion_model_settings = AnimateDiffSettings() | |
else: | |
self.motion_model_settings = motion_model_settings | |
def reset_context(self): | |
self.context_options = ContextOptionsGroup.default() | |
def clone(self) -> 'InjectionParams': | |
new_params = InjectionParams( | |
self.unlimited_area_hack, self.apply_mm_groupnorm_hack, | |
self.model_name, apply_v2_properly=self.apply_v2_properly, | |
) | |
new_params.full_length = self.full_length | |
new_params.set_context(self.context_options) | |
new_params.set_motion_model_settings(self.motion_model_settings) # Gen1 | |
return new_params | |