tolgacangoz commited on
Commit
e61dfa2
·
verified ·
1 Parent(s): 889cc98

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +22 -23
matryoshka.py CHANGED
@@ -1519,8 +1519,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1519
 
1520
  # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
- # attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
- attn_output_cond = attn_output_cond.transpose(-1, -2).reshape(batch_size, channels, *spatial_dims)
1524
  hidden_states = hidden_states + attn_output_cond
1525
 
1526
  if self.ff is not None:
@@ -1612,7 +1611,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1612
  hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1613
 
1614
  # Reshape hidden_states to 2D tensor
1615
- hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1)#.contiguous()
1616
  # Now hidden_states.shape is [batch_size, height * width, channels]
1617
 
1618
  if encoder_hidden_states is None:
@@ -1636,11 +1635,30 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1636
  # key = key.permute(0, 2, 1)
1637
  # value = value.permute(0, 2, 1)
1638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1639
  if attn.norm_q is not None:
1640
  query = attn.norm_q(query)
1641
  if attn.norm_k is not None:
1642
  key = attn.norm_k(key)
1643
 
 
 
 
 
 
 
1644
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1645
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1646
  # hidden_states = self.attention(
@@ -1650,31 +1668,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1650
  # mask=attention_mask,
1651
  # num_heads=attn.heads,
1652
  # )
1653
- inner_dim = key.shape[-1]
1654
- head_dim = inner_dim // attn.heads
1655
- #query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1656
- query = query.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1657
- key = key.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1658
- value = value.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1659
- hidden_states = F.scaled_dot_product_attention(
1660
- query,
1661
- key,
1662
- value,
1663
- attn_mask=attention_mask,
1664
- dropout_p=attn.dropout,
1665
- )
1666
 
1667
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1668
  hidden_states = hidden_states.to(query.dtype)
1669
 
1670
  if self_attention_output is not None:
1671
  hidden_states = hidden_states + self_attention_output
1672
-
1673
- if not attn.pre_only:
1674
- # linear proj
1675
- hidden_states = attn.to_out[0](hidden_states)
1676
- # dropout
1677
- hidden_states = attn.to_out[1](hidden_states)
1678
 
1679
  if attn.residual_connection:
1680
  hidden_states = hidden_states + residual
 
1519
 
1520
  # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
+ attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
 
1523
  hidden_states = hidden_states + attn_output_cond
1524
 
1525
  if self.ff is not None:
 
1611
  hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
1612
 
1613
  # Reshape hidden_states to 2D tensor
1614
+ hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous()
1615
  # Now hidden_states.shape is [batch_size, height * width, channels]
1616
 
1617
  if encoder_hidden_states is None:
 
1635
  # key = key.permute(0, 2, 1)
1636
  # value = value.permute(0, 2, 1)
1637
 
1638
+ if attn.norm_q is not None:
1639
+ query = attn.norm_q(query)
1640
+ if attn.norm_k is not None:
1641
+ key = attn.norm_k(key)
1642
+
1643
+ inner_dim = key.shape[-1]
1644
+ head_dim = inner_dim // attn.heads
1645
+
1646
+ if self_attention_output is None:
1647
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1648
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1649
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1650
+
1651
  if attn.norm_q is not None:
1652
  query = attn.norm_q(query)
1653
  if attn.norm_k is not None:
1654
  key = attn.norm_k(key)
1655
 
1656
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1657
+ # TODO: add support for attn.scale when we move to Torch 2.1
1658
+ hidden_states = F.scaled_dot_product_attention(
1659
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1660
+ )
1661
+
1662
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1663
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1664
  # hidden_states = self.attention(
 
1668
  # mask=attention_mask,
1669
  # num_heads=attn.heads,
1670
  # )
 
 
 
 
 
 
 
 
 
 
 
 
 
1671
 
 
1672
  hidden_states = hidden_states.to(query.dtype)
1673
 
1674
  if self_attention_output is not None:
1675
  hidden_states = hidden_states + self_attention_output
1676
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
 
 
 
 
 
1677
 
1678
  if attn.residual_connection:
1679
  hidden_states = hidden_states + residual