lhallee commited on
Commit
cfa15aa
·
verified ·
1 Parent(s): 2429419

Upload modeling_fast_esmfold.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fast_esmfold.py +50 -35
modeling_fast_esmfold.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch._inductor.config as inductor_config
3
  import torch._dynamo as dynamo
@@ -27,7 +29,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
27
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
28
  """
29
  from enum import Enum
30
- from typing import Optional
 
31
 
32
  import torch
33
  import torch.nn as nn
@@ -45,7 +48,12 @@ _compiled_flex_attention = None
45
 
46
 
47
  def _get_flex_attention_fn():
48
- """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
 
 
 
 
 
49
  global _compiled_flex_attention
50
  if flex_attention is None:
51
  return None
@@ -53,12 +61,15 @@ def _get_flex_attention_fn():
53
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
54
  return flex_attention
55
  if _compiled_flex_attention is None:
56
- _compiled_flex_attention = torch.compile(flex_attention)
 
 
 
57
  return _compiled_flex_attention
58
 
59
 
60
  ### Kernels Flash Attention Detection
61
- def _infer_kernels_flash_variant(kernel) -> str | None:
62
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
63
  return "flash_attn2"
64
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
@@ -174,7 +185,7 @@ class IndexFirstAxis(torch.autograd.Function):
174
  ).reshape(-1, *other_shape)
175
 
176
  @staticmethod
177
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None]:
178
  (indices,) = ctx.saved_tensors
179
  assert grad_output.ndim >= 2
180
  other_shape = grad_output.shape[1:]
@@ -197,7 +208,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
197
  return output
198
 
199
  @staticmethod
200
- def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]:
201
  (indices,) = ctx.saved_tensors
202
  return grad_output[indices], None, None
203
 
@@ -216,7 +227,7 @@ def _unpad_input(
216
  key_layer: torch.Tensor,
217
  value_layer: torch.Tensor,
218
  attention_mask_2d: torch.Tensor,
219
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]:
220
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
221
  seqlens = attention_mask_2d.sum(dim=1).int()
222
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
@@ -232,7 +243,7 @@ def kernels_flash_attention_func(
232
  query_states: torch.Tensor,
233
  key_states: torch.Tensor,
234
  value_states: torch.Tensor,
235
- attention_mask_2d: torch.Tensor | None = None,
236
  causal: bool = False,
237
  ) -> torch.Tensor:
