Spaces:
Runtime error
Runtime error
File size: 8,550 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm
from diffab.modules.common.layers import mask_zero, LayerNorm
from diffab.utils.protein.constants import BBHeavyAtom
def _alpha_from_logits(logits, mask, inf=1e5):
"""
Args:
logits: Logit matrices, (N, L_i, L_j, num_heads).
mask: Masks, (N, L).
Returns:
alpha: Attention weights.
"""
N, L, _, _ = logits.size()
mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *)
mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *)
logits = torch.where(mask_pair, logits, logits - inf)
alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads)
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
return alpha
def _heads(x, n_heads, n_ch):
"""
Args:
x: (..., num_heads * num_channels)
Returns:
(..., num_heads, num_channels)
"""
s = list(x.size())[:-1] + [n_heads, n_ch]
return x.view(*s)
class GABlock(nn.Module):
def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8,
num_value_points=8, num_heads=12, bias=False):
super().__init__()
self.node_feat_dim = node_feat_dim
self.pair_feat_dim = pair_feat_dim
self.value_dim = value_dim
self.query_key_dim = query_key_dim
self.num_query_points = num_query_points
self.num_value_points = num_value_points
self.num_heads = num_heads
# Node
self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)
# Pair
self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)
# Spatial
self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)),
requires_grad=True)
self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)
# Output
self.out_transform = nn.Linear(
in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + (
num_heads * num_value_points * (3 + 3 + 1)),
out_features=node_feat_dim,
)
self.layer_norm_1 = LayerNorm(node_feat_dim)
self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
nn.Linear(node_feat_dim, node_feat_dim))
self.layer_norm_2 = LayerNorm(node_feat_dim)
def _node_logits(self, x):
query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) *
(1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads)
return logits_node
def _pair_logits(self, z):
logits_pair = self.proj_pair_bias(z)
return logits_pair
def _spatial_logits(self, R, t, x):
N, L, _ = t.size()
# Query
query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points,
3) # (N, L, n_heads * n_pnts, 3)
query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3)
query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
# Key
key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points,
3) # (N, L, 3, n_heads * n_pnts)
key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3)
key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
# Q-K Product
sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads)
gamma = F.softplus(self.spatial_coef)
logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points)))
/ 2) # (N, L, L, n_heads)
return logits_spatial
def _pair_aggregation(self, alpha, z):
N, L = z.shape[:2]
feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C)
feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C)
return feat_p2n.reshape(N, L, -1)
def _node_aggregation(self, alpha, x):
N, L = x.shape[:2]
value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch)
feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch)
feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch)
return feat_node.reshape(N, L, -1)
def _spatial_aggregation(self, alpha, R, t, x):
N, L, _ = t.size()
value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points,
3) # (N, L, n_heads * n_v_pnts, 3)
value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points,
3)) # (N, L, n_heads, n_v_pnts, 3)
aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \
value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3)
aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3)
feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3)
feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts)
feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3)
feat_spatial = torch.cat([
feat_points.reshape(N, L, -1),
feat_distance.reshape(N, L, -1),
feat_direction.reshape(N, L, -1),
], dim=-1)
return feat_spatial
def forward(self, R, t, x, z, mask):
"""
Args:
R: Frame basis matrices, (N, L, 3, 3_index).
t: Frame external (absolute) coordinates, (N, L, 3).
x: Node-wise features, (N, L, F).
z: Pair-wise features, (N, L, L, C).
mask: Masks, (N, L).
Returns:
x': Updated node-wise features, (N, L, F).
"""
# Attention logits
logits_node = self._node_logits(x)
logits_pair = self._pair_logits(z)
logits_spatial = self._spatial_logits(R, t, x)
# Summing logits up and apply `softmax`.
logits_sum = logits_node + logits_pair + logits_spatial
alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads)
# Aggregate features
feat_p2n = self._pair_aggregation(alpha, z)
feat_node = self._node_aggregation(alpha, x)
feat_spatial = self._spatial_aggregation(alpha, R, t, x)
# Finally
feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F)
feat_all = mask_zero(mask.unsqueeze(-1), feat_all)
x_updated = self.layer_norm_1(x + feat_all)
x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
return x_updated
class GAEncoder(nn.Module):
def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
super(GAEncoder, self).__init__()
self.blocks = nn.ModuleList([
GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt)
for _ in range(num_layers)
])
def forward(self, R, t, res_feat, pair_feat, mask):
for i, block in enumerate(self.blocks):
res_feat = block(R, t, res_feat, pair_feat, mask)
return res_feat
|