lhallee commited on
Commit
0e53d22
·
verified ·
1 Parent(s): 80d0d6e

Upload modeling_dplm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm.py +13 -4
modeling_dplm.py CHANGED
@@ -482,7 +482,17 @@ def _try_get_kernels_flash():
482
  return flash_kernel, flash_kernel_variant
483
 
484
 
485
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
 
 
 
 
 
 
 
 
 
 
486
 
487
 
488
  def _kernels_flash_forward(
@@ -656,6 +666,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
656
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
657
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
658
  )
 
 
659
  if requested_backend == AttentionBackend.AUTO.value:
660
  if FLASH_KERNEL is not None:
661
  resolved = AttentionBackend.KERNELS_FLASH
@@ -945,9 +957,6 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
945
  flex_block_mask: "BlockMask | None" = None,
946
  ) -> tuple[torch.Tensor, None]:
947
  assert flex_attention is not None, "Flex attention is not available in this environment."
948
- assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
949
- f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
950
- )
951
  fn = _get_flex_attention_fn()
952
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
953
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
 
482
  return flash_kernel, flash_kernel_variant
483
 
484
 
485
+ _FLASH_KERNELS_LOADED = False
486
+ FLASH_KERNEL = None
487
+ FLASH_KERNEL_VARIANT = None
488
+
489
+
490
+ def _ensure_flash_kernels_loaded():
491
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
492
+ if _FLASH_KERNELS_LOADED:
493
+ return
494
+ _FLASH_KERNELS_LOADED = True
495
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
496
 
497
 
498
  def _kernels_flash_forward(
 
666
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
667
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
668
  )
669
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
670
+ _ensure_flash_kernels_loaded()
671
  if requested_backend == AttentionBackend.AUTO.value:
672
  if FLASH_KERNEL is not None:
673
  resolved = AttentionBackend.KERNELS_FLASH
 
957
  flex_block_mask: "BlockMask | None" = None,
958
  ) -> tuple[torch.Tensor, None]:
959
  assert flex_attention is not None, "Flex attention is not available in this environment."
 
 
 
960
  fn = _get_flex_attention_fn()
961
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
962
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None