238
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
@@ -305,7 +316,7 @@ def get_attention_mask(
305
  seq_len: int,
306
  device: torch.device,
307
  attention_mask: Optional[torch.Tensor] = None,
308
- ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]:
309
  """Build padding masks once for all encoder layers.
310
 
311
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
@@ -418,11 +429,11 @@ class EsmSelfAttention(nn.Module):
418
  def forward(
419
  self,
420
  hidden_states: torch.Tensor,
421
- attention_mask_2d: torch.Tensor | None = None,
422
- attention_mask_4d: torch.Tensor | None = None,
423
- flex_block_mask: "BlockMask | None" = None,
424
  output_attentions: bool = False,
425
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
426
  batch_size, seq_length = hidden_states.shape[:-1]
427
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
428
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
@@ -448,11 +459,11 @@ class EsmSelfAttention(nn.Module):
448
  query_BHLD: torch.Tensor,
449
  key_BHLD: torch.Tensor,
450
  value_BHLD: torch.Tensor,
451
- attention_mask_2d: torch.Tensor | None = None,
452
- attention_mask_4d: torch.Tensor | None = None,
453
- flex_block_mask: "BlockMask | None" = None,
454
  output_attentions: bool = False,
455
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
456
  if output_attentions:
457
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
458
 
@@ -470,8 +481,8 @@ class EsmSelfAttention(nn.Module):
470
  query_BHLD: torch.Tensor,
471
  key_BHLD: torch.Tensor,
472
  value_BHLD: torch.Tensor,
473
- attention_mask_4d: torch.Tensor | None = None,
474
- ) -> tuple[torch.Tensor, torch.Tensor]:
475
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
476
  if attention_mask_4d is not None:
477
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
@@ -487,8 +498,8 @@ class EsmSelfAttention(nn.Module):
487
  query_BHLD: torch.Tensor,
488
  key_BHLD: torch.Tensor,
489
  value_BHLD: torch.Tensor,
490
- attention_mask_2d: torch.Tensor | None = None,
491
- ) -> tuple[torch.Tensor, None]:
492
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
493
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
494
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
@@ -503,8 +514,8 @@ class EsmSelfAttention(nn.Module):
503
  query_BHLD: torch.Tensor,
504
  key_BHLD: torch.Tensor,
505
  value_BHLD: torch.Tensor,
506
- flex_block_mask: "BlockMask | None" = None,
507
- ) -> tuple[torch.Tensor, None]:
508
  assert flex_attention is not None, "Flex attention is not available in this environment."
509
  fn = _get_flex_attention_fn()
510
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
@@ -515,8 +526,8 @@ class EsmSelfAttention(nn.Module):
515
  query_BHLD: torch.Tensor,
516
  key_BHLD: torch.Tensor,
517
  value_BHLD: torch.Tensor,
518
- attention_mask_4d: torch.Tensor | None = None,
519
- ) -> tuple[torch.Tensor, None]:
520
  context_BHLD = F.scaled_dot_product_attention(
521
  query_BHLD, key_BHLD, value_BHLD,
522
  attn_mask=attention_mask_4d,
@@ -536,11 +547,11 @@ class EsmAttention(nn.Module):
536
  def forward(
537
  self,
538
  hidden_states: torch.Tensor,
539
- attention_mask_2d: torch.Tensor | None = None,
540
- attention_mask_4d: torch.Tensor | None = None,
541
- flex_block_mask: "BlockMask | None" = None,
542
  output_attentions: bool = False,
543
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
544
  hidden_states_ln = self.LayerNorm(hidden_states)
545
  attn_output, attn_weights = self.self(
546
  hidden_states_ln,
@@ -564,11 +575,11 @@ class EsmLayer(nn.Module):
564
  def forward(
565
  self,
566
  hidden_states: torch.Tensor,
567
- attention_mask_2d: torch.Tensor | None = None,
568
- attention_mask_4d: torch.Tensor | None = None,
569
- flex_block_mask: "BlockMask | None" = None,
570
  output_attentions: bool = False,
571
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
572
  attention_output, attn_weights = self.attention(
573
  hidden_states,
574
  attention_mask_2d=attention_mask_2d,
@@ -1203,8 +1214,12 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
1203
  with torch.no_grad():
1204
  output = self.infer(sequence)
1205
  plddt = output["plddt"]
1206
- if plddt.dim() >= 2:
1207
- mean_plddt = float(plddt.mean(dim=-1).mean().item())
 
 
 
 
1208
  else:
1209
  mean_plddt = float(plddt.mean().item())
1210
  result = {
 
1
+ from __future__ import annotations
2
+
3
  import torch
4
  import torch._inductor.config as inductor_config
5
  import torch._dynamo as dynamo
 
29
  flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
30
  """
31
  from enum import Enum
32
+ from functools import partial
33
+ from typing import Dict, List, Optional, Tuple
34
 
35
  import torch
36
  import torch.nn as nn
 
48
 
49
 
50
  def _get_flex_attention_fn():
51
+ """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
52
+
53
+ Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
54
+ on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
55
+ on older hardware.
56
+ """
57
  global _compiled_flex_attention
58
  if flex_attention is None:
59
  return None
 
61
  if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
62
  return flex_attention
63
  if _compiled_flex_attention is None:
64
+ _compiled_flex_attention = torch.compile(
65
+ partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
66
+ dynamic=False,
67
+ )
68
  return _compiled_flex_attention
69
 
70
 
71
  ### Kernels Flash Attention Detection
72
+ def _infer_kernels_flash_variant(kernel) -> Optional[str]:
73
  if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
74
  return "flash_attn2"
75
  if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
 
185
  ).reshape(-1, *other_shape)
186
 
187
  @staticmethod
188
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
189
  (indices,) = ctx.saved_tensors
190
  assert grad_output.ndim >= 2
191
  other_shape = grad_output.shape[1:]
 
208
  return output
209
 
210
  @staticmethod
211
+ def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
212
  (indices,) = ctx.saved_tensors
213
  return grad_output[indices], None, None
214
 
 
227
  key_layer: torch.Tensor,
228
  value_layer: torch.Tensor,
229
  attention_mask_2d: torch.Tensor,
230
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
231
  batch_size, seq_len, num_heads, head_dim = query_layer.shape
232
  seqlens = attention_mask_2d.sum(dim=1).int()
233
  cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
 
243
  query_states: torch.Tensor,
