Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +1 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -409,9 +409,8 @@ def get_attention_mask(
|
|
| 409 |
if attention_mask is None:
|
| 410 |
flex_block_mask = None
|
| 411 |
else:
|
| 412 |
-
sequence_ids = torch.where(token_attention_mask, 1, -1)
|
| 413 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 414 |
-
return (
|
| 415 |
|
| 416 |
flex_block_mask = create_block_mask(
|
| 417 |
mask_mod,
|
|
|
|
| 409 |
if attention_mask is None:
|
| 410 |
flex_block_mask = None
|
| 411 |
else:
|
|
|
|
| 412 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 413 |
+
return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
|
| 414 |
|
| 415 |
flex_block_mask = create_block_mask(
|
| 416 |
mask_mod,
|