Spaces:
Sleeping
Sleeping
""" | |
https://github.com/ProteinDesignLab/protpardelle | |
License: MIT | |
Author: Alex Chu | |
Neural network modules. Many of these are adapted from open source modules. | |
""" | |
from typing import List, Sequence, Optional | |
from einops import rearrange, reduce, repeat | |
from einops.layers.torch import Rearrange | |
import numpy as np | |
from rotary_embedding_torch import RotaryEmbedding | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, EsmModel | |
from core import protein_mpnn | |
from core import residue_constants | |
from core import utils | |
######################################## | |
# Adapted from https://github.com/ermongroup/ddim | |
def downsample(x): | |
return nn.functional.avg_pool2d(x, 2, 2, ceil_mode=True) | |
def upsample_coords(x, shape): | |
new_l, new_w = shape | |
return nn.functional.interpolate(x, size=(new_l, new_w), mode="nearest") | |
######################################## | |
# Adapted from https://github.com/aqlaboratory/openfold | |
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): | |
zero_index = -1 * len(inds) | |
first_inds = list(range(len(tensor.shape[:zero_index]))) | |
return tensor.contiguous().permute(first_inds + [zero_index + i for i in inds]) | |
def lddt( | |
all_atom_pred_pos: torch.Tensor, | |
all_atom_positions: torch.Tensor, | |
all_atom_mask: torch.Tensor, | |
cutoff: float = 15.0, | |
eps: float = 1e-10, | |
per_residue: bool = True, | |
) -> torch.Tensor: | |
n = all_atom_mask.shape[-2] | |
dmat_true = torch.sqrt( | |
eps | |
+ torch.sum( | |
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) | |
** 2, | |
dim=-1, | |
) | |
) | |
dmat_pred = torch.sqrt( | |
eps | |
+ torch.sum( | |
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, | |
dim=-1, | |
) | |
) | |
dists_to_score = ( | |
(dmat_true < cutoff) | |
* all_atom_mask | |
* permute_final_dims(all_atom_mask, (1, 0)) | |
* (1.0 - torch.eye(n, device=all_atom_mask.device)) | |
) | |
dist_l1 = torch.abs(dmat_true - dmat_pred) | |
score = ( | |
(dist_l1 < 0.5).type(dist_l1.dtype) | |
+ (dist_l1 < 1.0).type(dist_l1.dtype) | |
+ (dist_l1 < 2.0).type(dist_l1.dtype) | |
+ (dist_l1 < 4.0).type(dist_l1.dtype) | |
) | |
score = score * 0.25 | |
dims = (-1,) if per_residue else (-2, -1) | |
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims)) | |
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims)) | |
return score | |
class RelativePositionalEncoding(nn.Module): | |
def __init__(self, attn_dim=8, max_rel_idx=32): | |
super().__init__() | |
self.max_rel_idx = max_rel_idx | |
self.n_rel_pos = 2 * self.max_rel_idx + 1 | |
self.linear = nn.Linear(self.n_rel_pos, attn_dim) | |
def forward(self, residue_index): | |
d_ij = residue_index[..., None] - residue_index[..., None, :] | |
v_bins = torch.arange(self.n_rel_pos).to(d_ij.device) - self.max_rel_idx | |
idxs = (d_ij[..., None] - v_bins[None, None]).abs().argmin(-1) | |
p_ij = nn.functional.one_hot(idxs, num_classes=self.n_rel_pos) | |
embeddings = self.linear(p_ij.float()) | |
return embeddings | |
######################################## | |
# Adapted from https://github.com/NVlabs/edm | |
class Noise_Embedding(nn.Module): | |
def __init__(self, num_channels, max_positions=10000, endpoint=False): | |
super().__init__() | |
self.num_channels = num_channels | |
self.max_positions = max_positions | |
self.endpoint = endpoint | |
def forward(self, x): | |
freqs = torch.arange( | |
start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device | |
) | |
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) | |
freqs = (1 / self.max_positions) ** freqs | |
x = x.outer(freqs.to(x.dtype)) | |
x = torch.cat([x.cos(), x.sin()], dim=1) | |
return x | |
######################################## | |
# Adapted from github.com/lucidrains | |
# https://github.com/lucidrains/denoising-diffusion-pytorch | |
# https://github.com/lucidrains/recurrent-interface-network-pytorch | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if callable(d) else d | |
def posemb_sincos_1d(patches, temperature=10000, residue_index=None): | |
_, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype | |
n = torch.arange(n, device=device) if residue_index is None else residue_index | |
assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb" | |
omega = torch.arange(dim // 2, device=device) / (dim // 2 - 1) | |
omega = 1.0 / (temperature**omega) | |
n = n[..., None] * omega | |
pe = torch.cat((n.sin(), n.cos()), dim=-1) | |
return pe.type(dtype) | |
class LayerNorm(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(dim)) | |
self.register_buffer("beta", torch.zeros(dim)) | |
def forward(self, x): | |
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) | |
class NoiseConditioningBlock(nn.Module): | |
def __init__(self, n_in_channel, n_out_channel): | |
super().__init__() | |
self.block = nn.Sequential( | |
Noise_Embedding(n_in_channel), | |
nn.Linear(n_in_channel, n_out_channel), | |
nn.SiLU(), | |
nn.Linear(n_out_channel, n_out_channel), | |
Rearrange("b d -> b 1 d"), | |
) | |
def forward(self, noise_level): | |
return self.block(noise_level) | |
class TimeCondResnetBlock(nn.Module): | |
def __init__( | |
self, nic, noc, cond_nc, conv_layer=nn.Conv2d, dropout=0.1, n_norm_in_groups=4 | |
): | |
super().__init__() | |
self.block1 = nn.Sequential( | |
nn.GroupNorm(num_groups=nic // n_norm_in_groups, num_channels=nic), | |
nn.SiLU(), | |
conv_layer(nic, noc, 3, 1, 1), | |
) | |
self.cond_proj = nn.Linear(cond_nc, noc * 2) | |
self.mid_norm = nn.GroupNorm(num_groups=noc // 4, num_channels=noc) | |
self.dropout = dropout if dropout is None else nn.Dropout(dropout) | |
self.block2 = nn.Sequential( | |
nn.GroupNorm(num_groups=noc // 4, num_channels=noc), | |
nn.SiLU(), | |
conv_layer(noc, noc, 3, 1, 1), | |
) | |
self.mismatch = False | |
if nic != noc: | |
self.mismatch = True | |
self.conv_match = conv_layer(nic, noc, 1, 1, 0) | |
def forward(self, x, time=None): | |
h = self.block1(x) | |
if time is not None: | |
h = self.mid_norm(h) | |
scale, shift = self.cond_proj(time).chunk(2, dim=-1) | |
h = (h * (utils.expand(scale, h) + 1)) + utils.expand(shift, h) | |
if self.dropout is not None: | |
h = self.dropout(h) | |
h = self.block2(h) | |
if self.mismatch: | |
x = self.conv_match(x) | |
return x + h | |
class TimeCondAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_context=None, | |
heads=4, | |
dim_head=32, | |
norm=False, | |
norm_context=False, | |
time_cond_dim=None, | |
attn_bias_dim=None, | |
rotary_embedding_module=None, | |
): | |
super().__init__() | |
hidden_dim = dim_head * heads | |
dim_context = default(dim_context, dim) | |
self.time_cond = None | |
if exists(time_cond_dim): | |
self.time_cond = nn.Sequential(nn.SiLU(), nn.Linear(time_cond_dim, dim * 2)) | |
nn.init.zeros_(self.time_cond[-1].weight) | |
nn.init.zeros_(self.time_cond[-1].bias) | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
self.norm = LayerNorm(dim) if norm else nn.Identity() | |
self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity() | |
self.attn_bias_proj = None | |
if attn_bias_dim is not None: | |
self.attn_bias_proj = nn.Sequential( | |
Rearrange("b a i j -> b i j a"), | |
nn.Linear(attn_bias_dim, heads), | |
Rearrange("b i j a -> b a i j"), | |
) | |
self.to_q = nn.Linear(dim, hidden_dim, bias=False) | |
self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias=False) | |
self.to_out = nn.Linear(hidden_dim, dim, bias=False) | |
nn.init.zeros_(self.to_out.weight) | |
self.use_rope = False | |
if rotary_embedding_module is not None: | |
self.use_rope = True | |
self.rope = rotary_embedding_module | |
def forward(self, x, context=None, time=None, attn_bias=None, seq_mask=None): | |
# attn_bias is b, c, i, j | |
h = self.heads | |
has_context = exists(context) | |
context = default(context, x) | |
if x.shape[-1] != self.norm.gamma.shape[-1]: | |
print(context.shape, x.shape, self.norm.gamma.shape) | |
x = self.norm(x) | |
if exists(time): | |
scale, shift = self.time_cond(time).chunk(2, dim=-1) | |
x = (x * (scale + 1)) + shift | |
if has_context: | |
context = self.norm_context(context) | |
if seq_mask is not None: | |
x = x * seq_mask[..., None] | |
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) | |
q = q * self.scale | |
if self.use_rope: | |
q = self.rope.rotate_queries_or_keys(q) | |
k = self.rope.rotate_queries_or_keys(k) | |
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) | |
if attn_bias is not None: | |
if self.attn_bias_proj is not None: | |
attn_bias = self.attn_bias_proj(attn_bias) | |
sim += attn_bias | |
if seq_mask is not None: | |
attn_mask = torch.einsum("b i, b j -> b i j", seq_mask, seq_mask)[:, None] | |
sim -= (1 - attn_mask) * 1e6 | |
attn = sim.softmax(dim=-1) | |
out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
out = self.to_out(out) | |
if seq_mask is not None: | |
out = out * seq_mask[..., None] | |
return out | |
class TimeCondFeedForward(nn.Module): | |
def __init__(self, dim, mult=4, dim_out=None, time_cond_dim=None, dropout=0.1): | |
super().__init__() | |
if dim_out is None: | |
dim_out = dim | |
self.norm = LayerNorm(dim) | |
self.time_cond = None | |
self.dropout = None | |
inner_dim = int(dim * mult) | |
if exists(time_cond_dim): | |
self.time_cond = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(time_cond_dim, inner_dim * 2), | |
) | |
nn.init.zeros_(self.time_cond[-1].weight) | |
nn.init.zeros_(self.time_cond[-1].bias) | |
self.linear_in = nn.Linear(dim, inner_dim) | |
self.nonlinearity = nn.SiLU() | |
if dropout is not None: | |
self.dropout = nn.Dropout(dropout) | |
self.linear_out = nn.Linear(inner_dim, dim_out) | |
nn.init.zeros_(self.linear_out.weight) | |
nn.init.zeros_(self.linear_out.bias) | |
def forward(self, x, time=None): | |
x = self.norm(x) | |
x = self.linear_in(x) | |
x = self.nonlinearity(x) | |
if exists(time): | |
scale, shift = self.time_cond(time).chunk(2, dim=-1) | |
x = (x * (scale + 1)) + shift | |
if exists(self.dropout): | |
x = self.dropout(x) | |
return self.linear_out(x) | |
class TimeCondTransformer(nn.Module): | |
def __init__( | |
self, | |
dim, | |
depth, | |
heads, | |
dim_head, | |
time_cond_dim, | |
attn_bias_dim=None, | |
mlp_inner_dim_mult=4, | |
position_embedding_type: str = "rotary", | |
): | |
super().__init__() | |
self.rope = None | |
self.pos_emb_type = position_embedding_type | |
if position_embedding_type == "rotary": | |
self.rope = RotaryEmbedding(dim=32) | |
elif position_embedding_type == "relative": | |
self.relpos = nn.Sequential( | |
RelativePositionalEncoding(attn_dim=heads), | |
Rearrange("b i j d -> b d i j"), | |
) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
TimeCondAttention( | |
dim, | |
heads=heads, | |
dim_head=dim_head, | |
norm=True, | |
time_cond_dim=time_cond_dim, | |
attn_bias_dim=attn_bias_dim, | |
rotary_embedding_module=self.rope, | |
), | |
TimeCondFeedForward( | |
dim, mlp_inner_dim_mult, time_cond_dim=time_cond_dim | |
), | |
] | |
) | |
) | |
def forward( | |
self, | |
x, | |
time=None, | |
attn_bias=None, | |
context=None, | |
seq_mask=None, | |
residue_index=None, | |
): | |
if self.pos_emb_type == "absolute": | |
pos_emb = posemb_sincos_1d(x) | |
x = x + pos_emb | |
elif self.pos_emb_type == "absolute_residx": | |
assert residue_index is not None | |
pos_emb = posemb_sincos_1d(x, residue_index=residue_index) | |
x = x + pos_emb | |
elif self.pos_emb_type == "relative": | |
assert residue_index is not None | |
pos_emb = self.relpos(residue_index) | |
attn_bias = pos_emb if attn_bias is None else attn_bias + pos_emb | |
if seq_mask is not None: | |
x = x * seq_mask[..., None] | |
for i, (attn, ff) in enumerate(self.layers): | |
x = x + attn( | |
x, context=context, time=time, attn_bias=attn_bias, seq_mask=seq_mask | |
) | |
x = x + ff(x, time=time) | |
if seq_mask is not None: | |
x = x * seq_mask[..., None] | |
return x | |
class TimeCondUViT(nn.Module): | |
def __init__( | |
self, | |
*, | |
seq_len: int, | |
dim: int, | |
patch_size: int = 1, | |
depth: int = 6, | |
heads: int = 8, | |
dim_head: int = 32, | |
n_filt_per_layer: List[int] = [], | |
n_blocks_per_layer: int = 2, | |
n_atoms: int = 37, | |
channels_per_atom: int = 6, | |
attn_bias_dim: int = None, | |
time_cond_dim: int = None, | |
conv_skip_connection: bool = False, | |
position_embedding_type: str = "rotary", | |
): | |
super().__init__() | |
# Initialize configuration params | |
if time_cond_dim is None: | |
time_cond_dim = dim * 4 | |
self.position_embedding_type = position_embedding_type | |
channels = channels_per_atom | |
self.n_conv_layers = n_conv_layers = len(n_filt_per_layer) | |
if n_conv_layers > 0: | |
post_conv_filt = n_filt_per_layer[-1] | |
self.conv_skip_connection = conv_skip_connection and n_conv_layers == 1 | |
transformer_seq_len = seq_len // (2**n_conv_layers) | |
assert transformer_seq_len % patch_size == 0 | |
num_patches = transformer_seq_len // patch_size | |
dim_a = post_conv_atom_dim = max(1, n_atoms // (2 ** (n_conv_layers - 1))) | |
if n_conv_layers == 0: | |
patch_dim = patch_size * n_atoms * channels_per_atom | |
patch_dim_out = patch_size * n_atoms * 3 | |
dim_a = n_atoms | |
elif conv_skip_connection and n_conv_layers == 1: | |
patch_dim = patch_size * (channels + post_conv_filt) * post_conv_atom_dim | |
patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim | |
elif n_conv_layers > 0: | |
patch_dim = patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim | |
# Make downsampling conv | |
# Downsamples n-1 times where n is n_conv_layers | |
down_conv = [] | |
block_in = channels | |
for i, nf in enumerate(n_filt_per_layer): | |
block_out = nf | |
layer = [] | |
for j in range(n_blocks_per_layer): | |
n_groups = 2 if i == 0 and j == 0 else 4 | |
layer.append( | |
TimeCondResnetBlock( | |
block_in, block_out, time_cond_dim, n_norm_in_groups=n_groups | |
) | |
) | |
block_in = block_out | |
down_conv.append(nn.ModuleList(layer)) | |
self.down_conv = nn.ModuleList(down_conv) | |
# Make transformer | |
self.to_patch_embedding = nn.Sequential( | |
Rearrange("b c (n p) a -> b n (p c a)", p=patch_size), | |
nn.Linear(patch_dim, dim), | |
LayerNorm(dim), | |
) | |
self.transformer = TimeCondTransformer( | |
dim, | |
depth, | |
heads, | |
dim_head, | |
time_cond_dim, | |
attn_bias_dim=attn_bias_dim, | |
position_embedding_type=position_embedding_type, | |
) | |
self.from_patch = nn.Sequential( | |
LayerNorm(dim), | |
nn.Linear(dim, patch_dim_out), | |
Rearrange("b n (p c a) -> b c (n p) a", p=patch_size, a=dim_a), | |
) | |
nn.init.zeros_(self.from_patch[-2].weight) | |
nn.init.zeros_(self.from_patch[-2].bias) | |
# Make upsampling conv | |
up_conv = [] | |
for i, nf in enumerate(reversed(n_filt_per_layer)): | |
skip_in = nf | |
block_out = nf | |
layer = [] | |
for j in range(n_blocks_per_layer): | |
layer.append( | |
TimeCondResnetBlock(block_in + skip_in, block_out, time_cond_dim) | |
) | |
block_in = block_out | |
up_conv.append(nn.ModuleList(layer)) | |
self.up_conv = nn.ModuleList(up_conv) | |
# Conv out | |
if n_conv_layers > 0: | |
self.conv_out = nn.Sequential( | |
nn.GroupNorm(num_groups=block_out // 4, num_channels=block_out), | |
nn.SiLU(), | |
nn.Conv2d(block_out, channels // 2, 3, 1, 1), | |
) | |
def forward( | |
self, coords, time_cond, pair_bias=None, seq_mask=None, residue_index=None | |
): | |
if self.n_conv_layers > 0: # pad up to even dims | |
coords = F.pad(coords, (0, 0, 0, 0, 0, 1, 0, 0)) | |
x = rearr_coords = rearrange(coords, "b n a c -> b c n a") | |
hiddens = [] | |
for i, layer in enumerate(self.down_conv): | |
for block in layer: | |
x = block(x, time=time_cond) | |
hiddens.append(x) | |
if i != self.n_conv_layers - 1: | |
x = downsample(x) | |
if self.conv_skip_connection: | |
x = torch.cat([x, rearr_coords], 1) | |
x = self.to_patch_embedding(x) | |
# if self.position_embedding_type == 'absolute': | |
# pos_emb = posemb_sincos_1d(x) | |
# x = x + pos_emb | |
if seq_mask is not None and x.shape[1] == seq_mask.shape[1]: | |
x *= seq_mask[..., None] | |
x = self.transformer( | |
x, | |
time=time_cond, | |
attn_bias=pair_bias, | |
seq_mask=seq_mask, | |
residue_index=residue_index, | |
) | |
x = self.from_patch(x) | |
for i, layer in enumerate(self.up_conv): | |
for block in layer: | |
x = torch.cat([x, hiddens.pop()], 1) | |
x = block(x, time=time_cond) | |
if i != self.n_conv_layers - 1: | |
x = upsample_coords(x, hiddens[-1].shape[2:]) | |
if self.n_conv_layers > 0: | |
x = self.conv_out(x) | |
x = x[..., :-1, :] # drop even-dims padding | |
x = rearrange(x, "b c n a -> b n a c") | |
return x | |
######################################## | |
class LinearWarmupCosineDecay(torch.optim.lr_scheduler._LRScheduler): | |
def __init__( | |
self, | |
optimizer, | |
max_lr, | |
warmup_steps=1000, | |
decay_steps=int(1e6), | |
min_lr=1e-6, | |
**kwargs, | |
): | |
self.max_lr = max_lr | |
self.min_lr = min_lr | |
self.warmup_steps = warmup_steps | |
self.decay_steps = decay_steps | |
self.total_steps = warmup_steps + decay_steps | |
super(LinearWarmupCosineDecay, self).__init__(optimizer, **kwargs) | |
def get_lr(self): | |
# TODO double check for off-by-one errors | |
if self.last_epoch < self.warmup_steps: | |
curr_lr = self.last_epoch / self.warmup_steps * self.max_lr | |
return [curr_lr for group in self.optimizer.param_groups] | |
elif self.last_epoch < self.total_steps: | |
time = (self.last_epoch - self.warmup_steps) / self.decay_steps * np.pi | |
curr_lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * ( | |
1 + np.cos(time) | |
) | |
return [curr_lr for group in self.optimizer.param_groups] | |
else: | |
return [self.min_lr for group in self.optimizer.param_groups] | |
class NoiseConditionalProteinMPNN(nn.Module): | |
def __init__( | |
self, | |
n_channel=128, | |
n_layers=3, | |
n_neighbors=32, | |
time_cond_dim=None, | |
vocab_size=21, | |
input_S_is_embeddings=False, | |
): | |
super().__init__() | |
self.n_channel = n_channel | |
self.n_layers = n_layers | |
self.n_neighbors = n_neighbors | |
self.time_cond_dim = time_cond_dim | |
self.vocab_size = vocab_size | |
self.bb_idxs_if_atom37 = [ | |
residue_constants.atom_order[a] for a in ["N", "CA", "C", "O"] | |
] | |
self.mpnn = protein_mpnn.ProteinMPNN( | |
num_letters=vocab_size, | |
node_features=n_channel, | |
edge_features=n_channel, | |
hidden_dim=n_channel, | |
num_encoder_layers=n_layers, | |
num_decoder_layers=n_layers, | |
vocab=vocab_size, | |
k_neighbors=n_neighbors, | |
augment_eps=0.0, | |
dropout=0.1, | |
ca_only=False, | |
time_cond_dim=time_cond_dim, | |
input_S_is_embeddings=input_S_is_embeddings, | |
) | |
def forward( | |
self, denoised_coords, noisy_aatype, seq_mask, residue_index, time_cond | |
): | |
if denoised_coords.shape[-2] == 37: | |
denoised_coords = denoised_coords[:, :, self.bb_idxs_if_atom37] | |
node_embs, encoder_embs = self.mpnn( | |
X=denoised_coords, | |
S=noisy_aatype, | |
mask=seq_mask, | |
chain_M=seq_mask, | |
residue_idx=residue_index, | |
chain_encoding_all=seq_mask, | |
randn=None, | |
use_input_decoding_order=False, | |
decoding_order=None, | |
causal_mask=False, | |
time_cond=time_cond, | |
return_node_embs=True, | |
) | |
return node_embs, encoder_embs | |