Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/patches
/dependencies
/diffusers
/control.py
from contextlib import contextmanager | |
from typing import List, Union | |
import torch | |
from diffusers.hooks import HookRegistry, ModelHook | |
_CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK" | |
class ControlChannelConcatenateHook(ModelHook): | |
def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]): | |
self.input_names = input_names | |
self.inputs = inputs | |
self.dims = dims | |
def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims): | |
original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name] | |
control_tensor = torch.cat([original_tensor, input_tensor], dim=dim) | |
if isinstance(input_name, int): | |
args[input_name] = control_tensor | |
else: | |
kwargs[input_name] = control_tensor | |
return args, kwargs | |
def control_channel_concat( | |
module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int] | |
): | |
registry = HookRegistry.check_if_exists_or_initialize(module) | |
hook = ControlChannelConcatenateHook(input_names, inputs, dims) | |
registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK) | |
yield | |
registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False) | |