Spaces:
Sleeping
Sleeping
File size: 5,807 Bytes
fe3e74d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from ldm.modules.attention import default, zero_module, checkpoint
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from ldm.modules.diffusionmodules.util import timestep_embedding
class DepthAttention(nn.Module):
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False)
self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
if output_bias:
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1)
else:
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False)
def forward(self, x, context):
"""
@param x: b,f0,h,w
@param context: b,f1,d,h,w
@return:
"""
hn, hd = self.heads, self.dim_head
b, _, h, w = x.shape
b, _, d, h, w = context.shape
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
attn = sim.softmax(dim=2)
# b,hn,hd,d,h,w * b,hn,1,d,h,w
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
out = out.reshape(b,hn*hd,h,w)
return self.to_out(out)
class DepthTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
super().__init__()
inner_dim = n_heads * d_head
self.proj_in = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1),
nn.GroupNorm(8, inner_dim),
nn.SiLU(True),
)
self.proj_context = nn.Sequential(
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
nn.GroupNorm(8, context_dim),
nn.ReLU(True), # only relu, because we want input is 0, output is 0
)
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn
self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context):
x_in = x
x = self.proj_in(x)
context = self.proj_context(context)
x = self.depth_attn(x, context)
x = self.proj_out(x) + x_in
return x
class DepthWiseAttention(UNetModel):
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
super().__init__(*args, **kwargs)
# num_heads = 4
model_channels = kwargs['model_channels']
channel_mult = kwargs['channel_mult']
d0,d1,d2,d3 = volume_dims
# 4
ch = model_channels*channel_mult[2]
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3)
self.output_conditions=nn.ModuleList()
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8}
# 8
ch = model_channels*channel_mult[2]
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
# 16
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
ch = model_channels*channel_mult[1]
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
# 32
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
ch = model_channels*channel_mult[0]
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for index, module in enumerate(self.input_blocks):
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
h = self.middle_conditions(h, context=source_dict[h.shape[-1]])
for index, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if index in self.output_b2c:
layer = self.output_conditions[self.output_b2c[index]]
h = layer(h, context=source_dict[h.shape[-1]])
h = h.type(x.dtype)
return self.out(h)
def get_trainable_parameters(self):
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()]
return paras
|