Transformers
Safetensors
dplm2
custom_code
lhallee commited on
Commit
41fdcb2
·
verified ·
1 Parent(s): 21ba4c8

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +13 -4
modeling_dplm2.py CHANGED
@@ -479,7 +479,17 @@ def _try_get_kernels_flash():
479
  return flash_kernel, flash_kernel_variant
480
 
481
 
482
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
 
 
 
 
 
 
 
 
 
 
483
 
484
 
485
  def _kernels_flash_forward(
@@ -653,6 +663,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
653
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
654
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
655
  )
 
 
656
  if requested_backend == AttentionBackend.AUTO.value:
657
  if FLASH_KERNEL is not None:
658
  resolved = AttentionBackend.KERNELS_FLASH
@@ -1033,9 +1045,6 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1033
  flex_block_mask: "BlockMask | None" = None,
1034
  ) -> tuple[torch.Tensor, None]:
1035
  assert flex_attention is not None, "Flex attention is not available in this environment."
1036
- assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
1037
- f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
1038
- )
1039
  fn = _get_flex_attention_fn()
1040
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
1041
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
 
479
  return flash_kernel, flash_kernel_variant
480
 
481
 
482
+ _FLASH_KERNELS_LOADED = False
483
+ FLASH_KERNEL = None
484
+ FLASH_KERNEL_VARIANT = None
485
+
486
+
487
+ def _ensure_flash_kernels_loaded():
488
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
489
+ if _FLASH_KERNELS_LOADED:
490
+ return
491
+ _FLASH_KERNELS_LOADED = True
492
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
493
 
494
 
495
  def _kernels_flash_forward(
 
663
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
664
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
665
  )
666
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
667
+ _ensure_flash_kernels_loaded()
668
  if requested_backend == AttentionBackend.AUTO.value:
669
  if FLASH_KERNEL is not None:
670
  resolved = AttentionBackend.KERNELS_FLASH
 
1045
  flex_block_mask: "BlockMask | None" = None,
1046
  ) -> tuple[torch.Tensor, None]:
1047
  assert flex_attention is not None, "Flex attention is not available in this environment."
 
 
 
1048
  fn = _get_flex_attention_fn()
1049
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
1050
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None