Spaces:
Running
on
T4
Running
on
T4
import os | |
import subprocess | |
import warnings | |
from datetime import datetime | |
import signal | |
from contextlib import contextmanager | |
import numpy as np | |
import torch | |
import yaml | |
from rdkit import Chem | |
from rdkit.Chem import RemoveHs, MolToPDBFile | |
from torch_geometric.nn.data_parallel import DataParallel | |
from models.all_atom_score_model import TensorProductScoreModel as AAScoreModel | |
from models.score_model import TensorProductScoreModel as CGScoreModel | |
from utils.diffusion_utils import get_timestep_embedding | |
from spyrmsd import rmsd, molecule | |
def get_obrmsd(mol1_path, mol2_path, cache_name=None): | |
cache_name = datetime.now().strftime('date%d-%m_time%H-%M-%S.%f') if cache_name is None else cache_name | |
os.makedirs(".openbabel_cache", exist_ok=True) | |
if not isinstance(mol1_path, str): | |
MolToPDBFile(mol1_path, '.openbabel_cache/obrmsd_mol1_cache.pdb') | |
mol1_path = '.openbabel_cache/obrmsd_mol1_cache.pdb' | |
if not isinstance(mol2_path, str): | |
MolToPDBFile(mol2_path, '.openbabel_cache/obrmsd_mol2_cache.pdb') | |
mol2_path = '.openbabel_cache/obrmsd_mol2_cache.pdb' | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
return_code = subprocess.run(f"obrms {mol1_path} {mol2_path} > .openbabel_cache/obrmsd_{cache_name}.rmsd", | |
shell=True) | |
print(return_code) | |
obrms_output = read_strings_from_txt(f".openbabel_cache/obrmsd_{cache_name}.rmsd") | |
rmsds = [line.split(" ")[-1] for line in obrms_output] | |
return np.array(rmsds, dtype=np.float) | |
def remove_all_hs(mol): | |
params = Chem.RemoveHsParameters() | |
params.removeAndTrackIsotopes = True | |
params.removeDefiningBondStereo = True | |
params.removeDegreeZero = True | |
params.removeDummyNeighbors = True | |
params.removeHigherDegrees = True | |
params.removeHydrides = True | |
params.removeInSGroups = True | |
params.removeIsotopes = True | |
params.removeMapped = True | |
params.removeNonimplicit = True | |
params.removeOnlyHNeighbors = True | |
params.removeWithQuery = True | |
params.removeWithWedgedBond = True | |
return RemoveHs(mol, params) | |
def read_strings_from_txt(path): | |
# every line will be one element of the returned list | |
with open(path) as file: | |
lines = file.readlines() | |
return [line.rstrip() for line in lines] | |
def save_yaml_file(path, content): | |
assert isinstance(path, str), f'path must be a string, got {path} which is a {type(path)}' | |
content = yaml.dump(data=content) | |
if '/' in path and os.path.dirname(path) and not os.path.exists(os.path.dirname(path)): | |
os.makedirs(os.path.dirname(path)) | |
with open(path, 'w') as f: | |
f.write(content) | |
def get_optimizer_and_scheduler(args, model, scheduler_mode='min'): | |
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay) | |
if args.scheduler == 'plateau': | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7, | |
patience=args.scheduler_patience, min_lr=args.lr / 100) | |
else: | |
print('No scheduler') | |
scheduler = None | |
return optimizer, scheduler | |
def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False): | |
if 'all_atoms' in args and args.all_atoms: | |
model_class = AAScoreModel | |
else: | |
model_class = CGScoreModel | |
timestep_emb_func = get_timestep_embedding( | |
embedding_type=args.embedding_type, | |
embedding_dim=args.sigma_embed_dim, | |
embedding_scale=args.embedding_scale) | |
lm_embedding_type = None | |
if args.esm_embeddings_path is not None: lm_embedding_type = 'esm' | |
model = model_class(t_to_sigma=t_to_sigma, | |
device=device, | |
no_torsion=args.no_torsion, | |
timestep_emb_func=timestep_emb_func, | |
num_conv_layers=args.num_conv_layers, | |
lig_max_radius=args.max_radius, | |
scale_by_sigma=args.scale_by_sigma, | |
sigma_embed_dim=args.sigma_embed_dim, | |
ns=args.ns, nv=args.nv, | |
distance_embed_dim=args.distance_embed_dim, | |
cross_distance_embed_dim=args.cross_distance_embed_dim, | |
batch_norm=not args.no_batch_norm, | |
dropout=args.dropout, | |
use_second_order_repr=args.use_second_order_repr, | |
cross_max_distance=args.cross_max_distance, | |
dynamic_max_cross=args.dynamic_max_cross, | |
lm_embedding_type=lm_embedding_type, | |
confidence_mode=confidence_mode, | |
num_confidence_outputs=len( | |
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance( | |
args.rmsd_classification_cutoff, list) else 1) | |
if device.type == 'cuda' and not no_parallel: | |
model = DataParallel(model) | |
model.to(device) | |
return model | |
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None): | |
with time_limit(10): | |
mol = molecule.Molecule.from_rdkit(mol) | |
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2 | |
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums | |
mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix | |
RMSD = rmsd.symmrmsd( | |
coords1, | |
coords2, | |
mol.atomicnums, | |
mol2_atomicnums, | |
mol.adjacency_matrix, | |
mol2_adjacency_matrix, | |
) | |
return RMSD | |
class TimeoutException(Exception): pass | |
def time_limit(seconds): | |
def signal_handler(signum, frame): | |
raise TimeoutException("Timed out!") | |
signal.signal(signal.SIGALRM, signal_handler) | |
signal.alarm(seconds) | |
try: | |
yield | |
finally: | |
signal.alarm(0) | |
class ExponentialMovingAverage: | |
""" from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py | |
Maintains (exponential) moving average of a set of parameters. """ | |
def __init__(self, parameters, decay, use_num_updates=True): | |
""" | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; usually the result of | |
`model.parameters()`. | |
decay: The exponential decay. | |
use_num_updates: Whether to use number of updates when computing | |
averages. | |
""" | |
if decay < 0.0 or decay > 1.0: | |
raise ValueError('Decay must be between 0 and 1') | |
self.decay = decay | |
self.num_updates = 0 if use_num_updates else None | |
self.shadow_params = [p.clone().detach() | |
for p in parameters if p.requires_grad] | |
self.collected_params = [] | |
def update(self, parameters): | |
""" | |
Update currently maintained parameters. | |
Call this every time the parameters are updated, such as the result of | |
the `optimizer.step()` call. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; usually the same set of | |
parameters used to initialize this object. | |
""" | |
decay = self.decay | |
if self.num_updates is not None: | |
self.num_updates += 1 | |
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) | |
one_minus_decay = 1.0 - decay | |
with torch.no_grad(): | |
parameters = [p for p in parameters if p.requires_grad] | |
for s_param, param in zip(self.shadow_params, parameters): | |
s_param.sub_(one_minus_decay * (s_param - param)) | |
def copy_to(self, parameters): | |
""" | |
Copy current parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. | |
""" | |
parameters = [p for p in parameters if p.requires_grad] | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
param.data.copy_(s_param.data) | |
def store(self, parameters): | |
""" | |
Save the current parameters for restoring later. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
temporarily stored. | |
""" | |
self.collected_params = [param.clone() for param in parameters] | |
def restore(self, parameters): | |
""" | |
Restore the parameters stored with the `store` method. | |
Useful to validate the model with EMA parameters without affecting the | |
original optimization process. Store the parameters before the | |
`copy_to` method. After validation (or model saving), use this to | |
restore the former parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored parameters. | |
""" | |
for c_param, param in zip(self.collected_params, parameters): | |
param.data.copy_(c_param.data) | |
def state_dict(self): | |
return dict(decay=self.decay, num_updates=self.num_updates, | |
shadow_params=self.shadow_params) | |
def load_state_dict(self, state_dict, device): | |
self.decay = state_dict['decay'] | |
self.num_updates = state_dict['num_updates'] | |
self.shadow_params = [tensor.to(device) for tensor in state_dict['shadow_params']] | |