| import math |
| import torch |
| import torch.nn as nn |
| from torch_geometric.nn import GCNConv, GATConv, global_mean_pool |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.seq_len = seq_len |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| pe = torch.zeros(seq_len, d_model) |
|
|
| |
| position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze( |
| 1 |
| ) |
| |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| |
| pe[:, 0::2] = torch.sin(position * div_term) |
| |
| pe[:, 1::2] = torch.cos(position * div_term) |
|
|
| pe = pe.unsqueeze(0) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| |
| x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False) |
| return self.dropout(x) |
|
|
|
|
| class LigandGNN(nn.Module): |
| def __init__(self, input_dim, hidden_channels, dropout): |
| super().__init__() |
| self.hidden_channels = hidden_channels |
|
|
| self.conv1 = GCNConv(input_dim, hidden_channels) |
| self.conv2 = GCNConv(hidden_channels, hidden_channels) |
| self.conv3 = GCNConv(hidden_channels, hidden_channels) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, edge_index, batch): |
| x = self.conv1(x, edge_index) |
| x = x.relu() |
| x = self.dropout(x) |
|
|
| x = self.conv2(x, edge_index) |
| x = x.relu() |
| x = self.conv3(x, edge_index) |
| x = self.dropout(x) |
|
|
| |
| x = global_mean_pool(x, batch) |
| return x |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class ProteinTransformer(nn.Module): |
| def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2): |
| super().__init__() |
| self.d_model = d_model |
| self.embedding = nn.Embedding(vocab_size, d_model) |
| self.pos_encoder = PositionalEncoding(d_model, dropout=dropout) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=h, batch_first=True |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N) |
|
|
| self.fc = nn.Linear(d_model, output_dim) |
|
|
| def forward(self, x): |
| |
| padding_mask = x == 0 |
| x = self.embedding(x) * math.sqrt(self.d_model) |
| x = self.pos_encoder(x) |
| x = self.transformer(x, src_key_padding_mask=padding_mask) |
|
|
| mask = (~padding_mask).float().unsqueeze(-1) |
| x = x * mask |
|
|
| sum_x = x.sum(dim=1) |
| token_counts = mask.sum(dim=1).clamp(min=1e-9) |
| x = sum_x / token_counts |
| x = self.fc(x) |
| return x |
|
|
|
|
| class BindingAffinityModel(nn.Module): |
| def __init__( |
| self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2 |
| ): |
| super().__init__() |
| |
| self.ligand_gnn = LigandGNN( |
| input_dim=num_node_features, |
| hidden_channels=hidden_channels, |
| |
| dropout=dropout, |
| ) |
| |
| self.protein_transformer = ProteinTransformer( |
| vocab_size=26, |
| d_model=hidden_channels, |
| output_dim=hidden_channels, |
| dropout=dropout, |
| ) |
|
|
| self.head = nn.Sequential( |
| nn.Linear(hidden_channels * 2, hidden_channels), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_channels, 1), |
| ) |
|
|
| def forward(self, x, edge_index, batch, protein_seq): |
| ligand_vec = self.ligand_gnn(x, edge_index, batch) |
| batch_size = batch.max().item() + 1 |
| protein_seq = protein_seq.view(batch_size, -1) |
|
|
| protein_vec = self.protein_transformer(protein_seq) |
| combined = torch.cat([ligand_vec, protein_vec], dim=1) |
| return self.head(combined) |
|
|