|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.modules.loss import _WeightedLoss |
|
from torch_scatter import scatter_mean, scatter_add |
|
|
|
|
|
def split_tensor_by_batch(x, batch, num_graphs=None): |
|
""" |
|
Args: |
|
x: (N, ...) |
|
batch: (B, ) |
|
Returns: |
|
[(N_1, ), (N_2, ) ..., (N_B, ))] |
|
""" |
|
if num_graphs is None: |
|
num_graphs = batch.max().item() + 1 |
|
x_split = [] |
|
for i in range (num_graphs): |
|
mask = batch == i |
|
x_split.append(x[mask]) |
|
return x_split |
|
|
|
|
|
def concat_tensors_to_batch(x_split): |
|
x = torch.cat(x_split, dim=0) |
|
batch = torch.repeat_interleave( |
|
torch.arange(len(x_split)), |
|
repeats=torch.LongTensor([s.size(0) for s in x_split]) |
|
).to(device=x.device) |
|
return x, batch |
|
|
|
|
|
def split_tensor_to_segments(x, segsize): |
|
num_segs = math.ceil(x.size(0) / segsize) |
|
segs = [] |
|
for i in range(num_segs): |
|
segs.append(x[i*segsize : (i+1)*segsize]) |
|
return segs |
|
|
|
|
|
def split_tensor_by_lengths(x, lengths): |
|
segs = [] |
|
for l in lengths: |
|
segs.append(x[:l]) |
|
x = x[l:] |
|
return segs |
|
|
|
|
|
def batch_intersection_mask(batch, batch_filter): |
|
batch_filter = batch_filter.unique() |
|
mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1) |
|
return mask |
|
|
|
|
|
class MeanReadout(nn.Module): |
|
"""Mean readout operator over graphs with variadic sizes.""" |
|
|
|
def forward(self, input, batch, num_graphs): |
|
""" |
|
Perform readout over the graph(s). |
|
Parameters: |
|
data (torch_geometric.data.Data): batched graph |
|
input (Tensor): node representations |
|
Returns: |
|
Tensor: graph representations |
|
""" |
|
output = scatter_mean(input, batch, dim=0, dim_size=num_graphs) |
|
return output |
|
|
|
|
|
class SumReadout(nn.Module): |
|
"""Sum readout operator over graphs with variadic sizes.""" |
|
|
|
def forward(self, input, batch, num_graphs): |
|
""" |
|
Perform readout over the graph(s). |
|
Parameters: |
|
data (torch_geometric.data.Data): batched graph |
|
input (Tensor): node representations |
|
Returns: |
|
Tensor: graph representations |
|
""" |
|
output = scatter_add(input, batch, dim=0, dim_size=num_graphs) |
|
return output |
|
|
|
|
|
class MultiLayerPerceptron(nn.Module): |
|
""" |
|
Multi-layer Perceptron. |
|
Note there is no activation or dropout in the last layer. |
|
Parameters: |
|
input_dim (int): input dimension |
|
hidden_dim (list of int): hidden dimensions |
|
activation (str or function, optional): activation function |
|
dropout (float, optional): dropout rate |
|
""" |
|
|
|
def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): |
|
super(MultiLayerPerceptron, self).__init__() |
|
|
|
self.dims = [input_dim] + hidden_dims |
|
if isinstance(activation, str): |
|
self.activation = getattr(F, activation) |
|
else: |
|
self.activation = None |
|
if dropout: |
|
self.dropout = nn.Dropout(dropout) |
|
else: |
|
self.dropout = None |
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(len(self.dims) - 1): |
|
self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) |
|
|
|
def forward(self, input): |
|
"""""" |
|
x = input |
|
for i, layer in enumerate(self.layers): |
|
x = layer(x) |
|
if i < len(self.layers) - 1: |
|
if self.activation: |
|
x = self.activation(x) |
|
if self.dropout: |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class SmoothCrossEntropyLoss(_WeightedLoss): |
|
def __init__(self, weight=None, reduction='mean', smoothing=0.0): |
|
super().__init__(weight=weight, reduction=reduction) |
|
self.smoothing = smoothing |
|
self.weight = weight |
|
self.reduction = reduction |
|
|
|
@staticmethod |
|
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0): |
|
assert 0 <= smoothing < 1 |
|
with torch.no_grad(): |
|
targets = torch.empty(size=(targets.size(0), n_classes), |
|
device=targets.device) \ |
|
.fill_(smoothing /(n_classes-1)) \ |
|
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing) |
|
return targets |
|
|
|
def forward(self, inputs, targets): |
|
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), |
|
self.smoothing) |
|
lsm = F.log_softmax(inputs, -1) |
|
|
|
if self.weight is not None: |
|
lsm = lsm * self.weight.unsqueeze(0) |
|
|
|
loss = -(targets * lsm).sum(-1) |
|
|
|
if self.reduction == 'sum': |
|
loss = loss.sum() |
|
elif self.reduction == 'mean': |
|
loss = loss.mean() |
|
|
|
return loss |
|
|
|
|
|
class GaussianSmearing(nn.Module): |
|
def __init__(self, start=0.0, stop=10.0, num_gaussians=50): |
|
super().__init__() |
|
offset = torch.linspace(start, stop, num_gaussians) |
|
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 |
|
self.register_buffer('offset', offset) |
|
|
|
def forward(self, dist): |
|
dist = dist.view(-1, 1) - self.offset.view(1, -1) |
|
return torch.exp(self.coeff * torch.pow(dist, 2)) |
|
|
|
|
|
class ShiftedSoftplus(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.shift = torch.log(torch.tensor(2.0)).item() |
|
|
|
def forward(self, x): |
|
return F.softplus(x) - self.shift |
|
|
|
|
|
def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): |
|
batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) |
|
sort_idx = batch_ctx.argsort() |
|
|
|
mask_protein = torch.cat([ |
|
torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(), |
|
torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(), |
|
], dim=0)[sort_idx] |
|
|
|
batch_ctx = batch_ctx[sort_idx] |
|
h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] |
|
pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] |
|
|
|
return h_ctx, pos_ctx, batch_ctx |
|
|
|
|
|
def get_complete_graph(batch): |
|
""" |
|
Args: |
|
batch: Batch index. |
|
Returns: |
|
edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph. |
|
neighbors: (B, ), number of edges per graph. |
|
""" |
|
natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0) |
|
|
|
natoms_sqr = (natoms ** 2).long() |
|
num_atom_pairs = torch.sum(natoms_sqr) |
|
natoms_expand = torch.repeat_interleave(natoms, natoms_sqr) |
|
|
|
index_offset = torch.cumsum(natoms, dim=0) - natoms |
|
index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr) |
|
|
|
index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr |
|
index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr) |
|
|
|
atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset |
|
|
|
index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand |
|
index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand |
|
edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)]) |
|
mask = torch.logical_not(index1 == index2) |
|
edge_index = edge_index[:, mask] |
|
|
|
num_edges = natoms_sqr - natoms |
|
|
|
return edge_index, num_edges |
|
|
|
|
|
def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): |
|
num_graphs = batch_protein.max().item() + 1 |
|
|
|
batch_ctx = [] |
|
h_ctx = [] |
|
pos_ctx = [] |
|
mask_protein = [] |
|
|
|
for i in range(num_graphs): |
|
mask_p, mask_l = (batch_protein == i), (batch_ligand == i) |
|
batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l] |
|
|
|
batch_ctx += [batch_p, batch_l] |
|
h_ctx += [h_protein[mask_p], h_ligand[mask_l]] |
|
pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]] |
|
mask_protein += [ |
|
torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool), |
|
torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool), |
|
] |
|
|
|
batch_ctx = torch.cat(batch_ctx, dim=0) |
|
h_ctx = torch.cat(h_ctx, dim=0) |
|
pos_ctx = torch.cat(pos_ctx, dim=0) |
|
mask_protein = torch.cat(mask_protein, dim=0) |
|
|
|
return h_ctx, pos_ctx, batch_ctx, mask_protein |
|
|
|
if __name__ == '__main__': |
|
h_protein = torch.randn([60, 64]) |
|
h_ligand = -torch.randn([33, 64]) |
|
pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf')) |
|
pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0) |
|
batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30) |
|
batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11) |
|
|
|
h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand) |
|
|
|
assert (batch_ctx[mask_protein] == batch_protein).all() |
|
assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all() |
|
|
|
assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand) |
|
assert torch.allclose(h_ctx[mask_protein], h_protein) |
|
|
|
assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand) |
|
assert torch.allclose(pos_ctx[mask_protein], pos_protein) |
|
|
|
|
|
|