Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from lightning import LightningModule | |
| class GraphDTA(LightningModule): | |
| """ | |
| From GraphDTA (Nguyen et al., 2020; https://doi.org/10.1093/bioinformatics/btaa921). | |
| """ | |
| def __init__( | |
| self, | |
| gnn: nn.Module, | |
| num_features_protein: int, | |
| n_filters: int, | |
| embed_dim: int, | |
| output_dim: int, | |
| dropout: float | |
| ): | |
| super().__init__() | |
| self.gnn = gnn | |
| # protein sequence encoder (1d conv) | |
| self.embedding_xt = nn.Embedding(num_features_protein, embed_dim) | |
| self.conv_xt = nn.LazyConv1d(out_channels=n_filters, kernel_size=8) | |
| self.fc1_xt = nn.Linear(32 * 121, output_dim) | |
| # combined layers | |
| self.fc1 = nn.Linear(256, 1024) | |
| self.fc2 = nn.Linear(1024, 512) | |
| # activation and regularization | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| # protein input feedforward | |
| def conv_forward_xt(self, v_p): | |
| v_p = self.embedding_xt(v_p.long()) | |
| v_p = self.conv_xt(v_p) | |
| # flatten | |
| v_p = v_p.view(-1, 32 * 121) | |
| v_p = self.fc1_xt(v_p) | |
| return v_p | |
| def forward(self, v_d, v_p): | |
| v_d = self.gnn(v_d) | |
| v_p = self.conv_forward_xt(v_p) | |
| # concat | |
| v_f = torch.cat((v_d, v_p), 1) | |
| # dense layers | |
| v_f = self.fc1(v_f) | |
| v_f = self.relu(v_f) | |
| v_f = self.dropout(v_f) | |
| v_f = self.fc2(v_f) | |
| v_f = self.relu(v_f) | |
| v_f = self.dropout(v_f) | |
| # v_f = self.out(v_f) | |
| return v_f | |