Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +13 -4
modeling_fastesm.py
CHANGED
|
@@ -464,7 +464,17 @@ def _try_get_kernels_flash():
|
|
| 464 |
return flash_kernel, flash_kernel_variant
|
| 465 |
|
| 466 |
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
|
| 470 |
def _kernels_flash_forward(
|
|
@@ -638,6 +648,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
|
| 638 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 639 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 640 |
)
|
|
|
|
|
|
|
| 641 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 642 |
if FLASH_KERNEL is not None:
|
| 643 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
@@ -900,9 +912,6 @@ class EsmSelfAttention(nn.Module):
|
|
| 900 |
flex_block_mask: "BlockMask | None" = None,
|
| 901 |
) -> tuple[torch.Tensor, None]:
|
| 902 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 903 |
-
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
|
| 904 |
-
f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 905 |
-
)
|
| 906 |
fn = _get_flex_attention_fn()
|
| 907 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 908 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
|
| 464 |
return flash_kernel, flash_kernel_variant
|
| 465 |
|
| 466 |
|
| 467 |
+
_FLASH_KERNELS_LOADED = False
|
| 468 |
+
FLASH_KERNEL = None
|
| 469 |
+
FLASH_KERNEL_VARIANT = None
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _ensure_flash_kernels_loaded():
|
| 473 |
+
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| 474 |
+
if _FLASH_KERNELS_LOADED:
|
| 475 |
+
return
|
| 476 |
+
_FLASH_KERNELS_LOADED = True
|
| 477 |
+
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
| 478 |
|
| 479 |
|
| 480 |
def _kernels_flash_forward(
|
|
|
|
| 648 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 649 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 650 |
)
|
| 651 |
+
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| 652 |
+
_ensure_flash_kernels_loaded()
|
| 653 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 654 |
if FLASH_KERNEL is not None:
|
| 655 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
|
|
| 912 |
flex_block_mask: "BlockMask | None" = None,
|
| 913 |
) -> tuple[torch.Tensor, None]:
|
| 914 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
|
|
|
|
|
|
|
|
|
| 915 |
fn = _get_flex_attention_fn()
|
| 916 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
| 917 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|