Spaces:
Sleeping
Sleeping
import logging | |
import shutil | |
import sys | |
import tempfile | |
from argparse import ArgumentParser, Namespace, FileType | |
import copy | |
import itertools | |
import os | |
import subprocess | |
from datetime import datetime | |
from pathlib import Path | |
from functools import partial, cache | |
import warnings | |
import yaml | |
from Bio.PDB import PDBParser | |
from prody import parsePDB, parsePQR | |
from sklearn.cluster import DBSCAN | |
from openbabel import openbabel as ob | |
from src import const | |
from src.datasets import ( | |
collate_with_fragment_without_pocket_edges, collate_with_fragment_edges, get_dataloader, get_one_hot, parse_molecule | |
) | |
from src.lightning import DDPM | |
from src.linker_size_lightning import SizeClassifier | |
from src.utils import set_deterministic, FoundNaNException | |
# Ignore pandas deprecation warning around pyarrow | |
warnings.filterwarnings("ignore", category=DeprecationWarning, | |
message="(?s).*Pyarrow will become a required dependency of pandas.*") | |
import numpy as np | |
import pandas as pd | |
from pandarallel import pandarallel | |
import torch | |
from torch_geometric.loader import DataLoader | |
from Bio import SeqIO | |
from rdkit import RDLogger, Chem | |
from rdkit.Chem import RemoveAllHs, PandasTools | |
# TODO imports are a little odd, utils seems to shadow things | |
from utils.logging_utils import configure_logger, get_logger | |
from datasets.process_mols import create_mol_with_coords, read_molecule | |
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule | |
from utils.inference_utils import InferenceDataset | |
from utils.sampling import randomize_position, sampling | |
from utils.utils import get_model | |
from tqdm import tqdm | |
configure_logger() | |
log = get_logger() | |
RDLogger.DisableLog('rdApp.*') | |
ob.obErrorLog.SetOutputLevel(0) | |
warnings.filterwarnings("ignore", category=UserWarning, | |
message="The TorchScript type system doesn't support instance-level annotations on" | |
" empty non-base types in `__init__`") | |
# Prody logging is very verbose by default | |
prody_logger = logging.getLogger(".prody") | |
prody_logger.setLevel(logging.ERROR) | |
# Pandarallel initialization | |
nb_workers = os.cpu_count() | |
progress_bar = False | |
if hasattr(sys, 'gettrace') and sys.gettrace() is not None: # Debug mode | |
nb_workers = 1 | |
progress_bar = True | |
pandarallel.initialize(nb_workers=nb_workers, progress_bar=progress_bar) | |
def read_fragment_library(file_path): | |
if file_path is None: | |
return pd.DataFrame(columns=['X1', 'ID1', 'mol']) | |
file_path = Path(file_path) | |
if file_path.suffix == '.csv': | |
df = pd.read_csv(file_path) | |
# Validate columns | |
for col in ['X1', 'ID1']: | |
if col not in df.columns: | |
raise ValueError(f"Column '{col}' not found in CSV file.") | |
PandasTools.AddMoleculeColumnToFrame(df, smilesCol='X1', molCol='mol') | |
elif file_path.suffix == '.sdf': | |
df = PandasTools.LoadSDF(file_path, smilesName='X1', molColName='mol') | |
id_cols = [col for col in df.columns if 'ID' in col] | |
if id_cols: | |
df['ID1'] = df[id_cols[0]] | |
else: | |
raise ValueError(f"Unsupported file format: {file_path.suffix}") | |
if 'ID1' not in df.columns: | |
df['ID1'] = None | |
# Use InChiKey as ID1 if None | |
df.loc[df['ID1'].isna(), 'ID1'] = df.loc[ | |
df['ID1'].isna(), 'mol' | |
].apply(Chem.MolToInchiKey) | |
return df[['X1', 'ID1', 'mol']] | |
def read_protein_library(file_path): | |
df = None | |
if file_path.suffix == '.csv': | |
df = pd.read_csv(file_path) | |
elif file_path.suffix == '.fasta': | |
records = list(SeqIO.parse(file_path, 'fasta')) | |
df = pd.DataFrame([{'X2': str(record.seq), 'ID2': record.id} for record in records]) | |
return df | |
def remove_halogens(mol): | |
if mol is None: | |
return None | |
halogens = ['F', 'Cl', 'Br', 'I', 'At'] | |
# Enable editing | |
rw_mol = Chem.RWMol(mol) | |
for atom in rw_mol.GetAtoms(): | |
if atom.GetSymbol() in halogens: | |
# Replace with hydrogen | |
atom.SetAtomicNum(1) | |
mol_no_halogens = Chem.Mol(rw_mol) | |
# Make hydrogen implicit | |
mol_no_halogens = Chem.RemoveHs(mol_no_halogens) | |
return mol_no_halogens | |
def process_fragment_library(df, dehalogenate=True, discard_inorganic=True): | |
""" | |
SMILES strings with separators (e.g., .) represent distinct molecular entities, such as ligands, ions, or | |
co-crystallized molecules. Splitting them ensures that each entity is treated individually, allowing focused | |
analysis of their roles in binding. Single atom fragments (e.g., counterions like [I-] or [Cl-] are irrelevant in | |
docking and are to be removed. This filtering focuses on structurally relevant fragments. | |
""" | |
# Remove fragments with invalid SMILES | |
df['mol'] = df['X1'].apply(read_molecule, remove_confs=True) | |
df = df.dropna(subset=['mol']) | |
df['X1'] = df['mol'].apply(Chem.MolToSmiles) | |
# Get subset of rows with SMILES containing separators | |
fragmented_rows = df['X1'].str.contains('.', regex=False) | |
df_fragmented = df[fragmented_rows].copy() | |
# Split SMILES into lists and expand | |
df_fragmented['X1'] = df_fragmented['X1'].str.split('.') | |
df_fragmented = df_fragmented.explode('X1').reset_index(drop=True) | |
# Append fragment index as alphabet (A, B, C... AA, AB...) to ID1 for rows with fragmented SMILES | |
df_fragmented['ID1'] = df_fragmented.groupby('ID1').cumcount().apply(num_to_letter_code).radd( | |
df_fragmented['ID1'] + '_') | |
df = pd.concat([df[~fragmented_rows], df_fragmented]) | |
# Remove single atom fragments | |
df = df[df['mol'].apply(lambda mol: mol.GetNumAtoms() > 1)] | |
if discard_inorganic: | |
df = df[df['mol'].apply(lambda mol: any(atom.GetSymbol() == 'C' for atom in mol.GetAtoms()))] | |
if dehalogenate: | |
df['mol'] = df['mol'].apply(remove_halogens) | |
# Deduplicate fragments and canonicalize SMILES | |
df = df.groupby(['X1']).first().reset_index() | |
df['X1'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x)) | |
return df | |
def check_one_to_one(df, ID_column, X_column): | |
# Check for multiple X values for the same ID | |
id_to_x_conflicts = df.groupby(ID_column)[X_column].nunique() | |
conflicting_ids = id_to_x_conflicts[id_to_x_conflicts > 1] | |
# Check for multiple ID values for the same X | |
x_to_id_conflicts = df.groupby(X_column)[ID_column].nunique() | |
conflicting_xs = x_to_id_conflicts[x_to_id_conflicts > 1] | |
# Print conflicting mappings | |
if not conflicting_ids.empty: | |
print(f"Conflicting {ID_column} -> multiple {X_column}:") | |
for idx in conflicting_ids.index: | |
print(f"{ID_column}: {idx}, {X_column} values: {df[df[ID_column] == idx][X_column].unique()}") | |
if not conflicting_xs.empty: | |
print(f"Conflicting {X_column} -> multiple {ID_column}:") | |
for x in conflicting_xs.index: | |
print(f"{X_column}: {x}, {ID_column} values: {df[df[X_column] == x][ID_column].unique()}") | |
# Return whether the mappings are one-to-one | |
return conflicting_ids.empty and conflicting_xs.empty | |
def save_sdf(path, one_hot, positions, node_mask, is_geom): | |
# Select atom mapping based on whether geometry or generic atoms are used | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
# Identify valid atoms based on the mask | |
mask = node_mask.squeeze() | |
atom_indices = torch.where(mask)[0] | |
obMol = ob.OBMol() | |
# Add atoms to OpenBabel molecule | |
atoms = torch.argmax(one_hot, dim=1) | |
for atom_i in atom_indices: | |
atom = atoms[atom_i].item() | |
atom_symbol = idx2atom[atom] | |
obAtom = obMol.NewAtom() | |
obAtom.SetAtomicNum(Chem.GetPeriodicTable().GetAtomicNumber(atom_symbol)) # Set atomic number | |
# Set atomic positions | |
pos = positions[atom_i] | |
obAtom.SetVector(pos[0].item(), pos[1].item(), pos[2].item()) | |
# Infer bonds with OpenBabel | |
obMol.ConnectTheDots() | |
obMol.PerceiveBondOrders() | |
# Convert OpenBabel molecule to SDF | |
obConversion = ob.OBConversion() | |
obConversion.SetOutFormat("sdf") | |
sdf_string = obConversion.WriteString(obMol) | |
# Save SDF file | |
with open(path, "w") as f: | |
f.write(sdf_string) | |
# Generate SMILES | |
rdkit_mol = Chem.MolFromMolBlock(sdf_string) | |
if rdkit_mol is not None: | |
smiles = Chem.MolToSmiles(rdkit_mol) | |
else: | |
# Use OpenBabel to generate SMILES if RDKit fails | |
obConversion.SetOutFormat("can") | |
smiles = obConversion.WriteString(obMol).strip() | |
return smiles | |
def num_to_letter_code(n): | |
result = '' | |
while n >= 0: | |
result = chr(65 + (n % 26)) + result | |
n = n // 26 - 1 | |
return result | |
def dock_fragments( | |
out_dir, | |
score_ckpt, confidence_ckpt, device, | |
inference_steps, n_poses, initial_noise_std_proportion, docking_batch_size, | |
no_final_step_noise, | |
temp_sampling_tr, temp_sampling_rot, temp_sampling_tor, | |
temp_psi_tr, temp_psi_rot, temp_psi_tor, | |
temp_sigma_data_tr, temp_sigma_data_rot,temp_sigma_data_tor, | |
save_docking, | |
df=None, protein_ligand_csv=None, fragment_library=None, protein_library=None, | |
): | |
with open(Path(score_ckpt).parent / 'model_parameters.yml') as f: | |
score_model_args = Namespace(**yaml.full_load(f)) | |
with open(Path(confidence_ckpt).parent / 'model_parameters.yml') as f: | |
confidence_args = Namespace(**yaml.full_load(f)) | |
docking_out_dir = Path(out_dir, 'docking') | |
docking_out_dir.mkdir(parents=True, exist_ok=True) | |
if df is None: | |
if protein_ligand_csv is not None: | |
csv_path = Path(protein_ligand_csv) | |
assert csv_path.is_file(), f"File {protein_ligand_csv} does not exist" | |
df = pd.read_csv(csv_path) | |
df = process_fragment_library(df) | |
else: | |
assert fragment_library is not None and protein_library is not None, "Either a .csv file or `X1` and `X2` must be provided." | |
compound_df = pd.DataFrame(columns=['X1', 'ID1']) | |
if Path(fragment_library).is_file(): | |
compound_path = Path(fragment_library) | |
if compound_path.suffix in ['.csv', '.sdf']: | |
compound_df[['X1', 'ID1']] = read_fragment_library(compound_path)[['X1', 'ID1']] | |
else: | |
compound_df['X1'] = [compound_path] | |
compound_df['ID1'] = [compound_path.stem] | |
else: | |
compound_df['X1'] = [fragment_library] | |
compound_df['ID1'] = 'compound_0' | |
compound_df.dropna(subset=['X1'], inplace=True) | |
compound_df.loc[compound_df['ID1'].isna(), 'ID1'] = compound_df.loc[compound_df['ID1'].isna(), 'X1'].apply( | |
lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x)) | |
) | |
protein_df = pd.DataFrame(columns=['X2', 'ID2']) | |
if Path(protein_library).is_file(): | |
protein_path = Path(protein_library) | |
if protein_path.suffix in ['.csv', '.fasta']: | |
protein_df[['X2', 'ID2']] = read_protein_library(protein_path)[['X2', 'ID2']] | |
else: | |
protein_df['X2'] = [protein_path] | |
protein_df['ID2'] = [protein_path.stem] | |
else: | |
protein_df['X2'] = [protein_library] | |
protein_df['ID2'] = 'protein_0' | |
protein_df.dropna(subset=['X2'], inplace=True) | |
protein_df.loc[protein_df['ID2'].isna(), 'ID2'] = [ | |
f"protein_{i}" for i in range(protein_df['ID2'].isna().sum()) | |
] | |
compound_df = process_fragment_library(compound_df) | |
df = compound_df.merge(protein_df, how='cross') | |
# Identify duplicates based on 'X1' and 'X2' | |
duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)] | |
if not duplicates.empty: | |
print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['ID1', 'X1', 'ID2', 'X2']]) | |
print("Keeping the first occurrence of each duplicate.") | |
df.drop_duplicates(subset=['X1', 'X2'], inplace=True) | |
df['name'] = df['ID2'] + '-' + df['ID1'] | |
df = df.replace({pd.NA: None}) | |
# Check unique mappings between IDn and Xn | |
assert check_one_to_one(df, 'ID1', 'X1'), "ID1-X1 mapping is not one-to-one." | |
assert check_one_to_one(df, 'ID2', 'X2'), "ID2-X2 mapping is not one-to-one." | |
""" | |
Docking phase | |
""" | |
# preprocessing of complexes into geometric graphs | |
test_dataset = InferenceDataset( | |
df=df, out_dir=out_dir, | |
lm_embeddings=True, | |
receptor_radius=score_model_args.receptor_radius, | |
remove_hs=True, # score_model_args.remove_hs, | |
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, | |
all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius, | |
atom_max_neighbors=score_model_args.atom_max_neighbors, | |
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') | |
else not score_model_args.not_knn_only_graph | |
) | |
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) | |
confidence_test_dataset = InferenceDataset( | |
df=df, out_dir=out_dir, | |
lm_embeddings=True, | |
receptor_radius=confidence_args.receptor_radius, | |
remove_hs=True, # confidence_args.remove_hs, | |
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, | |
all_atoms=confidence_args.all_atoms, | |
atom_radius=confidence_args.atom_radius, | |
atom_max_neighbors=confidence_args.atom_max_neighbors, | |
precomputed_lm_embeddings=test_dataset.lm_embeddings, | |
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') | |
else not score_model_args.not_knn_only_graph | |
) | |
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) | |
model = get_model( | |
score_model_args, device, | |
t_to_sigma=t_to_sigma, no_parallel=True | |
) | |
state_dict = torch.load(Path(score_ckpt), map_location='cpu', weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
model = model.to(device) | |
model.eval() | |
confidence_model = get_model( | |
confidence_args, device, | |
t_to_sigma=t_to_sigma, no_parallel=True, confidence_mode=True, old=True | |
) | |
state_dict = torch.load(Path(confidence_ckpt), map_location='cpu', weights_only=True) | |
confidence_model.load_state_dict(state_dict, strict=True) | |
confidence_model = confidence_model.to(device) | |
confidence_model.eval() | |
tr_schedule = get_t_schedule(inference_steps=inference_steps, sigma_schedule='expbeta') | |
failures, skipped = 0, 0 | |
samples_per_complex = n_poses | |
test_ds_size = len(test_dataset) | |
df = test_loader.dataset.df | |
docking_dfs = [] | |
log.info(f'Size of fragment dataset: {test_ds_size}') | |
for idx, orig_complex_graph in tqdm(enumerate(test_loader), total=test_ds_size): | |
if not orig_complex_graph.success[0]: | |
skipped += 1 | |
log.warning( | |
f"The test dataset did not contain {df['name'].iloc[idx]}" | |
f" for {df['X1'].iloc[idx]} and {df['X2'].iloc[idx]}. We are skipping this complex.") | |
continue | |
try: | |
if confidence_test_dataset is not None: | |
confidence_complex_graph = confidence_test_dataset[idx] | |
if not confidence_complex_graph.success: | |
skipped += 1 | |
log.warning( | |
f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.") | |
continue | |
confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(samples_per_complex)] | |
else: | |
confidence_data_list = None | |
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(samples_per_complex)] | |
randomize_position( | |
data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max, | |
initial_noise_std_proportion=initial_noise_std_proportion, choose_residue=False | |
) | |
# run reverse diffusion | |
# TODO How to make full use of VRAM? seems the best way to create another loop for each fragment | |
''' | |
File "DiffFragDock/utils/sampling.py", line 142, in sampling | |
tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z) | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 0 | |
''' | |
# TODO It seems molecules of different sizes cannot be in the same batch in inference | |
if n_poses <= docking_batch_size: | |
batch_size = n_poses | |
elif n_poses % docking_batch_size == 0: | |
batch_size = docking_batch_size | |
else: | |
raise ValueError | |
data_list, confidence = sampling( | |
data_list=data_list, model=model, | |
inference_steps=inference_steps, | |
tr_schedule=tr_schedule, rot_schedule=tr_schedule, | |
tor_schedule=tr_schedule, | |
device=device, t_to_sigma=t_to_sigma, model_args=score_model_args, | |
visualization_list=None, confidence_model=confidence_model, | |
confidence_data_list=confidence_data_list, | |
confidence_model_args=confidence_args, | |
batch_size=batch_size, no_final_step_noise=no_final_step_noise, | |
temp_sampling=[temp_sampling_tr, temp_sampling_rot, temp_sampling_tor], | |
temp_psi=[temp_psi_tr, temp_psi_rot, temp_psi_tor], | |
temp_sigma_data=[temp_sigma_data_tr, temp_sigma_data_rot,temp_sigma_data_tor] | |
) | |
ligand_pos = np.asarray( | |
[complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for | |
complex_graph in data_list] | |
) | |
# save predictions | |
n_samples = len(confidence) | |
sample_df = pd.DataFrame([df.iloc[idx]] * n_samples) | |
sample_df['ID1'] = [f"{df['ID1'].iloc[idx]}_{i}" for i in range(n_samples)] | |
confidence = confidence[:, 0].cpu().numpy() | |
sample_df['confidence'] = confidence | |
lig = orig_complex_graph.mol[0] | |
# TODO Use index instead of confidence in filename | |
if save_docking: | |
sample_df['ligand_conf_path'] = [ | |
f"{df['name'].iloc[idx]}_{i}-confidence{confidence[i]:.2f}.sdf" for i in range(n_samples) | |
] | |
sample_df['ligand_mol']= [ | |
create_mol_with_coords( | |
mol=RemoveAllHs(copy.deepcopy(lig)), | |
new_coords=pos, | |
path=Path(docking_out_dir, sample_df['ligand_conf_path'].iloc[i]) if save_docking else None | |
) for i, pos in enumerate(ligand_pos) | |
] | |
# sample_df['ligand_pos'] = list(ligand_pos) | |
docking_dfs.append(sample_df) | |
# write_dir = f"{args.out_dir}/{df['name'].iloc[idx]}" | |
# for rank, pos in enumerate(ligand_pos): | |
# mol_pred = copy.deepcopy(lig) | |
# if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred) | |
# if rank == 0: write_mol_with_coords(mol_pred, pos, Path(write_dir, f'rank{rank + 1}.sdf')) | |
# write_mol_with_coords(mol_pred, pos, | |
# Path(write_dir, f'rank{rank + 1}_confidence{confidence[rank]:.2f}.sdf')) | |
except Exception as e: | |
log.warning("Failed on", orig_complex_graph["name"], e) | |
failures += 1 | |
# Teardown | |
del model | |
if confidence_model is not None: | |
del confidence_model | |
del test_dataset | |
if confidence_test_dataset is not None: | |
del confidence_test_dataset | |
del test_loader | |
torch.cuda.empty_cache() | |
docking_df = pd.concat(docking_dfs, ignore_index=True) | |
# Save intermediate docking results | |
if save_docking: | |
docking_df[ | |
['name', 'ID2', 'protein_path', 'ID1', 'X1', 'confidence', 'ligand_conf_path'] | |
].to_csv(Path(out_dir, 'docking_summary.csv'), index=False) | |
result_msg = f""" | |
Failed for {failures} / {test_ds_size} complexes. | |
Skipped {skipped} / {test_ds_size} complexes. | |
""" | |
if failures or skipped: | |
log.warning(result_msg) | |
else: | |
log.info(result_msg) | |
log.info(f"Docking results saved to {docking_out_dir}") | |
return docking_df | |
def calculate_mol_atomic_distances(mol1, mol2, distance_type='min'): | |
mol1_coords = [ | |
mol1.GetConformer().GetAtomPosition(i) for i in range(mol1.GetNumAtoms()) | |
] | |
mol2_coords = [ | |
mol2.GetConformer().GetAtomPosition(i) for i in range(mol2.GetNumAtoms()) | |
] | |
# Ensure numpy arrays | |
mol1_coords = np.array(mol1_coords) | |
mol2_coords = np.array(mol2_coords) | |
# Compute pairwise distances between carbon atoms | |
atom_pairwise_distances = np.linalg.norm(mol1_coords[:, None, :] - mol2_coords[None, :, :], axis=-1) | |
# if np.any(np.isnan(atom_pairwise_distances)): | |
# import pdb | |
# pdb.set_trace() # Trigger a breakpoint if NaN is found | |
if distance_type == 'min': | |
return atom_pairwise_distances.min() | |
elif distance_type == 'mean': | |
return atom_pairwise_distances.mean() | |
elif distance_type is None: | |
return atom_pairwise_distances | |
else: | |
raise ValueError(f"Unsupported distance_type: {distance_type}") | |
def process_docking_results( | |
df, | |
eps=5, # Distance threshold for DBSCAN clustering | |
min_samples=5, # Minimum number of samples for a cluster (enrichment) | |
frag_dist_range=(2, 5), # Distance range for fragment linking | |
distance_type='min', # Type of distance to compute between fragments | |
): | |
assert len(frag_dist_range) == 2, 'Distance range must be a tuple of two values in Angstroms (Å).' | |
frag_dist_range = sorted(frag_dist_range) | |
# The mols in df should have been processed to have no explicit hydrogens, except heavy hydrogen isotopes. | |
docking_summaries = [] # For saving intermediate docking results | |
fragment_combos = [] # Fragment pairs for the linking step | |
# 1. Cluster docking poses | |
# Compute pairwise distances of molecules defined by the closest non-heavy atoms | |
for protein, protein_df in df.groupby('X2'): | |
protein_id = protein_df['ID2'].iloc[0] | |
protein_path = protein_df['protein_path'].iloc[0] | |
protein_df['index'] = protein_df.index | |
log.info(f'Processing docking results for {protein_id}...') | |
dist_matrix = np.stack( | |
protein_df['ligand_mol'].parallel_apply( | |
lambda mol1: [ | |
calculate_mol_atomic_distances(mol1, mol2, distance_type=distance_type) | |
for mol2 in protein_df['ligand_mol'] | |
] | |
) | |
) | |
# Perform DBSCAN clustering | |
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed') | |
protein_df['cluster'] = dbscan.fit_predict(dist_matrix) | |
protein_df = protein_df.sort_values( | |
by=['X1', 'cluster', 'confidence'], ascending=[True, True, False] | |
) | |
# Add conformer number to ID1 | |
protein_df['ID1'] = protein_df.groupby('ID1').cumcount().astype(str).radd(protein_df['ID1'] + '_') | |
if args.save_docking: | |
docking_summaries.append( | |
protein_df[['name', 'ID2', 'X2', 'ID1', 'X1', 'cluster', 'confidence', 'ligand_conf_path']] | |
) | |
# Filter out outlier poses | |
protein_df = protein_df[protein_df['cluster'] != -1] | |
# Keep only the highest confidence pose per protein per ligand per cluster | |
protein_df = protein_df.groupby(['X1', 'cluster']).first().reset_index() | |
# 2. Create fragment-linking pairs | |
fragment_path = None | |
protein_fragment_combos = [] | |
for cluster, cluster_df in protein_df.groupby('cluster'): | |
if len(cluster_df) > 1: # Skip clusters with only one pose | |
pairs = list(itertools.combinations(cluster_df['index'], 2)) | |
for i, j in pairs: | |
row1 = cluster_df[cluster_df['index'] == i].iloc[0] | |
row2 = cluster_df[cluster_df['index'] == j].iloc[0] | |
dist = dist_matrix[i, j] | |
# Check if intermolecular distance is within the range | |
if frag_dist_range[0] < dist < frag_dist_range[1]: | |
combined_smiles = f"{row1['X1']}.{row2['X1']}" | |
combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol']) | |
complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}" | |
if 'ligand_conf_path' in row1 and 'ligand_conf_path' in row2: | |
fragment_path = [str(row1['ligand_conf_path']), str(row2['ligand_conf_path'])] | |
protein_fragment_combos.append( | |
(complex_name, protein, protein_path, combined_smiles, fragment_path, combined_mol, dist) | |
) | |
log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_combos)}.') | |
fragment_combos.extend(protein_fragment_combos) | |
# Save intermediate docking results | |
if args.save_docking: | |
docking_summary_df = pd.concat(docking_summaries, ignore_index=True) | |
docking_summary_df.to_csv(Path(args.out_dir, 'docking_summary.csv'), index=False) | |
log.info(f'Saved intermediate docking results to {args.out_dir}') | |
# Convert fragment pair results to DataFrame | |
if fragment_combos: | |
linking_df = pd.DataFrame( | |
fragment_combos, columns=['name', 'X2', 'protein_path', 'X1', 'fragment_path', 'fragment_mol', 'distance'] | |
) | |
if linking_df['fragment_path'].isnull().all(): | |
linking_df.drop(columns=['fragment_path'], inplace=True) | |
linking_df.drop(columns=['fragment_mol']).to_csv(Path(args.out_dir, 'linking_summary.csv'), index=False) | |
return linking_df | |
else: | |
raise ValueError('No eligible fragment pose pairs found for linking.') | |
def extract_pockets(protein_path, ligand_residue=None, top_pockets=None): | |
protein_path = Path(protein_path) | |
if ligand_residue: | |
top_pockets = 1 | |
# Copy the protein file to a temporary directory to avoid overwriting pocket files in different runs | |
tmp_dir = tempfile.mkdtemp() | |
tmp_protein_path = Path(tmp_dir) / protein_path.name | |
shutil.copy(protein_path, tmp_protein_path) | |
# Run fpocket | |
distance = 2.5 | |
min_size = 30 | |
args = ['./fpocket', '-d', '-f', tmp_protein_path, '-D', str(distance), '-i', str(min_size)] | |
if ligand_residue is not None: | |
args += ['-r', ligand_residue] | |
print(args) | |
subprocess.run(args, stdout=subprocess.DEVNULL) | |
fpocket_out_path = Path(str(tmp_protein_path.with_suffix('')) + '_out') | |
if not fpocket_out_path.is_dir(): | |
raise ValueError(f"fpocket output directory not found: {fpocket_out_path}") | |
pocket_alpha_sphere_path_dict = {} | |
if top_pockets is not None: | |
pocket_names = [f'pocket{i}' for i in range(1, top_pockets + 1)] | |
for name in pocket_names: | |
pocket_path = Path(fpocket_out_path, f'{name}_vert.pqr').resolve() | |
if pocket_path.is_file(): | |
pocket_alpha_sphere_path_dict[name] = str(pocket_path) | |
else: | |
# use fpocket_out_path.glob('*_vert.pqr') | |
pocket_alpha_sphere_path_dict = { | |
pocket_path.stem.split('_')[0]: str(pocket_path) for pocket_path in fpocket_out_path.glob('*_vert.pqr') | |
} | |
return pocket_alpha_sphere_path_dict | |
def check_pocket_overlap(mol, pocket_as): | |
mol_coords = [ | |
mol.GetConformer().GetAtomPosition(i) for i in range(mol.GetNumAtoms()) | |
] | |
for as_coords, as_radii in zip(pocket_as['coord'], pocket_as['radii']): | |
for atom_coord in mol_coords: | |
if np.linalg.norm(as_coords - atom_coord) < as_radii: | |
return True | |
return False | |
def deduplicate_conformers(fragment_df, rmsd_threshold=1.5): | |
if len(fragment_df) > 1: | |
mol_list = fragment_df['ligand_mol'].tolist() | |
indices_to_drop = set() | |
for i, mol1 in enumerate(mol_list): | |
if i in indices_to_drop: # Skip already marked duplicates | |
continue | |
for j, mol2 in enumerate(mol_list): | |
if i < j and j not in indices_to_drop: # Not comparing already removed molecules | |
rmsd = Chem.rdMolAlign.CalcRMS(mol1, mol2) | |
if rmsd < rmsd_threshold: | |
indices_to_drop.add(fragment_df.index[j]) # Mark duplicate for removal | |
fragment_df.drop(indices_to_drop, inplace=True) | |
return fragment_df | |
def select_fragment_pairs( | |
df, | |
pocket_path_dict=None, | |
top_pockets=3, | |
frag_dist_range=(2, 5), # Distance range for fragment linking | |
confidence_threshold=-1.5, | |
rmsd_threshold=1.5, | |
method='fpocket', | |
out_dir=Path('.'), | |
ligand_residue=None, | |
): | |
df = df[df['confidence'] > confidence_threshold].copy() | |
if 'ligand_conf_path' in df.columns: | |
df['ligand_conf_path'] = df['ligand_conf_path'].apply(Path) | |
if 'ligand_mol' not in df.columns: | |
df['ligand_mol'] = df['ligand_conf_path'].apply(read_molecule) | |
# Given pocket_path_dict for single protein case | |
if pocket_path_dict is not None: | |
pocket_names = list(pocket_path_dict.keys()) | |
top_pockets = len(pocket_names) | |
else: | |
pocket_names = [f'pocket{i}' for i in range(1, top_pockets + 1)] | |
# Add pocket columns to DataFrame | |
for name in pocket_names: | |
df[name] = False | |
fragment_conf_pairs = [] | |
for protein_path, protein_df in df.groupby('protein_path'): | |
protein_path = Path(protein_path) | |
protein_fragment_conf_pairs = [] | |
fragment_path = None | |
protein_id = protein_df['ID2'].iloc[0] | |
match method: | |
case 'fpocket': | |
# TODO: avoid reruning fpocket when proper job management is implemented | |
if pocket_path_dict is None: | |
pocket_path_dict = extract_pockets(protein_path, ligand_residue, top_pockets) | |
# Read pocket PQRs | |
for name in pocket_names: | |
pocket_as = read_pocket_alpha_spheres(pocket_path_dict[name]) | |
# Check if any atom in a fragment conformer falls within pocket volume of alpha spheres | |
protein_df[name] = protein_df['ligand_mol'].parallel_apply( | |
check_pocket_overlap, pocket_as=pocket_as | |
) | |
case 'clustering': | |
# Clustering-based pocket finding | |
pass | |
# Filter out fragment conformers that do not overlap with any pocket | |
protein_df = protein_df[protein_df[pocket_names].any(axis=1)] | |
# Select fragment conformer pairs for linking per pocket based on distance range | |
for name in pocket_names: | |
pocket_df = protein_df[protein_df[name] == True].copy() | |
if len(pocket_df) > 1: | |
# pocket_path = pocket_alpha_sphere_path_dict[name] | |
# Deduplicate similar conformers with RDKit Chem.rdMolAlign.CalcRMS | |
pocket_df = pocket_df.groupby('X1', group_keys=False).parallel_apply( | |
deduplicate_conformers, rmsd_threshold=rmsd_threshold | |
).reset_index(drop=True) | |
pairs = list(itertools.combinations(pocket_df.index, 2)) | |
dist_matrix = np.stack( | |
pocket_df['ligand_mol'].parallel_apply( | |
lambda mol1: [ | |
calculate_mol_atomic_distances(mol1, mol2, distance_type='min') | |
for mol2 in pocket_df['ligand_mol'] | |
] | |
) | |
) | |
for i, j in pairs: | |
dist = dist_matrix[i, j] | |
if frag_dist_range[0] < dist < frag_dist_range[1]: | |
row1 = pocket_df.loc[i] | |
row2 = pocket_df.loc[j] | |
combined_smiles = f"{row1['X1']}.{row2['X1']}" | |
combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol']) | |
complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}" | |
if 'ligand_conf_path' in row1 and 'ligand_conf_path' in row2: | |
fragment_path = [row1['ligand_conf_path'].name, row2['ligand_conf_path'].name] | |
protein_fragment_conf_pairs.append( | |
(complex_name, protein_path, # pocket_path, | |
combined_smiles, fragment_path, combined_mol, dist) | |
) | |
log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_conf_pairs)}.') | |
fragment_conf_pairs.extend(protein_fragment_conf_pairs) | |
# Convert fragment pair results to DataFrame | |
if fragment_conf_pairs: | |
linking_df = pd.DataFrame( | |
fragment_conf_pairs, | |
columns=[ | |
'name', 'protein_path', # 'pocket_path', | |
'X1', 'fragment_path', 'fragment_mol', 'distance' | |
] | |
) | |
if linking_df['fragment_path'].isnull().all(): | |
linking_df.drop(columns=['fragment_path'], inplace=True) | |
linking_df.drop(columns=['fragment_mol']).to_csv(Path(out_dir, 'linking_summary.csv'), index=False) | |
return linking_df | |
else: | |
return None | |
def process_linking_results(): | |
pass | |
def get_pocket(mol, pdb_path, backbone_atoms_only=False): | |
struct = PDBParser().get_structure('', pdb_path) | |
residue_ids = [] | |
atom_coords = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
for atom in residue.get_atoms(): | |
atom_coords.append(atom.get_coord()) | |
residue_ids.append(resid) | |
residue_ids = np.array(residue_ids) | |
atom_coords = np.array(atom_coords) | |
mol_atom_coords = mol.GetConformer().GetPositions() | |
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) | |
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) | |
pocket_coords = [] | |
pocket_types = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
if resid not in contact_residues: | |
continue | |
for atom in residue.get_atoms(): | |
atom_name = atom.get_name() | |
atom_type = atom.element.upper() | |
atom_coord = atom.get_coord() | |
if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}: | |
continue | |
pocket_coords.append(atom_coord.tolist()) | |
pocket_types.append(atom_type) | |
pocket_pos = [] | |
pocket_one_hot = [] | |
pocket_charges = [] | |
for coord, atom_type in zip(pocket_coords, pocket_types): | |
if atom_type not in const.GEOM_ATOM2IDX.keys(): | |
continue | |
pocket_pos.append(coord) | |
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
pocket_pos = np.array(pocket_pos) | |
pocket_one_hot = np.array(pocket_one_hot) | |
pocket_charges = np.array(pocket_charges) | |
return pocket_pos, pocket_one_hot, pocket_charges | |
def read_pocket(path, backbone_atoms_only): | |
pocket_coords = [] | |
pocket_types = [] | |
struct = PDBParser().get_structure('', path) | |
for residue in struct.get_residues(): | |
for atom in residue.get_atoms(): | |
atom_name = atom.get_name() | |
atom_type = atom.element.upper() | |
atom_coord = atom.get_coord() | |
if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}: | |
continue | |
pocket_coords.append(atom_coord.tolist()) | |
pocket_types.append(atom_type) | |
return { | |
'coord': np.array(pocket_coords), | |
'types': np.array(pocket_types), | |
} | |
def read_pocket_alpha_spheres(path): | |
ag = parsePQR(path) | |
as_coords = [] | |
as_radii = [] | |
for atom in ag: | |
as_coords.append(atom.getCoords()) | |
as_radii.append(atom.getRadius()) | |
return { | |
'coord': np.array(as_coords), | |
'radii': np.array(as_radii), | |
} | |
def generate_linkers( | |
df, backbone_atoms_only, | |
output_dir, n_samples, n_steps, linker_size, anchors, max_batch_size, random_seed, robust, | |
linker_ckpt, size_ckpt, linker_condition, device, | |
): | |
# Model setup | |
pocket_conditioned = linker_condition in ['protein', 'pocket'] | |
if 'X2' in df.columns and pocket_conditioned: | |
if backbone_atoms_only: | |
linker_ckpt = linker_ckpt['pocket_bb'] | |
else: | |
linker_ckpt = linker_ckpt['pocket_full'] | |
else: | |
linker_ckpt = linker_ckpt['geom'] | |
ddpm = DDPM.load_from_checkpoint( | |
linker_ckpt, | |
robust=robust, torch_device=device, map_location=device | |
).eval().to(device) | |
is_geom = ddpm.is_geom | |
if random_seed is not None: | |
set_deterministic(random_seed) | |
output_dir = Path(output_dir, 'linking') | |
output_dir.mkdir(exist_ok=True, parents=True) | |
linker_size = str(linker_size) | |
if linker_size == '0': | |
log.info(f'Will generate linkers with sampled numbers of atoms') | |
size_classifier = SizeClassifier.load_from_checkpoint(size_ckpt, map_location=device).eval().to(device) | |
def sample_fn(_data): | |
# TODO Improve efficiency: do not repeat sampling for the same fragment(-pocket) samples | |
out, _ = size_classifier.forward( | |
_data, return_loss=False, with_pocket=pocket_conditioned, adjust_shape=True | |
) | |
probabilities = torch.softmax(out, dim=1) | |
distribution = torch.distributions.Categorical(probs=probabilities) | |
samples = distribution.sample() | |
sizes = [] | |
for label in samples.detach().cpu().numpy(): | |
sizes.append(size_classifier.linker_id2size[label]) | |
sizes = torch.tensor(sizes, device=samples.device, dtype=const.TORCH_INT) | |
return sizes | |
elif linker_size.isdigit(): | |
log.info(f'Will generate linkers with {linker_size} atoms') | |
linker_size = int(linker_size) | |
def sample_fn(_data): | |
return torch.ones(_data['positions'].shape[0], device=device, dtype=const.TORCH_INT) * linker_size | |
else: | |
boundaries = [x.strip() for x in linker_size.split(',')] | |
if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit(): | |
left = int(boundaries[0]) | |
right = int(boundaries[1]) | |
log.info(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})') | |
def sample_fn(_data): | |
shape = len(_data['positions']), | |
return torch.randint(left, right + 1, shape, device=device, dtype=const.TORCH_INT) | |
if n_steps is not None: | |
ddpm.edm.T = n_steps | |
if ddpm.center_of_mass == 'anchors' and anchors is None: | |
log.warning( | |
"Using a anchor-conditioned DiffLinker checkpoint without providing anchors. " | |
"Forcing model's `center_of_mass` to 'fragments'." | |
) | |
ddpm.center_of_mass = 'fragments' | |
# # Apply the mapping to fill NaN values in ID1 and ID2 | |
# if 'ID1' not in df.columns: | |
# df['ID1'] = None | |
# if 'ID2' not in df.columns: | |
# df['ID2'] = None | |
# df.loc[df['ID1'].isna(), 'ID1'] = df.loc[df['ID1'].isna(), 'X1'].apply( | |
# lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x)) | |
# ) | |
# df.loc[df['ID2'].isna(), 'ID2'] = df.loc[df['ID2'].isna(), 'X2'].map({ | |
# x2_value: f"protein_{i}" | |
# for i, x2_value in enumerate(df.loc[df['ID2'].isna(), 'X2'].unique()) | |
# }) | |
# # Identify duplicates based on 'X1' and 'X2' | |
# duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)] | |
# if not duplicates.empty: | |
# print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['X1', 'X2']]) | |
# print("Keeping the first occurrence of each duplicate.") | |
# df = df.drop_duplicates(subset=['X1', 'X2']) | |
# Dataset setup | |
if 'fragment_path' not in df.columns: | |
df['fragment_path'] = df['X1'] | |
if 'fragment_mol' not in df.columns: | |
df['fragment_mol'] = df['fragment_path'].parallel_apply(read_molecule, remove_hs=True, remove_confs=False) | |
if 'protein_path' not in df.columns: | |
df['protein_path'] = df['X2'] | |
if 'name' not in df.columns and 'ID1' in df.columns and 'ID2' in df.columns: | |
df['name'] = df['ID1'] + '-' + df['ID2'] | |
df.dropna(subset=['fragment_mol', 'protein_path'], inplace=True) | |
cached_parse_molecule = cache(parse_molecule) | |
dataset = [] | |
optional_keys = ['X2', 'protein_path'] | |
for row in df.itertuples(): | |
mol = row.fragment_mol # Hs already removed | |
# Parsing fragments data | |
frag_pos, frag_one_hot, frag_charges = cached_parse_molecule(mol, is_geom=is_geom) | |
# Parsing pocket data | |
if pocket_conditioned: | |
if linker_condition == 'protein': | |
pocket_pos, pocket_one_hot, pocket_charges = get_pocket(mol, row.protein_path, backbone_atoms_only) | |
elif linker_condition == 'pocket': | |
pocket_data = read_pocket(row.protein_path, backbone_atoms_only) | |
pocket_pos = pocket_data['coord'] | |
pocket_one_hot = [] | |
pocket_charges = [] | |
for atom_type in pocket_data['types']: | |
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
pocket_one_hot = np.array(pocket_one_hot) | |
pocket_charges = np.array(pocket_charges) | |
positions = np.concatenate([frag_pos, pocket_pos], axis=0) | |
one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0) | |
charges = np.concatenate([frag_charges, pocket_charges], axis=0) | |
fragment_only_mask = np.concatenate([ | |
np.ones_like(frag_charges), | |
np.zeros_like(pocket_charges), | |
]) | |
pocket_mask = np.concatenate([ | |
np.zeros_like(frag_charges), | |
np.ones_like(pocket_charges), | |
]) | |
linker_mask = np.concatenate([ | |
np.zeros_like(frag_charges), | |
np.zeros_like(pocket_charges), | |
]) | |
fragment_mask = np.concatenate([ | |
np.ones_like(frag_charges), | |
np.ones_like(pocket_charges), | |
]) | |
else: | |
positions = frag_pos | |
one_hot = frag_one_hot | |
charges = frag_charges | |
fragment_only_mask = np.ones_like(charges) | |
pocket_mask = np.zeros_like(charges) | |
linker_mask = np.zeros_like(charges) | |
fragment_mask = np.ones_like(charges) | |
anchor_flags = np.zeros_like(charges) | |
if anchors is not None: | |
for anchor in anchors.split(','): | |
anchor_flags[int(anchor.strip()) - 1] = 1 | |
data = { | |
'name': row.name, | |
'X1': row.X1, | |
'fragment_path': row.fragment_path, | |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
'anchors': torch.tensor(anchor_flags, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
'num_atoms': len(positions) | |
} | |
for k in optional_keys: | |
if hasattr(row, k): | |
data[k] = getattr(row, k) | |
if pocket_conditioned: | |
data |= { | |
'X2': row.X2, | |
'protein_path': row.protein_path, | |
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), | |
} | |
dataset.extend([data] * n_samples) | |
ddpm.val_dataset = dataset | |
global_batch_size = min(n_samples, max_batch_size) | |
log.info(f'DiffLinker global batch size: {global_batch_size}') | |
dataloader = get_dataloader( | |
dataset, batch_size=global_batch_size, | |
collate_fn=collate_with_fragment_without_pocket_edges if pocket_conditioned else collate_with_fragment_edges | |
) | |
# df.drop(columns=['ligand_mol', 'protein_path'], inplace=True) | |
linking_dfs = [] | |
# Sampling | |
print('Sampling...') | |
# TODO: update linking_summary.csv per batch | |
for batch_i, data in tqdm(enumerate(dataloader), total=len(dataloader)): | |
effective_batch_size = len(data['positions']) | |
batch_data = { | |
'name': data['name'], | |
'X1': data['X1'], | |
'fragment_path': data['fragment_path'], | |
} | |
for k in optional_keys: | |
if k in data: | |
batch_data[k] = data[k] | |
if pocket_conditioned: | |
batch_data |= { | |
'X2': data['X2'], | |
'protein_path': data['protein_path'], | |
} | |
batch_df = pd.DataFrame(batch_data) | |
chain = None | |
node_mask = None | |
for i in range(5): | |
try: | |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
break | |
except FoundNaNException: | |
continue | |
if chain is None: | |
log.warning(f'Could not generate linker for batch {batch_i} in 5 attempts') | |
continue | |
x = chain[0][:, :, :ddpm.n_dims] | |
h = chain[0][:, :, ddpm.n_dims:] | |
# Put the molecule back to the initial orientation | |
if ddpm.center_of_mass == 'fragments': | |
if pocket_conditioned: | |
com_mask = data['fragment_only_mask'] | |
else: | |
com_mask = data['fragment_mask'] | |
else: | |
com_mask = data['anchors'] | |
pos_masked = data['positions'] * com_mask | |
N = com_mask.sum(1, keepdims=True) | |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N | |
x = x + mean * node_mask | |
if pocket_conditioned: | |
node_mask[torch.where(data['pocket_mask'])] = 0 | |
batch_df['one_hot'] = list(h.cpu()) | |
batch_df['positions'] = list(x.cpu()) | |
batch_df['node_mask'] = list(node_mask.cpu()) | |
linking_dfs.append(batch_df) | |
# for i in range(effective_batch_size): | |
# # # Save XYZ file and generate SMILES | |
# # out_xyz = Path(output_dir, f'{name}_{offset_idx+i}.xyz') | |
# # smiles = save_xyz_files(out_xyz, h[i], x[i], node_mask[i], is_geom=is_geom) | |
# # # Convert XYZ to SDF | |
# # out_sdf = Path(output_dir, name, f'output_{offset_idx+i}.sdf') | |
# # with open(os.devnull, 'w') as devnull: | |
# # subprocess.run(f'obabel {out_xyz} -O {out_sdf} -q', shell=True, stdout=devnull) | |
# # Save SDF file and generate SMILES | |
# out_sdf = Path(output_dir, f'{data["name"][i]}.sdf') | |
# smiles = save_sdf(out_sdf, h[i], x[i], node_mask[i], is_geom=is_geom) | |
# | |
# # Add experiment summary info | |
# batch_df['X1^'] = smiles | |
# batch_df['out_path'] = str(out_sdf) | |
# linking_dfs.append(batch_df) | |
# Teardown | |
del ddpm | |
torch.cuda.empty_cache() | |
if linking_dfs: | |
linking_summary_df = pd.concat(linking_dfs, ignore_index=True) | |
linking_summary_df['out_path'] = linking_summary_df.groupby('name').cumcount().apply( | |
lambda x: f"{x:0{len(str(linking_summary_df.groupby('name').cumcount().max()))}d}" | |
).radd(linking_summary_df['name'] + '_') + '.sdf' | |
linking_summary_df['X1^'] = linking_summary_df.parallel_apply( # parallel_apply bug | |
lambda x: save_sdf( | |
output_dir / x['out_path'], x['one_hot'], x['positions'], x['node_mask'], is_geom=is_geom | |
), axis=1 | |
) | |
# TODO add 'pocket_path' and 'distance' | |
linking_summary_df[ | |
['name', 'protein_path', 'fragment_path', 'X1', 'X1^', 'out_path'] | |
].to_csv(Path(output_dir.parent, 'linking_summary.csv'), index=False) | |
print(f'Saved experiment summary and generated molecules to {output_dir}') | |
else: | |
raise ValueError('No linkers generated.') | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
# Fragment docking settings | |
parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml') | |
parser.add_argument('--protein_ligand_csv', type=str, default=None, | |
help='Path to a .csv file specifying the input as described in the README. ' | |
'If this is not None, it will be used instead of the `X1` and `X2` parameters') | |
parser.add_argument('-n', '--name', type=str, default=None, | |
help='Name that the experiment will be saved with') | |
parser.add_argument('--X1', type=str, | |
help='Either a SMILES string or the path of a molecule file that rdkit can read') | |
parser.add_argument('--X2', type=str, | |
help='Either a FASTA sequence or the path of a protein for ESMFold') | |
parser.add_argument('-l', '--log', '--loglevel', type=str, default='INFO', dest="loglevel", | |
help='Log level. Default %(default)s') | |
parser.add_argument('--out_dir', type=str, default='results/', | |
help='Directory where the outputs will be written to') | |
parser.add_argument('--save_docking', action='store_true', default=True, | |
help='Save the intermediate docking results including SDF files and a summary CSV.') | |
parser.add_argument('--save_visualisation', action='store_true', default=False, | |
help='Save a pdb file with all of the steps of the reverse diffusion') | |
parser.add_argument('--samples_per_complex', type=int, default=10, | |
help='Number of samples to generate') | |
# parser.add_argument('--model_dir', type=str, default=None, | |
# help='Path to folder with trained score model and hyperparameters') | |
parser.add_argument('--score_ckpt', type=str, default='best_ema_inference_epoch_model.pt', | |
help='Checkpoint to use for the score model') | |
# parser.add_argument('--confidence_model_dir', type=str, default=None, | |
# help='Path to folder with trained confidence model and hyperparameters') | |
parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', | |
help='Checkpoint to use for the confidence model') | |
parser.add_argument('--n_poses', type=int, default=10, help='') | |
parser.add_argument('--no_final_step_noise', action='store_true', default=True, | |
help='Use no noise in the final step of the reverse diffusion') | |
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps') | |
parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, | |
help='Initial noise std proportion') | |
parser.add_argument('--choose_residue', action='store_true', default=False, help='') | |
parser.add_argument('--temp_sampling_tr', type=float, default=1.0) | |
parser.add_argument('--temp_psi_tr', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5) | |
parser.add_argument('--temp_sampling_rot', type=float, default=1.0) | |
parser.add_argument('--temp_psi_rot', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5) | |
parser.add_argument('--temp_sampling_tor', type=float, default=1.0) | |
parser.add_argument('--temp_psi_tor', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5) | |
parser.add_argument('--gnina_minimize', action='store_true', default=False, help='') | |
parser.add_argument('--gnina_path', type=str, default='gnina', help='') | |
parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', | |
help='') # To redirect gnina subprocesses stdouts from the terminal window | |
parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='') | |
parser.add_argument('--gnina_autobox_add', type=float, default=4.0) | |
parser.add_argument('--gnina_poses_to_optimize', type=int, default=1) | |
# Linker generation settings | |
# parser.add_argument('--fragments', action='store', type=str, required=True, | |
# help='Path to the file with input fragments' | |
# ) | |
# parser.add_argument( | |
# '--protein', action='store', type=str, required=True, | |
# help='Path to the file with the target protein' | |
# ) | |
parser.add_argument( | |
'--backbone_atoms_only', action='store_true', required=False, default=False, | |
help='Flag if to use only protein backbone atoms' | |
) | |
parser.add_argument( | |
'--linker_ckpt', action='store', type=str, | |
help='Path to the DiffLinker model' | |
) | |
parser.add_argument( | |
'--linker_size', action='store', type=str, default='0', | |
help='Linker size (int) or allowed size boundaries (comma-separated) or path to the size prediction model' | |
) | |
parser.add_argument( | |
'--n_linkers', action='store', type=int, required=False, default=5, | |
help='Number of linkers to generate' | |
) | |
parser.add_argument( | |
'--linker_steps', action='store', type=int, required=False, default=1000, | |
help='Number of denoising steps' | |
) | |
parser.add_argument( | |
'--anchors', action='store', type=str, required=False, default=None, | |
help='Comma-separated indices of anchor atoms ' | |
'(according to the order of atoms in the input fragments file, enumeration starts with 1)' | |
) | |
parser.add_argument( | |
'--linker_batch_size', action='store', type=int, required=False, | |
help='Max batch size for linker generation model' | |
) | |
parser.add_argument( | |
'--docking_batch_size', action='store', type=int, required=False, | |
help='Max batch size for fragment docking model' | |
) | |
parser.add_argument( | |
'--random_seed', action='store', type=int, required=False, default=None, | |
help='Random seed' | |
) | |
parser.add_argument( | |
'--robust', action='store_true', required=False, default=False, | |
help='Robust sampling modification' | |
) | |
parser.add_argument( | |
'--dock', action='store_true', default=False, | |
help='Fragment docking with DiffDock' | |
) | |
parser.add_argument( | |
'--link', action='store_true', default=False, | |
help='Linker generation with DiffLinker' | |
) | |
args = parser.parse_args() | |
if args.config: | |
config_dict = yaml.load(args.config, Loader=yaml.FullLoader) | |
arg_dict = args.__dict__ | |
for key, value in config_dict.items(): | |
# if isinstance(value, list): | |
# for v in value: | |
# arg_dict[key].append(v) | |
# else: | |
arg_dict[key] = value | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
experiment_name = f"{date_time}_{args.name}" | |
args.out_dir = Path(args.out_dir, experiment_name) | |
args.out_dir.mkdir(exist_ok=True, parents=True) | |
configure_logger(args.loglevel, logfile=args.out_dir / 'inference.log') | |
log = get_logger() | |
log.info(f"DiffFBDD will run on {device}") | |
docking_df = None | |
linking_df = None | |
if args.dock: | |
docking_df = dock_fragments( | |
protein_ligand_csv=args.protein_ligand_csv, | |
fragment_library=args.X1, protein_library=args.X2, out_dir=args.out_dir, | |
score_ckpt=args.score_ckpt, confidence_ckpt=args.confidence_ckpt, | |
inference_steps=args.inference_steps, n_poses=args.n_poses, docking_batch_size=args.docking_batch_size, | |
initial_noise_std_proportion=args.initial_noise_std_proportion, | |
no_final_step_noise=args.no_final_step_noise, | |
temp_sampling_tr=args.temp_sampling_tr, | |
temp_sampling_rot=args.temp_sampling_rot, | |
temp_sampling_tor=args.temp_sampling_tor, | |
temp_psi_tr=args.temp_psi_tr, | |
temp_psi_rot=args.temp_psi_rot, | |
temp_psi_tor=args.temp_psi_tor, | |
temp_sigma_data_tr=args.temp_sigma_data_tr, | |
temp_sigma_data_rot=args.temp_sigma_data_rot, | |
temp_sigma_data_tor=args.temp_sigma_data_tor, | |
save_docking=args.save_docking, device=device, | |
) | |
# linking_df = process_docking_results( | |
# docking_df, | |
# eps=args.eps, min_samples=args.min_samples, | |
# frag_dist_range=args.frag_dist_range, distance_type=args.distance_type | |
# ) | |
else: | |
df = pd.read_csv(args.protein_ligand_csv) | |
if 'ligand_conf_path' in df.columns: | |
docking_df = df | |
else: | |
linking_df = df | |
if args.link: | |
if docking_df is not None and linking_df is None: | |
linking_df = select_fragment_pairs( | |
docking_df, | |
top_pockets=args.top_pockets, | |
frag_dist_range=args.frag_dist_range, | |
confidence_threshold=args.confidence_threshold, | |
rmsd_threshold=args.rmsd_threshold, | |
out_dir=args.out_dir, | |
) | |
if linking_df is None or len(linking_df) == 0: | |
log.error('No eligible fragment pose pairs found for linking.') | |
sys.exit() | |
generate_linkers( | |
linking_df, | |
backbone_atoms_only=args.backbone_atoms_only, | |
output_dir=args.out_dir, | |
n_samples=args.n_linkers, | |
n_steps=args.linker_steps, | |
linker_size=args.linker_size, | |
anchors=args.anchors, | |
max_batch_size=args.linker_batch_size, | |
random_seed=args.random_seed, | |
robust=args.robust, | |
linker_ckpt=args.linker_ckpt, | |
size_ckpt=args.size_ckpt, | |
linker_condition=args.linker_condition, | |
device=device, | |
) |