""" Embedding model classes. """ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import PackedSequence class IdentityEmbed(nn.Module): """ Does not reduce the dimension of the language model embeddings, just passes them through to the contact model. """ def forward(self, x): """ :param x: Input language model embedding :math:`(b \\times N \\times d_0)` :type x: torch.Tensor :return: Same embedding :rtype: torch.Tensor """ return x class FullyConnectedEmbed(nn.Module): """ Protein Projection Module. Takes embedding from language model and outputs low-dimensional interaction aware projection. :param nin: Size of language model output :type nin: int :param nout: Dimension of projection :type nout: int :param dropout: Proportion of weights to drop out [default: 0.5] :type dropout: float :param activation: Activation for linear projection model :type activation: torch.nn.Module """ def __init__(self, nin, nout, dropout=0.5, activation=nn.ReLU()): super(FullyConnectedEmbed, self).__init__() self.nin = nin self.nout = nout self.dropout_p = dropout self.transform = nn.Linear(nin, nout) self.drop = nn.Dropout(p=self.dropout_p) self.activation = activation def forward(self, x): """ :param x: Input language model embedding :math:`(b \\times N \\times d_0)` :type x: torch.Tensor :return: Low dimensional projection of embedding :rtype: torch.Tensor """ t = self.transform(x) t = self.activation(t) t = self.drop(t) return t class SkipLSTM(nn.Module): """ Language model from `Bepler & Berger `_. Loaded with pre-trained weights in embedding function. :param nin: Input dimension of amino acid one-hot [default: 21] :type nin: int :param nout: Output dimension of final layer [default: 100] :type nout: int :param hidden_dim: Size of hidden dimension [default: 1024] :type hidden_dim: int :param num_layers: Number of stacked LSTM models [default: 3] :type num_layers: int :param dropout: Proportion of weights to drop out [default: 0] :type dropout: float :param bidirectional: Whether to use biLSTM vs. LSTM :type bidirectional: bool """ def __init__(self, nin=21, nout=100, hidden_dim=1024, num_layers=3, dropout=0, bidirectional=True): super(SkipLSTM, self).__init__() self.nin = nin self.nout = nout self.dropout = nn.Dropout(p=dropout) self.layers = nn.ModuleList() dim = nin for i in range(num_layers): f = nn.LSTM(dim, hidden_dim, 1, batch_first=True, bidirectional=bidirectional) self.layers.append(f) if bidirectional: dim = 2 * hidden_dim else: dim = hidden_dim n = hidden_dim * num_layers + nin if bidirectional: n = 2 * hidden_dim * num_layers + nin self.proj = nn.Linear(n, nout) def to_one_hot(self, x): """ Transform numeric encoded amino acid vector to one-hot encoded vector :param x: Input numeric amino acid encoding :math:`(N)` :type x: torch.Tensor :return: One-hot encoding vector :math:`(N \\times n_{in})` :rtype: torch.Tensor """ packed = type(x) is PackedSequence if packed: one_hot = x.data.new(x.data.size(0), self.nin).float().zero_() one_hot.scatter_(1, x.data.unsqueeze(1), 1) one_hot = PackedSequence(one_hot, x.batch_sizes) else: one_hot = x.new(x.size(0), x.size(1), self.nin).float().zero_() one_hot.scatter_(2, x.unsqueeze(2), 1) return one_hot def transform(self, x): """ :param x: Input numeric amino acid encoding :math:`(N)` :type x: torch.Tensor :return: Concatenation of all hidden layers :math:`(N \\times (n_{in} + 2 \\times \\text{num_layers} \\times \\text{hidden_dim}))` :rtype: torch.Tensor """ one_hot = self.to_one_hot(x) hs = [one_hot] # [] h_ = one_hot for f in self.layers: h, _ = f(h_) # h = self.dropout(h) hs.append(h) h_ = h if type(x) is PackedSequence: h = torch.cat([z.data for z in hs], 1) h = PackedSequence(h, x.batch_sizes) else: h = torch.cat([z for z in hs], 2) return h def forward(self, x): """ :meta private: """ one_hot = self.to_one_hot(x) hs = [one_hot] h_ = one_hot for f in self.layers: h, _ = f(h_) # h = self.dropout(h) hs.append(h) h_ = h if type(x) is PackedSequence: h = torch.cat([z.data for z in hs], 1) z = self.proj(h) z = PackedSequence(z, x.batch_sizes) else: h = torch.cat([z for z in hs], 2) z = self.proj(h.view(-1, h.size(2))) z = z.view(x.size(0), x.size(1), -1) return z