Yash Nagraj commited on
Commit ·
aee1300
1
Parent(s): 1c92417
Add forward function with attention
Browse files- models/blocks.py +21 -0
models/blocks.py
CHANGED
|
@@ -32,6 +32,7 @@ class DownBlock(nn.Module):
|
|
| 32 |
self.context_dim = context_dim
|
| 33 |
self.cross_attn = cross_attn
|
| 34 |
self.t_emb_dim = t_emd_dim
|
|
|
|
| 35 |
self.attn = attn
|
| 36 |
self.resnet_conv_first = nn.ModuleList([
|
| 37 |
nn.Sequential(
|
|
@@ -95,3 +96,23 @@ class DownBlock(nn.Module):
|
|
| 95 |
)
|
| 96 |
|
| 97 |
self.resnet_down_conv = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self.context_dim = context_dim
|
| 33 |
self.cross_attn = cross_attn
|
| 34 |
self.t_emb_dim = t_emd_dim
|
| 35 |
+
self.num_layers = num_layers
|
| 36 |
self.attn = attn
|
| 37 |
self.resnet_conv_first = nn.ModuleList([
|
| 38 |
nn.Sequential(
|
|
|
|
| 96 |
)
|
| 97 |
|
| 98 |
self.resnet_down_conv = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
|
| 99 |
+
|
| 100 |
+
def forward(self, x, t_emb=None, context=None):
|
| 101 |
+
out = x
|
| 102 |
+
for i in range(self.num_layers):
|
| 103 |
+
# Resnet Block
|
| 104 |
+
resnet_input = out
|
| 105 |
+
out = self.resnet_conv_first[i](out)
|
| 106 |
+
if self.t_emb is not None:
|
| 107 |
+
out = out + self.time_embd_layers[i](t_emb)[:, :, None, None]
|
| 108 |
+
out = self.resnet_conv_second[i](out)
|
| 109 |
+
out = out + self.residual_input_conv[i](resnet_input)
|
| 110 |
+
|
| 111 |
+
# Self Attention
|
| 112 |
+
batch_size, channels, h, w = out.shape
|
| 113 |
+
in_attn = out.reshape(batch_size, channels, h*w)
|
| 114 |
+
in_attn = self.attention_norms[i](in_attn)
|
| 115 |
+
in_attn = in_attn.transpose(1, 2)
|
| 116 |
+
out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
|
| 117 |
+
out_attn = out.transpose(1, 2).reshape(batch_size, channels, h, w)
|
| 118 |
+
out = out + out_attn
|