radames's picture
layerdiffuse
7951db8
raw
history blame
13.2 kB
# Currently only sd15
import functools
import torch
import einops
from comfy import model_management, utils
from comfy.ldm.modules.attention import optimized_attention
module_mapping_sd15 = {
0: "input_blocks.1.1.transformer_blocks.0.attn1",
1: "input_blocks.1.1.transformer_blocks.0.attn2",
2: "input_blocks.2.1.transformer_blocks.0.attn1",
3: "input_blocks.2.1.transformer_blocks.0.attn2",
4: "input_blocks.4.1.transformer_blocks.0.attn1",
5: "input_blocks.4.1.transformer_blocks.0.attn2",
6: "input_blocks.5.1.transformer_blocks.0.attn1",
7: "input_blocks.5.1.transformer_blocks.0.attn2",
8: "input_blocks.7.1.transformer_blocks.0.attn1",
9: "input_blocks.7.1.transformer_blocks.0.attn2",
10: "input_blocks.8.1.transformer_blocks.0.attn1",
11: "input_blocks.8.1.transformer_blocks.0.attn2",
12: "output_blocks.3.1.transformer_blocks.0.attn1",
13: "output_blocks.3.1.transformer_blocks.0.attn2",
14: "output_blocks.4.1.transformer_blocks.0.attn1",
15: "output_blocks.4.1.transformer_blocks.0.attn2",
16: "output_blocks.5.1.transformer_blocks.0.attn1",
17: "output_blocks.5.1.transformer_blocks.0.attn2",
18: "output_blocks.6.1.transformer_blocks.0.attn1",
19: "output_blocks.6.1.transformer_blocks.0.attn2",
20: "output_blocks.7.1.transformer_blocks.0.attn1",
21: "output_blocks.7.1.transformer_blocks.0.attn2",
22: "output_blocks.8.1.transformer_blocks.0.attn1",
23: "output_blocks.8.1.transformer_blocks.0.attn2",
24: "output_blocks.9.1.transformer_blocks.0.attn1",
25: "output_blocks.9.1.transformer_blocks.0.attn2",
26: "output_blocks.10.1.transformer_blocks.0.attn1",
27: "output_blocks.10.1.transformer_blocks.0.attn2",
28: "output_blocks.11.1.transformer_blocks.0.attn1",
29: "output_blocks.11.1.transformer_blocks.0.attn2",
30: "middle_block.1.transformer_blocks.0.attn1",
31: "middle_block.1.transformer_blocks.0.attn2",
}
def compute_cond_mark(cond_or_uncond, sigmas):
cond_or_uncond_size = int(sigmas.shape[0])
cond_mark = []
for cx in cond_or_uncond:
cond_mark += [cx] * cond_or_uncond_size
cond_mark = torch.Tensor(cond_mark).to(sigmas)
return cond_mark
class LoRALinearLayer(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
super().__init__()
self.down = torch.nn.Linear(in_features, rank, bias=False)
self.up = torch.nn.Linear(rank, out_features, bias=False)
self.org = [org]
def forward(self, h):
org_weight = self.org[0].weight.to(h)
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
down_weight = self.down.weight
up_weight = self.up.weight
final_weight = org_weight + torch.mm(up_weight, down_weight)
return torch.nn.functional.linear(h, final_weight, org_bias)
class AttentionSharingUnit(torch.nn.Module):
# `transformer_options` passed to the most recent BasicTransformerBlock.forward
# call.
transformer_options: dict = {}
def __init__(self, module, frames=2, use_control=True, rank=256):
super().__init__()
self.heads = module.heads
self.frames = frames
self.original_module = [module]
q_in_channels, q_out_channels = (
module.to_q.in_features,
module.to_q.out_features,
)
k_in_channels, k_out_channels = (
module.to_k.in_features,
module.to_k.out_features,
)
v_in_channels, v_out_channels = (
module.to_v.in_features,
module.to_v.out_features,
)
o_in_channels, o_out_channels = (
module.to_out[0].in_features,
module.to_out[0].out_features,
)
hidden_size = k_out_channels
self.to_q_lora = [
LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
for _ in range(self.frames)
]
self.to_k_lora = [
LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
for _ in range(self.frames)
]
self.to_v_lora = [
LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
for _ in range(self.frames)
]
self.to_out_lora = [
LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
for _ in range(self.frames)
]
self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)
self.temporal_i = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_n = torch.nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6
)
self.temporal_q = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_k = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_v = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.temporal_o = torch.nn.Linear(
in_features=hidden_size, out_features=hidden_size
)
self.control_convs = None
if use_control:
self.control_convs = [
torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(256, hidden_size, kernel_size=1),
)
for _ in range(self.frames)
]
self.control_convs = torch.nn.ModuleList(self.control_convs)
self.control_signals = None
def forward(self, h, context=None, value=None):
transformer_options = self.transformer_options
modified_hidden_states = einops.rearrange(
h, "(b f) d c -> f b d c", f=self.frames
)
if self.control_convs is not None:
context_dim = int(modified_hidden_states.shape[2])
control_outs = []
for f in range(self.frames):
control_signal = self.control_signals[context_dim].to(
modified_hidden_states
)
control = self.control_convs[f](control_signal)
control = einops.rearrange(control, "b c h w -> b (h w) c")
control_outs.append(control)
control_outs = torch.stack(control_outs, dim=0)
modified_hidden_states = modified_hidden_states + control_outs.to(
modified_hidden_states
)
if context is None:
framed_context = modified_hidden_states
else:
framed_context = einops.rearrange(
context, "(b f) d c -> f b d c", f=self.frames
)
framed_cond_mark = einops.rearrange(
compute_cond_mark(
transformer_options["cond_or_uncond"],
transformer_options["sigmas"],
),
"(b f) -> f b",
f=self.frames,
).to(modified_hidden_states)
attn_outs = []
for f in range(self.frames):
fcf = framed_context[f]
if context is not None:
cond_overwrite = transformer_options.get("cond_overwrite", [])
if len(cond_overwrite) > f:
cond_overwrite = cond_overwrite[f]
else:
cond_overwrite = None
if cond_overwrite is not None:
cond_mark = framed_cond_mark[f][:, None, None]
fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark
q = self.to_q_lora[f](modified_hidden_states[f])
k = self.to_k_lora[f](fcf)
v = self.to_v_lora[f](fcf)
o = optimized_attention(q, k, v, self.heads)
o = self.to_out_lora[f](o)
o = self.original_module[0].to_out[1](o)
attn_outs.append(o)
attn_outs = torch.stack(attn_outs, dim=0)
modified_hidden_states = modified_hidden_states + attn_outs.to(
modified_hidden_states
)
modified_hidden_states = einops.rearrange(
modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
)
x = modified_hidden_states
x = self.temporal_n(x)
x = self.temporal_i(x)
d = x.shape[1]
x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)
q = self.temporal_q(x)
k = self.temporal_k(x)
v = self.temporal_v(x)
x = optimized_attention(q, k, v, self.heads)
x = self.temporal_o(x)
x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)
modified_hidden_states = modified_hidden_states + x
return modified_hidden_states - h
@classmethod
def hijack_transformer_block(cls):
def register_get_transformer_options(func):
@functools.wraps(func)
def forward(self, x, context=None, transformer_options={}):
cls.transformer_options = transformer_options
return func(self, x, context, transformer_options)
return forward
from comfy.ldm.modules.attention import BasicTransformerBlock
BasicTransformerBlock.forward = register_get_transformer_options(
BasicTransformerBlock.forward
)
AttentionSharingUnit.hijack_transformer_block()
class AdditionalAttentionCondsEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.blocks_0 = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 64*64*256
self.blocks_1 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 32*32*256
self.blocks_2 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 16*16*256
self.blocks_3 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
torch.nn.SiLU(),
) # 8*8*256
self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]
def __call__(self, h):
results = {}
for b in self.blks:
h = b(h)
results[int(h.shape[2]) * int(h.shape[3])] = h
return results
class HookerLayers(torch.nn.Module):
def __init__(self, layer_list):
super().__init__()
self.layers = torch.nn.ModuleList(layer_list)
class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
super().__init__()
model_management.unload_model_clones(unet)
units = []
for i in range(32):
real_key = module_mapping_sd15[i]
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
u = AttentionSharingUnit(
attn_module, frames=frames, use_control=use_control, rank=rank
)
units.append(u)
unet.add_object_patch("diffusion_model." + real_key, u)
self.hookers = HookerLayers(units)
if use_control:
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
else:
self.kwargs_encoder = None
self.dtype = torch.float32
if model_management.should_use_fp16(model_management.get_torch_device()):
self.dtype = torch.float16
self.hookers.half()
return
def set_control(self, img):
img = img.cpu().float() * 2.0 - 1.0
signals = self.kwargs_encoder(img)
for m in self.hookers.layers:
m.control_signals = signals
return