Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -3
modeling_quiet.py
CHANGED
@@ -449,9 +449,10 @@ class QuietFlashAttention2(QuietAttention):
|
|
449 |
key_states = key_states.to(target_dtype)
|
450 |
value_states = value_states.to(target_dtype)
|
451 |
# Reshape to the expected shape for Flash Attention
|
452 |
-
query_states = query_states.
|
453 |
-
key_states = key_states.
|
454 |
-
value_states = value_states.
|
|
|
455 |
|
456 |
attn_output = self._flash_attention_forward(
|
457 |
query_states,
|
|
|
449 |
key_states = key_states.to(target_dtype)
|
450 |
value_states = value_states.to(target_dtype)
|
451 |
# Reshape to the expected shape for Flash Attention
|
452 |
+
query_states = query_states.reshape(bsz, -1, self.num_heads, self.head_dim)
|
453 |
+
key_states = key_states.reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
|
454 |
+
value_states = value_states.reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
|
455 |
+
|
456 |
|
457 |
attn_output = self._flash_attention_forward(
|
458 |
query_states,
|