|
""" |
|
Contact model classes. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.functional as F |
|
|
|
|
|
class FullyConnected(nn.Module): |
|
""" |
|
Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor. |
|
|
|
Input embeddings of dimension :math:`d` are combined into a :math:`2d` length MLP input :math:`z_{cat}`, where :math:`z_{cat} = [z_0 \\ominus z_1 | z_0 \\odot z_1]` |
|
|
|
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100] |
|
:type embed_dim: int |
|
:param hidden_dim: Hidden dimension :math:`h` [default: 50] |
|
:type hidden_dim: int |
|
:param activation: Activation function for broadcast tensor [default: torch.nn.ReLU()] |
|
:type activation: torch.nn.Module |
|
""" |
|
|
|
def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()): |
|
super(FullyConnected, self).__init__() |
|
|
|
self.D = embed_dim |
|
self.H = hidden_dim |
|
self.conv = nn.Conv2d(2 * self.D, self.H, 1) |
|
self.batchnorm = nn.BatchNorm2d(self.H) |
|
self.activation = activation |
|
|
|
def forward(self, z0, z1): |
|
""" |
|
:param z0: Projection module embedding :math:`(b \\times N \\times d)` |
|
:type z0: torch.Tensor |
|
:param z1: Projection module embedding :math:`(b \\times M \\times d)` |
|
:type z1: torch.Tensor |
|
:return: Predicted broadcast tensor :math:`(b \\times N \\times M \\times h)` |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
|
|
z0 = z0.transpose(1, 2) |
|
z1 = z1.transpose(1, 2) |
|
|
|
|
|
z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2)) |
|
z_mul = z0.unsqueeze(3) * z1.unsqueeze(2) |
|
z_cat = torch.cat([z_dif, z_mul], 1) |
|
|
|
b = self.conv(z_cat) |
|
b = self.activation(b) |
|
b = self.batchnorm(b) |
|
|
|
return b |
|
|
|
|
|
class ContactCNN(nn.Module): |
|
""" |
|
Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module. |
|
|
|
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100] |
|
:type embed_dim: int |
|
:param hidden_dim: Hidden dimension :math:`h` [default: 50] |
|
:type hidden_dim: int |
|
:param width: Width of convolutional filter :math:`2w+1` [default: 7] |
|
:type width: int |
|
:param activation: Activation function for final contact map [default: torch.nn.Sigmoid()] |
|
:type activation: torch.nn.Module |
|
""" |
|
|
|
def __init__(self, embed_dim=100, hidden_dim=50, width=7, activation=nn.Sigmoid()): |
|
super(ContactCNN, self).__init__() |
|
|
|
self.hidden = FullyConnected(embed_dim, hidden_dim) |
|
self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width // 2) |
|
self.batchnorm = nn.BatchNorm2d(1) |
|
self.activation = activation |
|
self.clip() |
|
|
|
def clip(self): |
|
""" |
|
Force the convolutional layer to be transpose invariant. |
|
|
|
:meta private: |
|
""" |
|
|
|
w = self.conv.weight |
|
self.conv.weight.data[:] = 0.5 * (w + w.transpose(2, 3)) |
|
|
|
def forward(self, z0, z1): |
|
""" |
|
:param z0: Projection module embedding :math:`(b \\times N \\times d)` |
|
:type z0: torch.Tensor |
|
:param z1: Projection module embedding :math:`(b \\times M \\times d)` |
|
:type z1: torch.Tensor |
|
:return: Predicted contact map :math:`(b \\times N \\times M)` |
|
:rtype: torch.Tensor |
|
""" |
|
B = self.broadcast(z0, z1) |
|
return self.predict(B) |
|
|
|
def broadcast(self, z0, z1): |
|
""" |
|
Calls `dscript.models.contact.FullyConnected <#module-dscript.models.contact.FullyConnected>`_. |
|
|
|
:param z0: Projection module embedding :math:`(b \\times N \\times d)` |
|
:type z0: torch.Tensor |
|
:param z1: Projection module embedding :math:`(b \\times M \\times d)` |
|
:type z1: torch.Tensor |
|
:return: Predicted contact broadcast tensor :math:`(b \\times N \\times M \\times h)` |
|
:rtype: torch.Tensor |
|
""" |
|
B = self.hidden(z0, z1) |
|
return B |
|
|
|
def predict(self, B): |
|
""" |
|
Predict contact map from broadcast tensor. |
|
|
|
:param B: Predicted contact broadcast :math:`(b \\times N \\times M \\times h)` |
|
:type B: torch.Tensor |
|
:return: Predicted contact map :math:`(b \\times N \\times M)` |
|
:rtype: torch.Tensor |
|
""" |
|
C = self.conv(B) |
|
C = self.batchnorm(C) |
|
C = self.activation(C) |
|
return C |
|
|