ICLR_FLAG / models /common.py
zaixizhang
renew
10efe81
raw
history blame
9.32 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss
from torch_scatter import scatter_mean, scatter_add
def split_tensor_by_batch(x, batch, num_graphs=None):
"""
Args:
x: (N, ...)
batch: (B, )
Returns:
[(N_1, ), (N_2, ) ..., (N_B, ))]
"""
if num_graphs is None:
num_graphs = batch.max().item() + 1
x_split = []
for i in range (num_graphs):
mask = batch == i
x_split.append(x[mask])
return x_split
def concat_tensors_to_batch(x_split):
x = torch.cat(x_split, dim=0)
batch = torch.repeat_interleave(
torch.arange(len(x_split)),
repeats=torch.LongTensor([s.size(0) for s in x_split])
).to(device=x.device)
return x, batch
def split_tensor_to_segments(x, segsize):
num_segs = math.ceil(x.size(0) / segsize)
segs = []
for i in range(num_segs):
segs.append(x[i*segsize : (i+1)*segsize])
return segs
def split_tensor_by_lengths(x, lengths):
segs = []
for l in lengths:
segs.append(x[:l])
x = x[l:]
return segs
def batch_intersection_mask(batch, batch_filter):
batch_filter = batch_filter.unique()
mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1)
return mask
class MeanReadout(nn.Module):
"""Mean readout operator over graphs with variadic sizes."""
def forward(self, input, batch, num_graphs):
"""
Perform readout over the graph(s).
Parameters:
data (torch_geometric.data.Data): batched graph
input (Tensor): node representations
Returns:
Tensor: graph representations
"""
output = scatter_mean(input, batch, dim=0, dim_size=num_graphs)
return output
class SumReadout(nn.Module):
"""Sum readout operator over graphs with variadic sizes."""
def forward(self, input, batch, num_graphs):
"""
Perform readout over the graph(s).
Parameters:
data (torch_geometric.data.Data): batched graph
input (Tensor): node representations
Returns:
Tensor: graph representations
"""
output = scatter_add(input, batch, dim=0, dim_size=num_graphs)
return output
class MultiLayerPerceptron(nn.Module):
"""
Multi-layer Perceptron.
Note there is no activation or dropout in the last layer.
Parameters:
input_dim (int): input dimension
hidden_dim (list of int): hidden dimensions
activation (str or function, optional): activation function
dropout (float, optional): dropout rate
"""
def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0):
super(MultiLayerPerceptron, self).__init__()
self.dims = [input_dim] + hidden_dims
if isinstance(activation, str):
self.activation = getattr(F, activation)
else:
self.activation = None
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
def forward(self, input):
""""""
x = input
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
if self.activation:
x = self.activation(x)
if self.dropout:
x = self.dropout(x)
return x
class SmoothCrossEntropyLoss(_WeightedLoss):
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
super().__init__(weight=weight, reduction=reduction)
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
@staticmethod
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
assert 0 <= smoothing < 1
with torch.no_grad():
targets = torch.empty(size=(targets.size(0), n_classes),
device=targets.device) \
.fill_(smoothing /(n_classes-1)) \
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
return targets
def forward(self, inputs, targets):
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
self.smoothing)
lsm = F.log_softmax(inputs, -1)
if self.weight is not None:
lsm = lsm * self.weight.unsqueeze(0)
loss = -(targets * lsm).sum(-1)
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss
class GaussianSmearing(nn.Module):
def __init__(self, start=0.0, stop=10.0, num_gaussians=50):
super().__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
class ShiftedSoftplus(nn.Module):
def __init__(self):
super().__init__()
self.shift = torch.log(torch.tensor(2.0)).item()
def forward(self, x):
return F.softplus(x) - self.shift
def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
sort_idx = batch_ctx.argsort()
mask_protein = torch.cat([
torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(),
torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(),
], dim=0)[sort_idx]
batch_ctx = batch_ctx[sort_idx]
h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
return h_ctx, pos_ctx, batch_ctx
def get_complete_graph(batch):
"""
Args:
batch: Batch index.
Returns:
edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph.
neighbors: (B, ), number of edges per graph.
"""
natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0)
natoms_sqr = (natoms ** 2).long()
num_atom_pairs = torch.sum(natoms_sqr)
natoms_expand = torch.repeat_interleave(natoms, natoms_sqr)
index_offset = torch.cumsum(natoms, dim=0) - natoms
index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr)
index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr
index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr)
atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset
index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand
index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand
edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)])
mask = torch.logical_not(index1 == index2)
edge_index = edge_index[:, mask]
num_edges = natoms_sqr - natoms # Number of edges per graph
return edge_index, num_edges
def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
num_graphs = batch_protein.max().item() + 1
batch_ctx = []
h_ctx = []
pos_ctx = []
mask_protein = []
for i in range(num_graphs):
mask_p, mask_l = (batch_protein == i), (batch_ligand == i)
batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l]
batch_ctx += [batch_p, batch_l]
h_ctx += [h_protein[mask_p], h_ligand[mask_l]]
pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]]
mask_protein += [
torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool),
torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool),
]
batch_ctx = torch.cat(batch_ctx, dim=0)
h_ctx = torch.cat(h_ctx, dim=0)
pos_ctx = torch.cat(pos_ctx, dim=0)
mask_protein = torch.cat(mask_protein, dim=0)
return h_ctx, pos_ctx, batch_ctx, mask_protein
if __name__ == '__main__':
h_protein = torch.randn([60, 64])
h_ligand = -torch.randn([33, 64])
pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf'))
pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0)
batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30)
batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11)
h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand)
assert (batch_ctx[mask_protein] == batch_protein).all()
assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all()
assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand)
assert torch.allclose(h_ctx[mask_protein], h_protein)
assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand)
assert torch.allclose(pos_ctx[mask_protein], pos_protein)