Yash Nagraj commited on
Commit
aee1300
·
1 Parent(s): 1c92417

Add forward function with attention

Browse files
Files changed (1) hide show
  1. models/blocks.py +21 -0
models/blocks.py CHANGED
@@ -32,6 +32,7 @@ class DownBlock(nn.Module):
32
  self.context_dim = context_dim
33
  self.cross_attn = cross_attn
34
  self.t_emb_dim = t_emd_dim
 
35
  self.attn = attn
36
  self.resnet_conv_first = nn.ModuleList([
37
  nn.Sequential(
@@ -95,3 +96,23 @@ class DownBlock(nn.Module):
95
  )
96
 
97
  self.resnet_down_conv = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self.context_dim = context_dim
33
  self.cross_attn = cross_attn
34
  self.t_emb_dim = t_emd_dim
35
+ self.num_layers = num_layers
36
  self.attn = attn
37
  self.resnet_conv_first = nn.ModuleList([
38
  nn.Sequential(
 
96
  )
97
 
98
  self.resnet_down_conv = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
99
+
100
+ def forward(self, x, t_emb=None, context=None):
101
+ out = x
102
+ for i in range(self.num_layers):
103
+ # Resnet Block
104
+ resnet_input = out
105
+ out = self.resnet_conv_first[i](out)
106
+ if self.t_emb is not None:
107
+ out = out + self.time_embd_layers[i](t_emb)[:, :, None, None]
108
+ out = self.resnet_conv_second[i](out)
109
+ out = out + self.residual_input_conv[i](resnet_input)
110
+
111
+ # Self Attention
112
+ batch_size, channels, h, w = out.shape
113
+ in_attn = out.reshape(batch_size, channels, h*w)
114
+ in_attn = self.attention_norms[i](in_attn)
115
+ in_attn = in_attn.transpose(1, 2)
116
+ out_attn, _ = self.attention[i](in_attn, in_attn, in_attn)
117
+ out_attn = out.transpose(1, 2).reshape(batch_size, channels, h, w)
118
+ out = out + out_attn