ICLR_FLAG / utils /dihedral_utils.py
zaixizhang
renew
10efe81
raw
history blame
14.7 kB
import torch
import torch_geometric as tg
from torch_geometric.utils import degree
import networkx as nx
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1]]).to(device)
angle_combos = torch.LongTensor([[0, 1],
[0, 2],
[1, 2],
[0, 3],
[1, 3],
[2, 3]]).to(device)
def get_neighbor_ids(data):
"""
Takes the edge indices and returns dictionary mapping atom index to neighbor indices
Note: this only includes atoms with degree > 1
"""
# start, end = edge_index
# idxs, vals = torch.unique(start, return_counts=True)
# vs = torch.split_with_sizes(end, tuple(vals))
# return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
neighbors = data.neighbors.pop(0)
n_atoms_per_mol = data.batch.bincount()
n_atoms_prev_mol = 0
for i, n_dict in enumerate(data.neighbors):
new_dict = {}
n_atoms_prev_mol += n_atoms_per_mol[i].item()
for k, v in n_dict.items():
new_dict[k + n_atoms_prev_mol] = v + n_atoms_prev_mol
neighbors.update(new_dict)
return neighbors
def get_neighbor_bonds(edge_index, bond_type):
"""
Takes the edge indices and bond type and returns dictionary mapping atom index to neighbor bond types
Note: this only includes atoms with degree > 1
"""
start, end = edge_index
idxs, vals = torch.unique(start, return_counts=True)
vs = torch.split_with_sizes(bond_type, tuple(vals))
return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
def get_leaf_hydrogens(neighbors, x):
"""
Takes the edge indices and atom features and returns dictionary mapping atom index to neighbors, indicating true
for hydrogens that are leaf nodes
Note: this only works because degree = 1 and hydrogen atomic number = 1 (checks when 1 == 1)
Note: we use the 5th feature index bc this corresponds to the atomic number
"""
# start, end = edge_index
# degrees = degree(end)
# idxs, vals = torch.unique(start, return_counts=True)
# vs = torch.split_with_sizes(end, tuple(vals))
# return {k.item(): degrees[v] == x[v, 5] for k, v in zip(idxs, vs) if len(v) > 1}
leaf_hydrogens = {}
h_mask = x[:, 0] == 1
for k, v in neighbors.items():
leaf_hydrogens[k] = h_mask[neighbors[k]]
return leaf_hydrogens
def get_dihedral_pairs(edge_index, data):
"""
Given edge indices, return pairs of indices that we must calculate dihedrals for
"""
start, end = edge_index
degrees = degree(end)
dihedral_pairs_true = torch.nonzero(torch.logical_and(degrees[start] > 1, degrees[end] > 1))
dihedral_pairs = edge_index[:, dihedral_pairs_true].squeeze(-1)
# # first method which removes one (pseudo) random edge from a cycle
dihedral_idxs = torch.nonzero(dihedral_pairs.sort(dim=0).indices[0, :] == 0).squeeze().detach().cpu().numpy()
# prioritize rings for assigning dihedrals
dihedral_pairs = dihedral_pairs.t()[dihedral_idxs]
G = nx.to_undirected(tg.utils.to_networkx(data))
cycles = nx.cycle_basis(G)
keep, sorted_keep = [], []
if len(dihedral_pairs.shape) == 1:
dihedral_pairs = dihedral_pairs.unsqueeze(0)
for pair in dihedral_pairs:
x, y = pair
if sorted(pair) in sorted_keep:
continue
y_cycle_check = [y in cycle for cycle in cycles]
x_cycle_check = [x in cycle for cycle in cycles]
if any(x_cycle_check) and any(y_cycle_check): # both in new cycle
cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x)
keep.extend(cycle_indices)
sorted_keep.extend([sorted(c) for c in cycle_indices])
continue
if any(y_cycle_check):
cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y)
keep.append(pair)
keep.extend(cycle_indices)
sorted_keep.append(sorted(pair))
sorted_keep.extend([sorted(c) for c in cycle_indices])
continue
keep.append(pair)
keep = [t.to(device) for t in keep]
return torch.stack(keep).t()
def batch_distance_metrics_from_coords(coords, mask):
"""
Given coordinates of neighboring atoms, compute bond
distances and 2-hop distances in local neighborhood
"""
d_mat_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
if coords.dim() == 4:
two_dop_d_mat = torch.square(coords.unsqueeze(1) - coords.unsqueeze(2) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1)
one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
elif coords.dim() == 5:
two_dop_d_mat = torch.square(coords.unsqueeze(2) - coords.unsqueeze(3) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1).unsqueeze(1)
one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
return one_hop_ds, two_dop_d_mat
def batch_angle_between_vectors(a, b):
"""
Compute angle between two batches of input vectors
"""
inner_product = (a * b).sum(dim=-1)
# norms
a_norm = torch.linalg.norm(a, dim=-1)
b_norm = torch.linalg.norm(b, dim=-1)
# protect denominator during division
den = a_norm * b_norm + 1e-10
cos = inner_product / den
return cos
def batch_angles_from_coords(coords, mask):
"""
Given coordinates, compute all local neighborhood angles
"""
if coords.dim() == 4:
all_possible_combos = coords[:, angle_combos]
v_a, v_b = all_possible_combos.split(1, dim=2) # does one of these need to be negative?
angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
angles = batch_angle_between_vectors(v_a.squeeze(2), v_b.squeeze(2)) * angle_mask.unsqueeze(-1)
elif coords.dim() == 5:
all_possible_combos = coords[:, :, angle_combos]
v_a, v_b = all_possible_combos.split(1, dim=3) # does one of these need to be negative?
angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
angles = batch_angle_between_vectors(v_a.squeeze(3), v_b.squeeze(3)) * angle_mask.unsqueeze(-1).unsqueeze(-1)
return angles
def batch_local_stats_from_coords(coords, mask):
"""
Given neighborhood neighbor coordinates, compute bond distances,
2-hop distances, and angles in local neighborhood (this assumes
the central atom has coordinates at the origin)
"""
one_hop_ds, two_dop_d_mat = batch_distance_metrics_from_coords(coords, mask)
angles = batch_angles_from_coords(coords, mask)
return one_hop_ds, two_dop_d_mat, angles
def batch_dihedrals(p0, p1, p2, p3, angle=False):
s1 = p1 - p0
s2 = p2 - p1
s3 = p3 - p2
sin_d_ = torch.linalg.norm(s2, dim=-1) * torch.sum(s1 * torch.cross(s2, s3, dim=-1), dim=-1)
cos_d_ = torch.sum(torch.cross(s1, s2, dim=-1) * torch.cross(s2, s3, dim=-1), dim=-1)
if angle:
return torch.atan2(sin_d_, cos_d_ + 1e-10)
else:
den = torch.linalg.norm(torch.cross(s1, s2, dim=-1), dim=-1) * torch.linalg.norm(torch.cross(s2, s3, dim=-1), dim=-1) + 1e-10
return sin_d_/den, cos_d_/den
def batch_vector_angles(xn, x, y, yn):
uT = xn.view(-1, 3)
uX = x.view(-1, 3)
uY = y.view(-1, 3)
uZ = yn.view(-1, 3)
b1 = uT - uX
b2 = uZ - uY
num = torch.bmm(b1.view(-1, 1, 3), b2.view(-1, 3, 1)).squeeze(-1).squeeze(-1)
den = torch.linalg.norm(b1, dim=-1) * torch.linalg.norm(b2, dim=-1) + 1e-10
return (num / den).view(-1, 9)
def von_Mises_loss(a, b, a_sin=None, b_sin=None):
"""
:param a: cos of first angle
:param b: cos of second angle
:return: difference of cosines
"""
if torch.is_tensor(a_sin):
out = a * b + a_sin * b_sin
else:
out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5)
return out
def rotation_matrix(neighbor_coords, neighbor_mask, neighbor_map, mu=None):
"""
Given predicted neighbor coordinates from model, return rotation matrix
:param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs
(n_dihedral_pairs, 4, n_generated_confs, 3)
:param neighbor_mask: mask describing which atoms are present (n_dihedral_pairs, 4)
:param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom
(n_dihedral_pairs, 4) each entry in neighbor_map should have one TRUE entry with the rest as FALSE
:return: rotation matrix (n_dihedral_pairs, n_model_confs, 3, 3)
"""
if not torch.is_tensor(mu):
# mu = neighbor_coords.sum(dim=1, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-10)
mu_num = neighbor_coords[~neighbor_map.bool()].view(neighbor_coords.size(0), 3, neighbor_coords.size(2), -1).sum(dim=1)
mu_den = (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) - 1 + 1e-10)
mu = mu_num / mu_den # (n_dihedral_pairs, n_model_confs, 10)
mu = mu.squeeze(1) # (n_dihedral_pairs, n_model_confs, 10)
p_Y = neighbor_coords[neighbor_map.bool(), :]
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
h3_1 = torch.cross(p_Y, mu, dim=-1)
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
H = torch.cat([h1.unsqueeze(-2),
h2.unsqueeze(-2),
h3.unsqueeze(-2)], dim=-2)
return H
def rotation_matrix_v2(neighbor_coords):
"""
Given predicted neighbor coordinates from model, return rotation matrix
:param neighbor_coords: y or x coordinates for the x or y center node
(n_dihedral_pairs, 3)
:return: rotation matrix (n_dihedral_pairs, 3, 3)
"""
p_Y = neighbor_coords
eta_1 = torch.rand_like(p_Y)
eta_2 = eta_1 - torch.sum(eta_1 * p_Y, dim=-1, keepdim=True) / (torch.linalg.norm(p_Y, dim=-1, keepdim=True)**2 + 1e-10) * p_Y
eta = eta_2 / torch.linalg.norm(eta_2, dim=-1, keepdim=True)
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
h3_1 = torch.cross(p_Y, eta, dim=-1)
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
H = torch.cat([h1.unsqueeze(-2),
h2.unsqueeze(-2),
h3.unsqueeze(-2)], dim=-2)
return H
def signed_volume(local_coords):
"""
Compute signed volume given ordered neighbor local coordinates
:param local_coords: (n_tetrahedral_chiral_centers, 4, n_generated_confs, 3)
:return: signed volume of each tetrahedral center (n_tetrahedral_chiral_centers, n_generated_confs)
"""
v1 = local_coords[:, 0] - local_coords[:, 3]
v2 = local_coords[:, 1] - local_coords[:, 3]
v3 = local_coords[:, 2] - local_coords[:, 3]
cp = v2.cross(v3, dim=-1)
vol = torch.sum(v1 * cp, dim=-1)
return torch.sign(vol)
def rotation_matrix_inf(neighbor_coords, neighbor_mask, neighbor_map):
"""
Given predicted neighbor coordinates from model, return rotation matrix
:param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs (4, n_model_confs, 3)
:param neighbor_mask: mask describing which atoms are present (4)
:param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom (4)
each entry in neighbor_map should have one TRUE entry with the rest as FALSE
:return: rotation matrix (3, 3)
"""
mu = neighbor_coords.sum(dim=0, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) + 1e-10)
mu = mu.squeeze(0)
p_Y = neighbor_coords[neighbor_map.bool(), :].squeeze(0)
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10)
h3_1 = torch.cross(p_Y, mu, dim=-1)
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10)
h2 = -torch.cross(h1, h3, dim=-1)
H = torch.cat([h1.unsqueeze(-2),
h2.unsqueeze(-2),
h3.unsqueeze(-2)], dim=-2)
return H
def build_alpha_rotation_inf(alpha, n_model_confs):
H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1)
H_alpha[:, 1, 1] = torch.cos(alpha)
H_alpha[:, 1, 2] = -torch.sin(alpha)
H_alpha[:, 2, 1] = torch.sin(alpha)
H_alpha[:, 2, 2] = torch.cos(alpha)
return H_alpha
def random_rotation_matrix(dim):
yaw = torch.rand(dim)
pitch = torch.rand(dim)
roll = torch.rand(dim)
R = torch.stack([torch.stack([torch.cos(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.sin(pitch) * torch.sin(roll) - torch.sin(yaw) * torch.cos(
roll),
torch.cos(yaw) * torch.sin(pitch) * torch.cos(roll) + torch.sin(yaw) * torch.sin(
roll)], dim=-1),
torch.stack([torch.sin(yaw) * torch.cos(pitch),
torch.sin(yaw) * torch.sin(pitch) * torch.sin(roll) + torch.cos(yaw) * torch.cos(
roll),
torch.sin(yaw) * torch.sin(pitch) * torch.cos(roll) - torch.cos(yaw) * torch.sin(
roll)], dim=-1),
torch.stack([-torch.sin(pitch),
torch.cos(pitch) * torch.sin(roll),
torch.cos(pitch) * torch.cos(roll)], dim=-1)], dim=-2)
return R
def length_to_mask(length, max_len=None, dtype=None):
"""length: B.
return B x max_len.
If max_len is None, then max of length will be used.
"""
assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
max_len = max_len or length.max().item()
mask = torch.arange(max_len, device=length.device,
dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
if dtype is not None:
mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
return mask