lhallee commited on
Commit
3366b37
·
verified ·
1 Parent(s): 7fc9c27

Upload modeling_ankh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_ankh.py +2 -2
modeling_ankh.py CHANGED
@@ -969,7 +969,7 @@ class AnkhSelfAttention(nn.Module):
969
 
970
  if output_attentions:
971
  attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
972
- return attn_output, attn_weights, position_bias
973
 
974
  if self.attn_backend == AttentionBackend.FLEX:
975
  attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod)
@@ -978,7 +978,7 @@ class AnkhSelfAttention(nn.Module):
978
  else:
979
  raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}")
980
 
981
- return attn_output, None, position_bias
982
 
983
  def _sdpa_attn(
984
  self,
 
969
 
970
  if output_attentions:
971
  attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
972
+ return self.o(attn_output), attn_weights, position_bias
973
 
974
  if self.attn_backend == AttentionBackend.FLEX:
975
  attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod)
 
978
  else:
979
  raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}")
980
 
981
+ return self.o(attn_output), None, position_bias
982
 
983
  def _sdpa_attn(
984
  self,