244
  key_states: torch.Tensor,
245
  value_states: torch.Tensor,
246
+ attention_mask_2d: Optional[torch.Tensor] = None,
247
  causal: bool = False,
248
  ) -> torch.Tensor:
249
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
 
316
  seq_len: int,
317
  device: torch.device,
318
  attention_mask: Optional[torch.Tensor] = None,
319
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
320
  """Build padding masks once for all encoder layers.
321
 
322
  Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
 
429
  def forward(
430
  self,
431
  hidden_states: torch.Tensor,
432
+ attention_mask_2d: Optional[torch.Tensor] = None,
433
+ attention_mask_4d: Optional[torch.Tensor] = None,
434
+ flex_block_mask: Optional[BlockMask] = None,
435
  output_attentions: bool = False,
436
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
437
  batch_size, seq_length = hidden_states.shape[:-1]
438
  hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
439
  query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
 
459
  query_BHLD: torch.Tensor,
460
  key_BHLD: torch.Tensor,
461
  value_BHLD: torch.Tensor,
462
+ attention_mask_2d: Optional[torch.Tensor] = None,
463
+ attention_mask_4d: Optional[torch.Tensor] = None,
464
+ flex_block_mask: Optional[BlockMask] = None,
465
  output_attentions: bool = False,
466
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
467
  if output_attentions:
468
  return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
469
 
 
481
  query_BHLD: torch.Tensor,
482
  key_BHLD: torch.Tensor,
483
  value_BHLD: torch.Tensor,
484
+ attention_mask_4d: Optional[torch.Tensor] = None,
485
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
486
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
487
  if attention_mask_4d is not None:
488
  attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
 
498
  query_BHLD: torch.Tensor,
499
  key_BHLD: torch.Tensor,
500
  value_BHLD: torch.Tensor,
501
+ attention_mask_2d: Optional[torch.Tensor] = None,
502
+ ) -> Tuple[torch.Tensor, None]:
503
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
504
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
505
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
514
  query_BHLD: torch.Tensor,
515
  key_BHLD: torch.Tensor,
516
  value_BHLD: torch.Tensor,
517
+ flex_block_mask: Optional[BlockMask] = None,
518
+ ) -> Tuple[torch.Tensor, None]:
519
  assert flex_attention is not None, "Flex attention is not available in this environment."
520
  fn = _get_flex_attention_fn()
521
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
 
526
  query_BHLD: torch.Tensor,
527
  key_BHLD: torch.Tensor,
528
  value_BHLD: torch.Tensor,
529
+ attention_mask_4d: Optional[torch.Tensor] = None,
530
+ ) -> Tuple[torch.Tensor, None]:
531
  context_BHLD = F.scaled_dot_product_attention(
532
  query_BHLD, key_BHLD, value_BHLD,
533
  attn_mask=attention_mask_4d,
 
547
  def forward(
548
  self,
549
  hidden_states: torch.Tensor,
550
+ attention_mask_2d: Optional[torch.Tensor] = None,
551
+ attention_mask_4d: Optional[torch.Tensor] = None,
552
+ flex_block_mask: Optional[BlockMask] = None,
553
  output_attentions: bool = False,
554
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
555
  hidden_states_ln = self.LayerNorm(hidden_states)
556
  attn_output, attn_weights = self.self(
557
  hidden_states_ln,
 
575
  def forward(
576
  self,
577
  hidden_states: torch.Tensor,
578
+ attention_mask_2d: Optional[torch.Tensor] = None,
579
+ attention_mask_4d: Optional[torch.Tensor] = None,
580
+ flex_block_mask: Optional[BlockMask] = None,
581
  output_attentions: bool = False,
582
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
583
  attention_output, attn_weights = self.attention(
584
  hidden_states,
585
  attention_mask_2d=attention_mask_2d,
 
1214
  with torch.no_grad():
1215
  output = self.infer(sequence)
1216
  plddt = output["plddt"]
1217
+ # plddt shape is (batch, L, 37) - per-atom across atom37 types.
1218
+ # Use CA atom (index 1) only, matching PDB B-factor output.
1219
+ if plddt.dim() == 3:
1220
+ mean_plddt = float(plddt[:, :, 1].mean().item())
1221
+ elif plddt.dim() == 2:
1222
+ mean_plddt = float(plddt[:, 1].mean().item())
1223
  else:
1224
  mean_plddt = float(plddt.mean().item())
1225
  result = {