Upload modeling_fast_esmfold.py with huggingface_hub
Browse files- modeling_fast_esmfold.py +50 -35
modeling_fast_esmfold.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._inductor.config as inductor_config
|
| 3 |
import torch._dynamo as dynamo
|
|
@@ -27,7 +29,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
|
|
| 27 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 28 |
"""
|
| 29 |
from enum import Enum
|
| 30 |
-
from
|
|
|
|
| 31 |
|
| 32 |
import torch
|
| 33 |
import torch.nn as nn
|
|
@@ -45,7 +48,12 @@ _compiled_flex_attention = None
|
|
| 45 |
|
| 46 |
|
| 47 |
def _get_flex_attention_fn():
|
| 48 |
-
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
global _compiled_flex_attention
|
| 50 |
if flex_attention is None:
|
| 51 |
return None
|
|
@@ -53,12 +61,15 @@ def _get_flex_attention_fn():
|
|
| 53 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 54 |
return flex_attention
|
| 55 |
if _compiled_flex_attention is None:
|
| 56 |
-
_compiled_flex_attention = torch.compile(
|
|
|
|
|
|
|
|
|
|
| 57 |
return _compiled_flex_attention
|
| 58 |
|
| 59 |
|
| 60 |
### Kernels Flash Attention Detection
|
| 61 |
-
def _infer_kernels_flash_variant(kernel) -> str
|
| 62 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 63 |
return "flash_attn2"
|
| 64 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
@@ -174,7 +185,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 174 |
).reshape(-1, *other_shape)
|
| 175 |
|
| 176 |
@staticmethod
|
| 177 |
-
def backward(ctx, grad_output) ->
|
| 178 |
(indices,) = ctx.saved_tensors
|
| 179 |
assert grad_output.ndim >= 2
|
| 180 |
other_shape = grad_output.shape[1:]
|
|
@@ -197,7 +208,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 197 |
return output
|
| 198 |
|
| 199 |
@staticmethod
|
| 200 |
-
def backward(ctx, grad_output) ->
|
| 201 |
(indices,) = ctx.saved_tensors
|
| 202 |
return grad_output[indices], None, None
|
| 203 |
|
|
@@ -216,7 +227,7 @@ def _unpad_input(
|
|
| 216 |
key_layer: torch.Tensor,
|
| 217 |
value_layer: torch.Tensor,
|
| 218 |
attention_mask_2d: torch.Tensor,
|
| 219 |
-
) ->
|
| 220 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 221 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 222 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
@@ -232,7 +243,7 @@ def kernels_flash_attention_func(
|
|
| 232 |
query_states: torch.Tensor,
|
| 233 |
key_states: torch.Tensor,
|
| 234 |
value_states: torch.Tensor,
|
| 235 |
-
attention_mask_2d: torch.Tensor
|
| 236 |
causal: bool = False,
|
| 237 |
) -> torch.Tensor:
|
| 238 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
@@ -305,7 +316,7 @@ def get_attention_mask(
|
|
| 305 |
seq_len: int,
|
| 306 |
device: torch.device,
|
| 307 |
attention_mask: Optional[torch.Tensor] = None,
|
| 308 |
-
) ->
|
| 309 |
"""Build padding masks once for all encoder layers.
|
| 310 |
|
| 311 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
@@ -418,11 +429,11 @@ class EsmSelfAttention(nn.Module):
|
|
| 418 |
def forward(
|
| 419 |
self,
|
| 420 |
hidden_states: torch.Tensor,
|
| 421 |
-
attention_mask_2d: torch.Tensor
|
| 422 |
-
attention_mask_4d: torch.Tensor
|
| 423 |
-
flex_block_mask:
|
| 424 |
output_attentions: bool = False,
|
| 425 |
-
) ->
|
| 426 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 427 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 428 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
@@ -448,11 +459,11 @@ class EsmSelfAttention(nn.Module):
|
|
| 448 |
query_BHLD: torch.Tensor,
|
| 449 |
key_BHLD: torch.Tensor,
|
| 450 |
value_BHLD: torch.Tensor,
|
| 451 |
-
attention_mask_2d: torch.Tensor
|
| 452 |
-
attention_mask_4d: torch.Tensor
|
| 453 |
-
flex_block_mask:
|
| 454 |
output_attentions: bool = False,
|
| 455 |
-
) ->
|
| 456 |
if output_attentions:
|
| 457 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
|
| 458 |
|
|
@@ -470,8 +481,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 470 |
query_BHLD: torch.Tensor,
|
| 471 |
key_BHLD: torch.Tensor,
|
| 472 |
value_BHLD: torch.Tensor,
|
| 473 |
-
attention_mask_4d: torch.Tensor
|
| 474 |
-
) ->
|
| 475 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 476 |
if attention_mask_4d is not None:
|
| 477 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
@@ -487,8 +498,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 487 |
query_BHLD: torch.Tensor,
|
| 488 |
key_BHLD: torch.Tensor,
|
| 489 |
value_BHLD: torch.Tensor,
|
| 490 |
-
attention_mask_2d: torch.Tensor
|
| 491 |
-
) ->
|
| 492 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 493 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 494 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
@@ -503,8 +514,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 503 |
query_BHLD: torch.Tensor,
|
| 504 |
key_BHLD: torch.Tensor,
|
| 505 |
value_BHLD: torch.Tensor,
|
| 506 |
-
flex_block_mask:
|
| 507 |
-
) ->
|
| 508 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 509 |
fn = _get_flex_attention_fn()
|
| 510 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
@@ -515,8 +526,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 515 |
query_BHLD: torch.Tensor,
|
| 516 |
key_BHLD: torch.Tensor,
|
| 517 |
value_BHLD: torch.Tensor,
|
| 518 |
-
attention_mask_4d: torch.Tensor
|
| 519 |
-
) ->
|
| 520 |
context_BHLD = F.scaled_dot_product_attention(
|
| 521 |
query_BHLD, key_BHLD, value_BHLD,
|
| 522 |
attn_mask=attention_mask_4d,
|
|
@@ -536,11 +547,11 @@ class EsmAttention(nn.Module):
|
|
| 536 |
def forward(
|
| 537 |
self,
|
| 538 |
hidden_states: torch.Tensor,
|
| 539 |
-
attention_mask_2d: torch.Tensor
|
| 540 |
-
attention_mask_4d: torch.Tensor
|
| 541 |
-
flex_block_mask:
|
| 542 |
output_attentions: bool = False,
|
| 543 |
-
) ->
|
| 544 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 545 |
attn_output, attn_weights = self.self(
|
| 546 |
hidden_states_ln,
|
|
@@ -564,11 +575,11 @@ class EsmLayer(nn.Module):
|
|
| 564 |
def forward(
|
| 565 |
self,
|
| 566 |
hidden_states: torch.Tensor,
|
| 567 |
-
attention_mask_2d: torch.Tensor
|
| 568 |
-
attention_mask_4d: torch.Tensor
|
| 569 |
-
flex_block_mask:
|
| 570 |
output_attentions: bool = False,
|
| 571 |
-
) ->
|
| 572 |
attention_output, attn_weights = self.attention(
|
| 573 |
hidden_states,
|
| 574 |
attention_mask_2d=attention_mask_2d,
|
|
@@ -1203,8 +1214,12 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 1203 |
with torch.no_grad():
|
| 1204 |
output = self.infer(sequence)
|
| 1205 |
plddt = output["plddt"]
|
| 1206 |
-
|
| 1207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
else:
|
| 1209 |
mean_plddt = float(plddt.mean().item())
|
| 1210 |
result = {
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch._inductor.config as inductor_config
|
| 5 |
import torch._dynamo as dynamo
|
|
|
|
| 29 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 30 |
"""
|
| 31 |
from enum import Enum
|
| 32 |
+
from functools import partial
|
| 33 |
+
from typing import Dict, List, Optional, Tuple
|
| 34 |
|
| 35 |
import torch
|
| 36 |
import torch.nn as nn
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def _get_flex_attention_fn():
|
| 51 |
+
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
| 52 |
+
|
| 53 |
+
Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
|
| 54 |
+
on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
|
| 55 |
+
on older hardware.
|
| 56 |
+
"""
|
| 57 |
global _compiled_flex_attention
|
| 58 |
if flex_attention is None:
|
| 59 |
return None
|
|
|
|
| 61 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 62 |
return flex_attention
|
| 63 |
if _compiled_flex_attention is None:
|
| 64 |
+
_compiled_flex_attention = torch.compile(
|
| 65 |
+
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
|
| 66 |
+
dynamic=False,
|
| 67 |
+
)
|
| 68 |
return _compiled_flex_attention
|
| 69 |
|
| 70 |
|
| 71 |
### Kernels Flash Attention Detection
|
| 72 |
+
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| 73 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 74 |
return "flash_attn2"
|
| 75 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
|
|
| 185 |
).reshape(-1, *other_shape)
|
| 186 |
|
| 187 |
@staticmethod
|
| 188 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| 189 |
(indices,) = ctx.saved_tensors
|
| 190 |
assert grad_output.ndim >= 2
|
| 191 |
other_shape = grad_output.shape[1:]
|
|
|
|
| 208 |
return output
|
| 209 |
|
| 210 |
@staticmethod
|
| 211 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| 212 |
(indices,) = ctx.saved_tensors
|
| 213 |
return grad_output[indices], None, None
|
| 214 |
|
|
|
|
| 227 |
key_layer: torch.Tensor,
|
| 228 |
value_layer: torch.Tensor,
|
| 229 |
attention_mask_2d: torch.Tensor,
|
| 230 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 231 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 232 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 233 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
|
|
| 243 |
query_states: torch.Tensor,
|
| 244 |
key_states: torch.Tensor,
|
| 245 |
value_states: torch.Tensor,
|
| 246 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 247 |
causal: bool = False,
|
| 248 |
) -> torch.Tensor:
|
| 249 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
|
|
| 316 |
seq_len: int,
|
| 317 |
device: torch.device,
|
| 318 |
attention_mask: Optional[torch.Tensor] = None,
|
| 319 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| 320 |
"""Build padding masks once for all encoder layers.
|
| 321 |
|
| 322 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
|
|
| 429 |
def forward(
|
| 430 |
self,
|
| 431 |
hidden_states: torch.Tensor,
|
| 432 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 433 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 434 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 435 |
output_attentions: bool = False,
|
| 436 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 437 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 438 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 439 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
| 459 |
query_BHLD: torch.Tensor,
|
| 460 |
key_BHLD: torch.Tensor,
|
| 461 |
value_BHLD: torch.Tensor,
|
| 462 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 463 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 464 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 465 |
output_attentions: bool = False,
|
| 466 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 467 |
if output_attentions:
|
| 468 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d)
|
| 469 |
|
|
|
|
| 481 |
query_BHLD: torch.Tensor,
|
| 482 |
key_BHLD: torch.Tensor,
|
| 483 |
value_BHLD: torch.Tensor,
|
| 484 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 485 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 486 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 487 |
if attention_mask_4d is not None:
|
| 488 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
|
|
| 498 |
query_BHLD: torch.Tensor,
|
| 499 |
key_BHLD: torch.Tensor,
|
| 500 |
value_BHLD: torch.Tensor,
|
| 501 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 502 |
+
) -> Tuple[torch.Tensor, None]:
|
| 503 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 504 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 505 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
| 514 |
query_BHLD: torch.Tensor,
|
| 515 |
key_BHLD: torch.Tensor,
|
| 516 |
value_BHLD: torch.Tensor,
|
| 517 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 518 |
+
) -> Tuple[torch.Tensor, None]:
|
| 519 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 520 |
fn = _get_flex_attention_fn()
|
| 521 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
|
|
| 526 |
query_BHLD: torch.Tensor,
|
| 527 |
key_BHLD: torch.Tensor,
|
| 528 |
value_BHLD: torch.Tensor,
|
| 529 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 530 |
+
) -> Tuple[torch.Tensor, None]:
|
| 531 |
context_BHLD = F.scaled_dot_product_attention(
|
| 532 |
query_BHLD, key_BHLD, value_BHLD,
|
| 533 |
attn_mask=attention_mask_4d,
|
|
|
|
| 547 |
def forward(
|
| 548 |
self,
|
| 549 |
hidden_states: torch.Tensor,
|
| 550 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 551 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 552 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 553 |
output_attentions: bool = False,
|
| 554 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 555 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 556 |
attn_output, attn_weights = self.self(
|
| 557 |
hidden_states_ln,
|
|
|
|
| 575 |
def forward(
|
| 576 |
self,
|
| 577 |
hidden_states: torch.Tensor,
|
| 578 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 579 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 580 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 581 |
output_attentions: bool = False,
|
| 582 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 583 |
attention_output, attn_weights = self.attention(
|
| 584 |
hidden_states,
|
| 585 |
attention_mask_2d=attention_mask_2d,
|
|
|
|
| 1214 |
with torch.no_grad():
|
| 1215 |
output = self.infer(sequence)
|
| 1216 |
plddt = output["plddt"]
|
| 1217 |
+
# plddt shape is (batch, L, 37) - per-atom across atom37 types.
|
| 1218 |
+
# Use CA atom (index 1) only, matching PDB B-factor output.
|
| 1219 |
+
if plddt.dim() == 3:
|
| 1220 |
+
mean_plddt = float(plddt[:, :, 1].mean().item())
|
| 1221 |
+
elif plddt.dim() == 2:
|
| 1222 |
+
mean_plddt = float(plddt[:, 1].mean().item())
|
| 1223 |
else:
|
| 1224 |
mean_plddt = float(plddt.mean().item())
|
| 1225 |
result = {
|