wilbin's picture
Upload 248 files
8896a5f verified
"""
Contact model classes.
"""
import torch
import torch.nn as nn
import torch.functional as F
class FullyConnected(nn.Module):
"""
Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor.
Input embeddings of dimension :math:`d` are combined into a :math:`2d` length MLP input :math:`z_{cat}`, where :math:`z_{cat} = [z_0 \\ominus z_1 | z_0 \\odot z_1]`
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
:type embed_dim: int
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
:type hidden_dim: int
:param activation: Activation function for broadcast tensor [default: torch.nn.ReLU()]
:type activation: torch.nn.Module
"""
def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()):
super(FullyConnected, self).__init__()
self.D = embed_dim
self.H = hidden_dim
self.conv = nn.Conv2d(2 * self.D, self.H, 1)
self.batchnorm = nn.BatchNorm2d(self.H)
self.activation = activation
def forward(self, z0, z1):
"""
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted broadcast tensor :math:`(b \\times N \\times M \\times h)`
:rtype: torch.Tensor
"""
# z0 is (b,N,d), z1 is (b,M,d)
z0 = z0.transpose(1, 2)
z1 = z1.transpose(1, 2)
# z0 is (b,d,N), z1 is (b,d,M)
z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2))
z_mul = z0.unsqueeze(3) * z1.unsqueeze(2)
z_cat = torch.cat([z_dif, z_mul], 1)
b = self.conv(z_cat)
b = self.activation(b)
b = self.batchnorm(b)
return b
class ContactCNN(nn.Module):
"""
Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module.
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
:type embed_dim: int
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
:type hidden_dim: int
:param width: Width of convolutional filter :math:`2w+1` [default: 7]
:type width: int
:param activation: Activation function for final contact map [default: torch.nn.Sigmoid()]
:type activation: torch.nn.Module
"""
def __init__(self, embed_dim=100, hidden_dim=50, width=7, activation=nn.Sigmoid()):
super(ContactCNN, self).__init__()
self.hidden = FullyConnected(embed_dim, hidden_dim)
self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width // 2)
self.batchnorm = nn.BatchNorm2d(1)
self.activation = activation
self.clip()
def clip(self):
"""
Force the convolutional layer to be transpose invariant.
:meta private:
"""
w = self.conv.weight
self.conv.weight.data[:] = 0.5 * (w + w.transpose(2, 3))
def forward(self, z0, z1):
"""
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted contact map :math:`(b \\times N \\times M)`
:rtype: torch.Tensor
"""
B = self.broadcast(z0, z1)
return self.predict(B)
def broadcast(self, z0, z1):
"""
Calls `dscript.models.contact.FullyConnected <#module-dscript.models.contact.FullyConnected>`_.
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
:type z0: torch.Tensor
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
:type z1: torch.Tensor
:return: Predicted contact broadcast tensor :math:`(b \\times N \\times M \\times h)`
:rtype: torch.Tensor
"""
B = self.hidden(z0, z1)
return B
def predict(self, B):
"""
Predict contact map from broadcast tensor.
:param B: Predicted contact broadcast :math:`(b \\times N \\times M \\times h)`
:type B: torch.Tensor
:return: Predicted contact map :math:`(b \\times N \\times M)`
:rtype: torch.Tensor
"""
C = self.conv(B)
C = self.batchnorm(C)
C = self.activation(C)
return C