P2DFlow / openfold /model /embedders.py
Holmes
test
ca7299e
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if self.bins is None:
self.bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = self.bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x