Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +13 -4
modeling_dplm2.py
CHANGED
|
@@ -479,7 +479,17 @@ def _try_get_kernels_flash():
|
|
| 479 |
return flash_kernel, flash_kernel_variant
|
| 480 |
|
| 481 |
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
|
| 485 |
def _kernels_flash_forward(
|
|
@@ -653,6 +663,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
|
| 653 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 654 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 655 |
)
|
|
|
|
|
|
|
| 656 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 657 |
if FLASH_KERNEL is not None:
|
| 658 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
@@ -1033,9 +1045,6 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1033 |
flex_block_mask: "BlockMask | None" = None,
|
| 1034 |
) -> tuple[torch.Tensor, None]:
|
| 1035 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1036 |
-
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
|
| 1037 |
-
f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 1038 |
-
)
|
| 1039 |
fn = _get_flex_attention_fn()
|
| 1040 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 1041 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
|
| 479 |
return flash_kernel, flash_kernel_variant
|
| 480 |
|
| 481 |
|
| 482 |
+
_FLASH_KERNELS_LOADED = False
|
| 483 |
+
FLASH_KERNEL = None
|
| 484 |
+
FLASH_KERNEL_VARIANT = None
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def _ensure_flash_kernels_loaded():
|
| 488 |
+
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| 489 |
+
if _FLASH_KERNELS_LOADED:
|
| 490 |
+
return
|
| 491 |
+
_FLASH_KERNELS_LOADED = True
|
| 492 |
+
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
| 493 |
|
| 494 |
|
| 495 |
def _kernels_flash_forward(
|
|
|
|
| 663 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 664 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 665 |
)
|
| 666 |
+
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| 667 |
+
_ensure_flash_kernels_loaded()
|
| 668 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 669 |
if FLASH_KERNEL is not None:
|
| 670 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
|
|
| 1045 |
flex_block_mask: "BlockMask | None" = None,
|
| 1046 |
) -> tuple[torch.Tensor, None]:
|
| 1047 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
|
|
|
|
|
|
|
|
|
| 1048 |
fn = _get_flex_attention_fn()
|
| 1049 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 1050 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|