lhallee commited on
Commit
16a49d6
·
verified ·
1 Parent(s): 8f59aa9

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +13 -4
modeling_fastesm.py CHANGED
@@ -464,7 +464,17 @@ def _try_get_kernels_flash():
464
  return flash_kernel, flash_kernel_variant
465
 
466
 
467
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
 
 
 
 
 
 
 
 
 
 
468
 
469
 
470
  def _kernels_flash_forward(
@@ -638,6 +648,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
638
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
639
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
640
  )
 
 
641
  if requested_backend == AttentionBackend.AUTO.value:
642
  if FLASH_KERNEL is not None:
643
  resolved = AttentionBackend.KERNELS_FLASH
@@ -900,9 +912,6 @@ class EsmSelfAttention(nn.Module):
900
  flex_block_mask: "BlockMask | None" = None,
901
  ) -> tuple[torch.Tensor, None]:
902
  assert flex_attention is not None, "Flex attention is not available in this environment."
903
- assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
904
- f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
905
- )
906
  fn = _get_flex_attention_fn()
907
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
908
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
 
464
  return flash_kernel, flash_kernel_variant
465
 
466
 
467
+ _FLASH_KERNELS_LOADED = False
468
+ FLASH_KERNEL = None
469
+ FLASH_KERNEL_VARIANT = None
470
+
471
+
472
+ def _ensure_flash_kernels_loaded():
473
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
474
+ if _FLASH_KERNELS_LOADED:
475
+ return
476
+ _FLASH_KERNELS_LOADED = True
477
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
478
 
479
 
480
  def _kernels_flash_forward(
 
648
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
649
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
650
  )
651
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
652
+ _ensure_flash_kernels_loaded()
653
  if requested_backend == AttentionBackend.AUTO.value:
654
  if FLASH_KERNEL is not None:
655
  resolved = AttentionBackend.KERNELS_FLASH
 
912
  flex_block_mask: "BlockMask | None" = None,
913
  ) -> tuple[torch.Tensor, None]:
914
  assert flex_attention is not None, "Flex attention is not available in this environment."
 
 
 
915
  fn = _get_flex_attention_fn()
916
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
917
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None