DockFormer / dockformer /model /triangular_multiplicative_update.py
bshor's picture
add code
bca3a49
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from dockformer.model.primitives import Linear, LayerNorm
from dockformer.utils.precision_utils import is_fp16_enabled
from dockformer.utils.tensor_utils import permute_final_dims
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
if(self._outgoing):
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False,
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = mask
a = a * self.sigmoid(self.linear_a_g(z))
a = a * self.linear_a_p(z)
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a_std = a.std()
b_std = b.std()
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
x = self._combine_projections(a, b)
del a, b
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
x = x * g
return x
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)