Upload modeling_ankh.py with huggingface_hub
Browse files- modeling_ankh.py +2 -2
modeling_ankh.py
CHANGED
|
@@ -969,7 +969,7 @@ class AnkhSelfAttention(nn.Module):
|
|
| 969 |
|
| 970 |
if output_attentions:
|
| 971 |
attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
|
| 972 |
-
return attn_output, attn_weights, position_bias
|
| 973 |
|
| 974 |
if self.attn_backend == AttentionBackend.FLEX:
|
| 975 |
attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod)
|
|
@@ -978,7 +978,7 @@ class AnkhSelfAttention(nn.Module):
|
|
| 978 |
else:
|
| 979 |
raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}")
|
| 980 |
|
| 981 |
-
return attn_output, None, position_bias
|
| 982 |
|
| 983 |
def _sdpa_attn(
|
| 984 |
self,
|
|
|
|
| 969 |
|
| 970 |
if output_attentions:
|
| 971 |
attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
|
| 972 |
+
return self.o(attn_output), attn_weights, position_bias
|
| 973 |
|
| 974 |
if self.attn_backend == AttentionBackend.FLEX:
|
| 975 |
attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod)
|
|
|
|
| 978 |
else:
|
| 979 |
raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}")
|
| 980 |
|
| 981 |
+
return self.o(attn_output), None, position_bias
|
| 982 |
|
| 983 |
def _sdpa_attn(
|
| 984 |
self,
|