tolgacangoz
commited on
Upload matryoshka.py
Browse files- matryoshka.py +22 -23
matryoshka.py
CHANGED
@@ -1519,8 +1519,7 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1519 |
|
1520 |
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
-
|
1523 |
-
attn_output_cond = attn_output_cond.transpose(-1, -2).reshape(batch_size, channels, *spatial_dims)
|
1524 |
hidden_states = hidden_states + attn_output_cond
|
1525 |
|
1526 |
if self.ff is not None:
|
@@ -1612,7 +1611,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1612 |
hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
|
1613 |
|
1614 |
# Reshape hidden_states to 2D tensor
|
1615 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1)
|
1616 |
# Now hidden_states.shape is [batch_size, height * width, channels]
|
1617 |
|
1618 |
if encoder_hidden_states is None:
|
@@ -1636,11 +1635,30 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1636 |
# key = key.permute(0, 2, 1)
|
1637 |
# value = value.permute(0, 2, 1)
|
1638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1639 |
if attn.norm_q is not None:
|
1640 |
query = attn.norm_q(query)
|
1641 |
if attn.norm_k is not None:
|
1642 |
key = attn.norm_k(key)
|
1643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1644 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1645 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1646 |
# hidden_states = self.attention(
|
@@ -1650,31 +1668,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1650 |
# mask=attention_mask,
|
1651 |
# num_heads=attn.heads,
|
1652 |
# )
|
1653 |
-
inner_dim = key.shape[-1]
|
1654 |
-
head_dim = inner_dim // attn.heads
|
1655 |
-
#query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1656 |
-
query = query.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1657 |
-
key = key.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1658 |
-
value = value.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1659 |
-
hidden_states = F.scaled_dot_product_attention(
|
1660 |
-
query,
|
1661 |
-
key,
|
1662 |
-
value,
|
1663 |
-
attn_mask=attention_mask,
|
1664 |
-
dropout_p=attn.dropout,
|
1665 |
-
)
|
1666 |
|
1667 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1668 |
hidden_states = hidden_states.to(query.dtype)
|
1669 |
|
1670 |
if self_attention_output is not None:
|
1671 |
hidden_states = hidden_states + self_attention_output
|
1672 |
-
|
1673 |
-
if not attn.pre_only:
|
1674 |
-
# linear proj
|
1675 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1676 |
-
# dropout
|
1677 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1678 |
|
1679 |
if attn.residual_connection:
|
1680 |
hidden_states = hidden_states + residual
|
|
|
1519 |
|
1520 |
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
+
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
|
|
1523 |
hidden_states = hidden_states + attn_output_cond
|
1524 |
|
1525 |
if self.ff is not None:
|
|
|
1611 |
hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2)
|
1612 |
|
1613 |
# Reshape hidden_states to 2D tensor
|
1614 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous()
|
1615 |
# Now hidden_states.shape is [batch_size, height * width, channels]
|
1616 |
|
1617 |
if encoder_hidden_states is None:
|
|
|
1635 |
# key = key.permute(0, 2, 1)
|
1636 |
# value = value.permute(0, 2, 1)
|
1637 |
|
1638 |
+
if attn.norm_q is not None:
|
1639 |
+
query = attn.norm_q(query)
|
1640 |
+
if attn.norm_k is not None:
|
1641 |
+
key = attn.norm_k(key)
|
1642 |
+
|
1643 |
+
inner_dim = key.shape[-1]
|
1644 |
+
head_dim = inner_dim // attn.heads
|
1645 |
+
|
1646 |
+
if self_attention_output is None:
|
1647 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1648 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1649 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1650 |
+
|
1651 |
if attn.norm_q is not None:
|
1652 |
query = attn.norm_q(query)
|
1653 |
if attn.norm_k is not None:
|
1654 |
key = attn.norm_k(key)
|
1655 |
|
1656 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1657 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1658 |
+
hidden_states = F.scaled_dot_product_attention(
|
1659 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1660 |
+
)
|
1661 |
+
|
1662 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1663 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1664 |
# hidden_states = self.attention(
|
|
|
1668 |
# mask=attention_mask,
|
1669 |
# num_heads=attn.heads,
|
1670 |
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1671 |
|
|
|
1672 |
hidden_states = hidden_states.to(query.dtype)
|
1673 |
|
1674 |
if self_attention_output is not None:
|
1675 |
hidden_states = hidden_states + self_attention_output
|
1676 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
|
|
|
|
|
|
|
|
|
1677 |
|
1678 |
if attn.residual_connection:
|
1679 |
hidden_states = hidden_states + residual
|