Upload modeling_dplm.py with huggingface_hub
Browse files- modeling_dplm.py +13 -4
modeling_dplm.py
CHANGED
|
@@ -482,7 +482,17 @@ def _try_get_kernels_flash():
|
|
| 482 |
return flash_kernel, flash_kernel_variant
|
| 483 |
|
| 484 |
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
|
| 488 |
def _kernels_flash_forward(
|
|
@@ -656,6 +666,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
|
| 656 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 657 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 658 |
)
|
|
|
|
|
|
|
| 659 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 660 |
if FLASH_KERNEL is not None:
|
| 661 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
@@ -945,9 +957,6 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 945 |
flex_block_mask: "BlockMask | None" = None,
|
| 946 |
) -> tuple[torch.Tensor, None]:
|
| 947 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 948 |
-
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
|
| 949 |
-
f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 950 |
-
)
|
| 951 |
fn = _get_flex_attention_fn()
|
| 952 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 953 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
|
| 482 |
return flash_kernel, flash_kernel_variant
|
| 483 |
|
| 484 |
|
| 485 |
+
_FLASH_KERNELS_LOADED = False
|
| 486 |
+
FLASH_KERNEL = None
|
| 487 |
+
FLASH_KERNEL_VARIANT = None
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _ensure_flash_kernels_loaded():
|
| 491 |
+
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| 492 |
+
if _FLASH_KERNELS_LOADED:
|
| 493 |
+
return
|
| 494 |
+
_FLASH_KERNELS_LOADED = True
|
| 495 |
+
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
| 496 |
|
| 497 |
|
| 498 |
def _kernels_flash_forward(
|
|
|
|
| 666 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 667 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 668 |
)
|
| 669 |
+
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| 670 |
+
_ensure_flash_kernels_loaded()
|
| 671 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 672 |
if FLASH_KERNEL is not None:
|
| 673 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
|
|
| 957 |
flex_block_mask: "BlockMask | None" = None,
|
| 958 |
) -> tuple[torch.Tensor, None]:
|
| 959 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
|
|
|
|
|
|
|
|
|
| 960 |
fn = _get_flex_attention_fn()
|
| 961 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 962 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|