| | from abc import ABC, abstractmethod |
| | from typing import Optional, Dict, Any, Set, List, Union |
| |
|
| | import torch |
| | import numpy as np |
| | from . import vb_const as const |
| | from .vb_potentials_schedules import ( |
| | ParameterSchedule, |
| | ExponentialInterpolation, |
| | PiecewiseStepFunction, |
| | ) |
| | from .vb_loss_diffusionv2 import weighted_rigid_align |
| |
|
| |
|
| | class Potential(ABC): |
| | def __init__( |
| | self, |
| | parameters: Optional[ |
| | Dict[str, Union[ParameterSchedule, float, int, bool]] |
| | ] = None, |
| | ): |
| | self.parameters = parameters |
| |
|
| | def compute(self, coords, feats, parameters): |
| | index, args, com_args, ref_args, operator_args = self.compute_args( |
| | feats, parameters |
| | ) |
| |
|
| | if index.shape[1] == 0: |
| | return torch.zeros(coords.shape[:-2], device=coords.device) |
| |
|
| | if com_args is not None: |
| | com_index, atom_pad_mask = com_args |
| | unpad_com_index = com_index[atom_pad_mask] |
| | unpad_coords = coords[..., atom_pad_mask, :] |
| | coords = torch.zeros( |
| | (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), |
| | device=coords.device, |
| | ).scatter_reduce( |
| | -2, |
| | unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), |
| | unpad_coords, |
| | "mean", |
| | ) |
| | else: |
| | com_index, atom_pad_mask = None, None |
| |
|
| | if ref_args is not None: |
| | ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args |
| | coords = coords[..., ref_atom_index, :] |
| | else: |
| | ref_coords, ref_mask, ref_atom_index, ref_token_index = ( |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| | if operator_args is not None: |
| | negation_mask, union_index = operator_args |
| | else: |
| | negation_mask, union_index = None, None |
| |
|
| | value = self.compute_variable( |
| | coords, |
| | index, |
| | ref_coords=ref_coords, |
| | ref_mask=ref_mask, |
| | compute_gradient=False, |
| | ) |
| | energy = self.compute_function( |
| | value, *args, negation_mask=negation_mask, compute_derivative=False |
| | ) |
| |
|
| | if union_index is not None: |
| | neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) |
| | Z = torch.zeros( |
| | (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| | ).scatter_reduce( |
| | -1, |
| | union_index.expand_as(neg_exp_energy), |
| | neg_exp_energy, |
| | "sum", |
| | ) |
| | softmax_energy = neg_exp_energy / Z[..., union_index] |
| | softmax_energy[Z[..., union_index] == 0] = 0 |
| | return (energy * softmax_energy).sum(dim=-1) |
| |
|
| | return energy.sum(dim=tuple(range(1, energy.dim()))) |
| |
|
| | def compute_gradient(self, coords, feats, parameters): |
| | index, args, com_args, ref_args, operator_args = self.compute_args( |
| | feats, parameters |
| | ) |
| | if index.shape[1] == 0: |
| | return torch.zeros_like(coords) |
| |
|
| | if com_args is not None: |
| | com_index, atom_pad_mask = com_args |
| | unpad_coords = coords[..., atom_pad_mask, :] |
| | unpad_com_index = com_index[atom_pad_mask] |
| | coords = torch.zeros( |
| | (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), |
| | device=coords.device, |
| | ).scatter_reduce( |
| | -2, |
| | unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), |
| | unpad_coords, |
| | "mean", |
| | ) |
| | com_counts = torch.bincount(com_index[atom_pad_mask]) |
| | else: |
| | com_index, atom_pad_mask = None, None |
| |
|
| | if ref_args is not None: |
| | ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args |
| | coords = coords[..., ref_atom_index, :] |
| | else: |
| | ref_coords, ref_mask, ref_atom_index, ref_token_index = ( |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |
| |
|
| | if operator_args is not None: |
| | negation_mask, union_index = operator_args |
| | else: |
| | negation_mask, union_index = None, None |
| |
|
| | value, grad_value = self.compute_variable( |
| | coords, |
| | index, |
| | ref_coords=ref_coords, |
| | ref_mask=ref_mask, |
| | compute_gradient=True, |
| | ) |
| | energy, dEnergy = self.compute_function( |
| | value, |
| | *args, negation_mask=negation_mask, compute_derivative=True |
| | ) |
| | if union_index is not None: |
| | neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) |
| | Z = torch.zeros( |
| | (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| | ).scatter_reduce( |
| | -1, |
| | union_index.expand_as(energy), |
| | neg_exp_energy, |
| | "sum", |
| | ) |
| | softmax_energy = neg_exp_energy / Z[..., union_index] |
| | softmax_energy[Z[..., union_index] == 0] = 0 |
| | f = torch.zeros( |
| | (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| | ).scatter_reduce( |
| | -1, |
| | union_index.expand_as(energy), |
| | energy * softmax_energy, |
| | "sum", |
| | ) |
| | dSoftmax = ( |
| | dEnergy |
| | * softmax_energy |
| | * (1 + parameters["union_lambda"] * (energy - f[..., union_index])) |
| | ) |
| | prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze( |
| | -1 |
| | ) * grad_value.flatten(start_dim=-3, end_dim=-2) |
| | if prod.dim() > 3: |
| | prod = prod.sum(dim=list(range(1, prod.dim() - 2))) |
| | grad_atom = torch.zeros_like(coords).scatter_reduce( |
| | -2, |
| | index.flatten(start_dim=0, end_dim=1) |
| | .unsqueeze(-1) |
| | .expand((*coords.shape[:-2], -1, 3)), |
| | prod, |
| | "sum", |
| | ) |
| | else: |
| | prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze( |
| | -1 |
| | ) * grad_value.flatten(start_dim=-3, end_dim=-2) |
| | if prod.dim() > 3: |
| | prod = prod.sum(dim=list(range(1, prod.dim() - 2))) |
| | grad_atom = torch.zeros_like(coords).scatter_reduce( |
| | -2, |
| | index.flatten(start_dim=0, end_dim=1) |
| | .unsqueeze(-1) |
| | .expand((*coords.shape[:-2], -1, 3)), |
| | prod, |
| | "sum", |
| | ) |
| |
|
| | if com_index is not None: |
| | grad_atom = grad_atom[..., com_index, :] |
| | elif ref_token_index is not None: |
| | grad_atom = grad_atom[..., ref_token_index, :] |
| |
|
| | return grad_atom |
| |
|
| | def compute_parameters(self, t): |
| | if self.parameters is None: |
| | return None |
| | parameters = { |
| | name: parameter |
| | if not isinstance(parameter, ParameterSchedule) |
| | else parameter.compute(t) |
| | for name, parameter in self.parameters.items() |
| | } |
| | return parameters |
| |
|
| | @abstractmethod |
| | def compute_function( |
| | self, value, *args, negation_mask=None, compute_derivative=False |
| | ): |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def compute_variable(self, coords, index, compute_gradient=False): |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def compute_args(self, t, feats, **parameters): |
| | raise NotImplementedError |
| |
|
| | def get_reference_coords(self, feats, parameters): |
| | return None, None |
| |
|
| |
|
| | class FlatBottomPotential(Potential): |
| | def compute_function( |
| | self, |
| | value, |
| | k, |
| | lower_bounds, |
| | upper_bounds, |
| | negation_mask=None, |
| | compute_derivative=False, |
| | ): |
| | if lower_bounds is None: |
| | lower_bounds = torch.full_like(value, float("-inf")) |
| | if upper_bounds is None: |
| | upper_bounds = torch.full_like(value, float("inf")) |
| | lower_bounds = lower_bounds.expand_as(value).clone() |
| | upper_bounds = upper_bounds.expand_as(value).clone() |
| |
|
| | if negation_mask is not None: |
| | unbounded_below_mask = torch.isneginf(lower_bounds) |
| | unbounded_above_mask = torch.isposinf(upper_bounds) |
| | unbounded_mask = unbounded_below_mask + unbounded_above_mask |
| | assert torch.all(unbounded_mask + negation_mask) |
| | lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[ |
| | ~unbounded_above_mask * ~negation_mask |
| | ] |
| | upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf") |
| | upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[ |
| | ~unbounded_below_mask * ~negation_mask |
| | ] |
| | lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf") |
| |
|
| | neg_overflow_mask = value < lower_bounds |
| | pos_overflow_mask = value > upper_bounds |
| |
|
| | energy = torch.zeros_like(value) |
| | energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask] |
| | energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask] |
| | if not compute_derivative: |
| | return energy |
| |
|
| | dEnergy = torch.zeros_like(value) |
| | dEnergy[neg_overflow_mask] = ( |
| | -1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask] |
| | ) |
| | dEnergy[pos_overflow_mask] = ( |
| | 1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask] |
| | ) |
| |
|
| | return energy, dEnergy |
| |
|
| |
|
| | class ReferencePotential(Potential): |
| | def compute_variable( |
| | self, coords, index, ref_coords, ref_mask, compute_gradient=False |
| | ): |
| | aligned_ref_coords = weighted_rigid_align( |
| | ref_coords.float(), |
| | coords[:, index].float(), |
| | ref_mask, |
| | ref_mask, |
| | ) |
| |
|
| | r = coords[:, index] - aligned_ref_coords |
| | r_norm = torch.linalg.norm(r, dim=-1) |
| |
|
| | if not compute_gradient: |
| | return r_norm |
| |
|
| | r_hat = r / r_norm.unsqueeze(-1) |
| | grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1) |
| | return r_norm, grad |
| |
|
| |
|
| | class DistancePotential(Potential): |
| | def compute_variable( |
| | self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| | ): |
| | r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) |
| | r_ij_norm = torch.linalg.norm(r_ij, dim=-1) |
| | r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1) |
| |
|
| | if not compute_gradient: |
| | return r_ij_norm |
| |
|
| | grad_i = r_hat_ij |
| | grad_j = -1 * r_hat_ij |
| | grad = torch.stack((grad_i, grad_j), dim=1) |
| | return r_ij_norm, grad |
| |
|
| |
|
| | class DihedralPotential(Potential): |
| | def compute_variable( |
| | self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| | ): |
| | r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) |
| | r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1]) |
| | r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3]) |
| |
|
| | n_ijk = torch.cross(r_ij, r_kj, dim=-1) |
| | n_jkl = torch.cross(r_kj, r_kl, dim=-1) |
| |
|
| | r_kj_norm = torch.linalg.norm(r_kj, dim=-1) |
| | n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) |
| | n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) |
| |
|
| | sign_phi = torch.sign( |
| | r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1) |
| | ).squeeze(-1, -2) |
| | phi = sign_phi * torch.arccos( |
| | torch.clamp( |
| | (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2) |
| | / (n_ijk_norm * n_jkl_norm), |
| | -1 + 1e-8, |
| | 1 - 1e-8, |
| | ) |
| | ) |
| |
|
| | if not compute_gradient: |
| | return phi |
| |
|
| | a = ( |
| | (r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) |
| | ).unsqueeze(-1) |
| | b = ( |
| | (r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) |
| | ).unsqueeze(-1) |
| |
|
| | grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1) |
| | grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1) |
| | grad_j = (a - 1) * grad_i - b * grad_l |
| | grad_k = (b - 1) * grad_l - a * grad_i |
| | grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1) |
| | return phi, grad |
| |
|
| |
|
| | class AbsDihedralPotential(DihedralPotential): |
| | def compute_variable( |
| | self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| | ): |
| | if not compute_gradient: |
| | phi = super().compute_variable( |
| | coords, index, compute_gradient=compute_gradient |
| | ) |
| | phi = torch.abs(phi) |
| | return phi |
| |
|
| | phi, grad = super().compute_variable( |
| | coords, index, compute_gradient=compute_gradient |
| | ) |
| | grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1 |
| | phi = torch.abs(phi) |
| |
|
| | return phi, grad |
| |
|
| |
|
| | class PoseBustersPotential(FlatBottomPotential, DistancePotential): |
| | def compute_args(self, feats, parameters): |
| | pair_index = feats["rdkit_bounds_index"][0] |
| | lower_bounds = feats["rdkit_lower_bounds"][0].clone() |
| | upper_bounds = feats["rdkit_upper_bounds"][0].clone() |
| | bond_mask = feats["rdkit_bounds_bond_mask"][0] |
| | angle_mask = feats["rdkit_bounds_angle_mask"][0] |
| |
|
| | lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"] |
| | upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"] |
| | lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"] |
| | upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"] |
| | lower_bounds[bond_mask * angle_mask] *= 1.0 - min( |
| | parameters["bond_buffer"], parameters["angle_buffer"] |
| | ) |
| | upper_bounds[bond_mask * angle_mask] *= 1.0 + min( |
| | parameters["bond_buffer"], parameters["angle_buffer"] |
| | ) |
| | lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"] |
| | upper_bounds[~bond_mask * ~angle_mask] = float("inf") |
| |
|
| | vdw_radii = torch.zeros( |
| | const.num_elements, dtype=torch.float32, device=pair_index.device |
| | ) |
| | vdw_radii[1:119] = torch.tensor( |
| | const.vdw_radii, dtype=torch.float32, device=pair_index.device |
| | ) |
| | atom_vdw_radii = ( |
| | feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) |
| | ).squeeze(-1)[0] |
| | bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0) |
| | lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask]) |
| | upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask]) |
| |
|
| | k = torch.ones_like(lower_bounds) |
| |
|
| | return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class ConnectionsPotential(FlatBottomPotential, DistancePotential): |
| | def compute_args(self, feats, parameters): |
| | pair_index = feats["connected_atom_index"][0] |
| | lower_bounds = None |
| | upper_bounds = torch.full( |
| | (pair_index.shape[1],), parameters["buffer"], device=pair_index.device |
| | ) |
| | k = torch.ones_like(upper_bounds) |
| |
|
| | return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class VDWOverlapPotential(FlatBottomPotential, DistancePotential): |
| | def compute_args(self, feats, parameters): |
| | atom_chain_id = ( |
| | torch.bmm( |
| | feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() |
| | ) |
| | .squeeze(-1) |
| | .long() |
| | )[0] |
| | atom_pad_mask = feats["atom_pad_mask"][0].bool() |
| | chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) |
| | single_ion_mask = (chain_sizes > 1)[atom_chain_id] |
| |
|
| | vdw_radii = torch.zeros( |
| | const.num_elements, dtype=torch.float32, device=atom_chain_id.device |
| | ) |
| | vdw_radii[1:119] = torch.tensor( |
| | const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device |
| | ) |
| | atom_vdw_radii = ( |
| | feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) |
| | ).squeeze(-1)[0] |
| |
|
| | pair_index = torch.triu_indices( |
| | atom_chain_id.shape[0], |
| | atom_chain_id.shape[0], |
| | 1, |
| | device=atom_chain_id.device, |
| | ) |
| |
|
| | pair_pad_mask = atom_pad_mask[pair_index].all(dim=0) |
| | pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] |
| |
|
| | num_chains = atom_chain_id.max() + 1 |
| | connected_chain_index = feats["connected_chain_index"][0] |
| | connected_chain_matrix = torch.eye( |
| | num_chains, device=atom_chain_id.device, dtype=torch.bool |
| | ) |
| | connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = ( |
| | True |
| | ) |
| | connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = ( |
| | True |
| | ) |
| | connected_chain_mask = connected_chain_matrix[ |
| | atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]] |
| | ] |
| |
|
| | pair_index = pair_index[ |
| | :, pair_pad_mask * pair_ion_mask * ~connected_chain_mask |
| | ] |
| |
|
| | lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * ( |
| | 1.0 - parameters["buffer"] |
| | ) |
| | upper_bounds = None |
| | k = torch.ones_like(lower_bounds) |
| |
|
| | return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential): |
| | def compute_args(self, feats, parameters): |
| | atom_chain_id = ( |
| | torch.bmm( |
| | feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() |
| | ) |
| | .squeeze(-1) |
| | .long() |
| | )[0] |
| | atom_pad_mask = feats["atom_pad_mask"][0].bool() |
| | chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) |
| | single_ion_mask = chain_sizes > 1 |
| |
|
| | pair_index = feats["symmetric_chain_index"][0] |
| | pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] |
| | pair_index = pair_index[:, pair_ion_mask] |
| | lower_bounds = torch.full( |
| | (pair_index.shape[1],), |
| | parameters["buffer"], |
| | dtype=torch.float32, |
| | device=pair_index.device, |
| | ) |
| | upper_bounds = None |
| | k = torch.ones_like(lower_bounds) |
| |
|
| | return ( |
| | pair_index, |
| | (k, lower_bounds, upper_bounds), |
| | (atom_chain_id, atom_pad_mask), |
| | None, |
| | None, |
| | ) |
| |
|
| |
|
| | class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential): |
| | def compute_args(self, feats, parameters): |
| | stereo_bond_index = feats["stereo_bond_index"][0] |
| | stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool() |
| |
|
| | lower_bounds = torch.zeros( |
| | stereo_bond_orientations.shape, device=stereo_bond_orientations.device |
| | ) |
| | upper_bounds = torch.zeros( |
| | stereo_bond_orientations.shape, device=stereo_bond_orientations.device |
| | ) |
| | lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"] |
| | upper_bounds[stereo_bond_orientations] = float("inf") |
| | lower_bounds[~stereo_bond_orientations] = float("-inf") |
| | upper_bounds[~stereo_bond_orientations] = parameters["buffer"] |
| |
|
| | k = torch.ones_like(lower_bounds) |
| |
|
| | return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class ChiralAtomPotential(FlatBottomPotential, DihedralPotential): |
| | def compute_args(self, feats, parameters): |
| | chiral_atom_index = feats["chiral_atom_index"][0] |
| | chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool() |
| |
|
| | lower_bounds = torch.zeros( |
| | chiral_atom_orientations.shape, device=chiral_atom_orientations.device |
| | ) |
| | upper_bounds = torch.zeros( |
| | chiral_atom_orientations.shape, device=chiral_atom_orientations.device |
| | ) |
| | lower_bounds[chiral_atom_orientations] = parameters["buffer"] |
| | upper_bounds[chiral_atom_orientations] = float("inf") |
| | upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"] |
| | lower_bounds[~chiral_atom_orientations] = float("-inf") |
| |
|
| | k = torch.ones_like(lower_bounds) |
| | return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential): |
| | def compute_args(self, feats, parameters): |
| | double_bond_index = feats["planar_bond_index"][0].T |
| | double_bond_improper_index = torch.tensor( |
| | [ |
| | [1, 2, 3, 0], |
| | [4, 5, 0, 3], |
| | ], |
| | device=double_bond_index.device, |
| | ).T |
| | improper_index = ( |
| | double_bond_index[:, double_bond_improper_index] |
| | .swapaxes(0, 1) |
| | .flatten(start_dim=1) |
| | ) |
| | lower_bounds = None |
| | upper_bounds = torch.full( |
| | (improper_index.shape[1],), |
| | parameters["buffer"], |
| | device=improper_index.device, |
| | ) |
| | k = torch.ones_like(upper_bounds) |
| |
|
| | return improper_index, (k, lower_bounds, upper_bounds), None, None, None |
| |
|
| |
|
| | class TemplateReferencePotential(FlatBottomPotential, ReferencePotential): |
| | def compute_args(self, feats, parameters): |
| | if "template_mask_cb" not in feats or "template_force" not in feats: |
| | return torch.empty([1, 0]), None, None, None, None |
| |
|
| | template_mask = feats["template_mask_cb"][feats["template_force"]] |
| | if template_mask.shape[0] == 0: |
| | return torch.empty([1, 0]), None, None, None, None |
| |
|
| | ref_coords = feats["template_cb"][feats["template_force"]].clone() |
| | ref_mask = feats["template_mask_cb"][feats["template_force"]].clone() |
| | ref_atom_index = ( |
| | torch.bmm( |
| | feats["token_to_rep_atom"].float(), |
| | torch.arange( |
| | feats["atom_pad_mask"].shape[1], |
| | device=feats["atom_pad_mask"].device, |
| | dtype=torch.float32, |
| | )[None, :, None], |
| | ) |
| | .squeeze(-1) |
| | .long() |
| | )[0] |
| | ref_token_index = ( |
| | torch.bmm( |
| | feats["atom_to_token"].float(), |
| | feats["token_index"].unsqueeze(-1).float(), |
| | ) |
| | .squeeze(-1) |
| | .long() |
| | )[0] |
| |
|
| | index = torch.arange( |
| | template_mask.shape[-1], dtype=torch.long, device=template_mask.device |
| | )[None] |
| | upper_bounds = torch.full( |
| | template_mask.shape, float("inf"), device=index.device, dtype=torch.float32 |
| | ) |
| | ref_idxs = torch.argwhere(template_mask).T |
| | upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][ |
| | feats["template_force"] |
| | ][ref_idxs[0]] |
| |
|
| | lower_bounds = None |
| | k = torch.ones_like(upper_bounds) |
| | return ( |
| | index, |
| | (k, lower_bounds, upper_bounds), |
| | None, |
| | (ref_coords, ref_mask, ref_atom_index, ref_token_index), |
| | None, |
| | ) |
| |
|
| |
|
| | class ContactPotentital(FlatBottomPotential, DistancePotential): |
| | def compute_args(self, feats, parameters): |
| | index = feats["contact_pair_index"][0] |
| | union_index = feats["contact_union_index"][0] |
| | negation_mask = feats["contact_negation_mask"][0] |
| | lower_bounds = None |
| | upper_bounds = feats["contact_thresholds"][0].clone() |
| | k = torch.ones_like(upper_bounds) |
| | return ( |
| | index, |
| | (k, lower_bounds, upper_bounds), |
| | None, |
| | None, |
| | (negation_mask, union_index), |
| | ) |
| |
|
| |
|
| | def get_potentials(steering_args, boltz2=False): |
| | potentials = [] |
| | if steering_args["fk_steering"] or steering_args["physical_guidance_update"]: |
| | potentials.extend( |
| | [ |
| | SymmetricChainCOMPotential( |
| | parameters={ |
| | "guidance_interval": 4, |
| | "guidance_weight": 0.5 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 0.5, |
| | "buffer": ExponentialInterpolation( |
| | start=1.0, end=5.0, alpha=-2.0 |
| | ), |
| | } |
| | ), |
| | VDWOverlapPotential( |
| | parameters={ |
| | "guidance_interval": 5, |
| | "guidance_weight": ( |
| | PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0]) |
| | if steering_args["physical_guidance_update"] |
| | else 0.0 |
| | ), |
| | "resampling_weight": PiecewiseStepFunction( |
| | thresholds=[0.6], values=[0.01, 0.0] |
| | ), |
| | "buffer": 0.225, |
| | } |
| | ), |
| | ConnectionsPotential( |
| | parameters={ |
| | "guidance_interval": 1, |
| | "guidance_weight": 0.15 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 1.0, |
| | "buffer": 2.0, |
| | } |
| | ), |
| | PoseBustersPotential( |
| | parameters={ |
| | "guidance_interval": 1, |
| | "guidance_weight": 0.01 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 0.1, |
| | "bond_buffer": 0.125, |
| | "angle_buffer": 0.125, |
| | "clash_buffer": 0.10, |
| | } |
| | ), |
| | ChiralAtomPotential( |
| | parameters={ |
| | "guidance_interval": 1, |
| | "guidance_weight": 0.1 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 1.0, |
| | "buffer": 0.52360, |
| | } |
| | ), |
| | StereoBondPotential( |
| | parameters={ |
| | "guidance_interval": 1, |
| | "guidance_weight": 0.05 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 1.0, |
| | "buffer": 0.52360, |
| | } |
| | ), |
| | PlanarBondPotential( |
| | parameters={ |
| | "guidance_interval": 1, |
| | "guidance_weight": 0.05 |
| | if steering_args["physical_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 1.0, |
| | "buffer": 0.26180, |
| | } |
| | ), |
| | ] |
| | ) |
| | if boltz2 and ( |
| | steering_args["fk_steering"] or steering_args["contact_guidance_update"] |
| | ): |
| | potentials.extend( |
| | [ |
| | ContactPotentital( |
| | parameters={ |
| | "guidance_interval": 4, |
| | "guidance_weight": ( |
| | PiecewiseStepFunction( |
| | thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0] |
| | ) |
| | if steering_args["contact_guidance_update"] |
| | else 0.0 |
| | ), |
| | "resampling_weight": 1.0, |
| | "union_lambda": ExponentialInterpolation( |
| | start=8.0, end=0.0, alpha=-2.0 |
| | ), |
| | } |
| | ), |
| | TemplateReferencePotential( |
| | parameters={ |
| | "guidance_interval": 2, |
| | "guidance_weight": 0.1 |
| | if steering_args["contact_guidance_update"] |
| | else 0.0, |
| | "resampling_weight": 1.0, |
| | } |
| | ), |
| | ] |
| | ) |
| | return potentials |
| |
|