Spaces:
Sleeping
Sleeping
import logging | |
import subprocess | |
import sys | |
from argparse import ArgumentParser, Namespace, FileType | |
import copy | |
import itertools | |
import os | |
from datetime import datetime | |
from pathlib import Path | |
from functools import partial, cache | |
import warnings | |
import yaml | |
from Bio.PDB import PDBParser | |
from sklearn.cluster import DBSCAN | |
from src import const | |
from src.datasets import ( | |
collate_with_fragment_without_pocket_edges, get_dataloader, get_one_hot, parse_molecule, ProteinConditionedDataset | |
) | |
from src.lightning import DDPM | |
from src.linker_size_lightning import SizeClassifier | |
from src.utils import set_deterministic, FoundNaNException | |
from src.visualizer import save_sdf | |
warnings.filterwarnings("ignore", category=DeprecationWarning, | |
message="(?s).*Pyarrow will become a required dependency of pandas.*") | |
import numpy as np | |
import pandas as pd | |
from pandarallel import pandarallel | |
import torch | |
from torch_geometric.loader import DataLoader | |
from Bio import SeqIO | |
from rdkit import RDLogger, Chem | |
from rdkit.Chem import RemoveAllHs | |
from utils.logging_utils import configure_logger, get_logger | |
from datasets.process_mols import create_mol_with_coords, read_molecule | |
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule | |
from utils.inference_utils import InferenceDataset | |
from utils.sampling import randomize_position, sampling | |
from utils.utils import get_model | |
from utils.visualise import PDBFile | |
from tqdm import tqdm | |
RDLogger.DisableLog('rdApp.*') | |
warnings.filterwarnings("ignore", category=UserWarning, | |
message="The TorchScript type system doesn't support instance-level annotations on" | |
" empty non-base types in `__init__`") | |
prody_logger = logging.getLogger(".prody") | |
prody_logger.setLevel(logging.ERROR) | |
nb_workers = os.cpu_count() | |
progress_bar = False | |
if hasattr(sys, 'gettrace') and sys.gettrace() is not None: # Debug mode | |
nb_workers = 1 | |
progress_bar = True | |
pandarallel.initialize(nb_workers=nb_workers, progress_bar=progress_bar) | |
def read_compound_library(file_path): | |
df = None | |
if file_path.suffix == '.csv': | |
df = pd.read_csv(file_path) | |
elif file_path.suffix == '.sdf': | |
supplier = Chem.SDMolSupplier(file_path, sanitize=False, removeHs=False) | |
# Convert to a dataframe | |
df = pd.DataFrame([{'X1': Chem.MolToSmiles(mol), 'ID1': mol.GetProp('_Name')} for mol in supplier]) | |
# Use InChiKey as ID1 if None | |
df.loc[df['ID1'].isna(), 'ID1'] = df.loc[ | |
df['ID1'].isna(), 'X1' | |
].apply(Chem.MolFromSmiles).apply(Chem.MolToInchiKey) | |
return df | |
def read_protein_library(file_path): | |
df = None | |
if file_path.suffix == '.csv': | |
df = pd.read_csv(file_path) | |
elif file_path.suffix == '.fasta': | |
records = list(SeqIO.parse(file_path, 'fasta')) | |
df = pd.DataFrame([{'X2': str(record.seq), 'ID2': record.id} for record in records]) | |
return df | |
def process_fragment_library(df): | |
""" | |
SMILES strings with separators (e.g., .) represent distinct molecular entities, such as ligands, ions, or | |
co-crystallized molecules. Splitting them ensures that each entity is treated individually, allowing focused | |
analysis of their roles in binding. Single atom fragments (e.g., counterions like [I-] or [Cl-] are irrelevant in | |
docking and are to be removed. This filtering focuses on structurally relevant fragments. | |
""" | |
# Get subset of rows with SMILES containing separators | |
fragmented_rows = df['X1'].str.contains('.', regex=False) | |
df_fragmented = df[fragmented_rows].copy() | |
# Split SMILES into lists and expand | |
df_fragmented['X1'] = df_fragmented['X1'].str.split('.') | |
df_fragmented = df_fragmented.explode('X1').reset_index(drop=True) | |
# Append fragment index as alphabet (A, B, C... AA, AB...) to ID1 for rows with fragmented SMILES | |
df_fragmented['ID1'] = df_fragmented.groupby('ID1').cumcount().apply(num_to_letter_code).radd( | |
df_fragmented['ID1'] + '_') | |
df = pd.concat([df[~fragmented_rows], df_fragmented]).sort_index().reset_index(drop=True) | |
df['mol'] = df['X1'].apply(read_molecule, remove_confs=True) | |
df = df.dropna(subset=['mol']) | |
# # Remove fragments with no carbon atoms | |
# df = df[df['mol'].swifter.apply(lambda mol: any(atom.GetSymbol() == 'C' for atom in mol.GetAtoms()))] | |
# Remove single-atom fragments | |
df = df[df['mol'].apply(lambda mol: mol.GetNumAtoms() > 1)] | |
# Canonicalize SMILES | |
df['X1'] = df['mol'].apply(lambda x: Chem.MolToSmiles(x)) | |
return df | |
def check_one_to_one(df, ID_column, X_column): | |
# Check for multiple X values for the same ID | |
id_to_x_conflicts = df.groupby(ID_column)[X_column].nunique() | |
conflicting_ids = id_to_x_conflicts[id_to_x_conflicts > 1] | |
# Check for multiple ID values for the same X | |
x_to_id_conflicts = df.groupby(X_column)[ID_column].nunique() | |
conflicting_xs = x_to_id_conflicts[x_to_id_conflicts > 1] | |
# Print conflicting mappings | |
if not conflicting_ids.empty: | |
print(f"Conflicting {ID_column} -> multiple {X_column}:") | |
for idx in conflicting_ids.index: | |
print(f"{ID_column}: {idx}, {X_column} values: {df[df[ID_column] == idx][X_column].unique()}") | |
if not conflicting_xs.empty: | |
print(f"Conflicting {X_column} -> multiple {ID_column}:") | |
for x in conflicting_xs.index: | |
print(f"{X_column}: {x}, {ID_column} values: {df[df[X_column] == x][ID_column].unique()}") | |
# Return whether the mappings are one-to-one | |
return conflicting_ids.empty and conflicting_xs.empty | |
def num_to_letter_code(n): | |
result = '' | |
while n >= 0: | |
result = chr(65 + (n % 26)) + result | |
n = n // 26 - 1 | |
return result | |
def dock_fragments(args): | |
with open(Path(args.score_ckpt).parent / 'model_parameters.yml') as f: | |
score_model_args = Namespace(**yaml.full_load(f)) | |
if args.confidence_ckpt is not None: | |
with open(Path(args.confidence_ckpt).parent / 'model_parameters.yml') as f: | |
confidence_args = Namespace(**yaml.full_load(f)) | |
log.info(f"DiffFragDock will run on {device}") | |
docking_out_dir = Path(args.out_dir, 'docking') | |
docking_out_dir.mkdir(parents=True, exist_ok=True) | |
if args.protein_ligand_csv is not None: | |
csv_path = Path(args.protein_ligand_csv) | |
assert csv_path.is_file(), f"File {args.protein_ligand_csv} does not exist" | |
df = pd.read_csv(csv_path) | |
df = process_fragment_library(df) | |
else: | |
assert args.X1 is not None and args.X2 is not None, "Either a .csv file or `X1` and `X2` must be provided." | |
compound_df = pd.DataFrame(columns=['X1', 'ID1']) | |
if Path(args.X1).is_file(): | |
compound_path = Path(args.X1) | |
if compound_path.suffix in ['.csv', '.sdf']: | |
compound_df[['X1', 'ID1']] = read_compound_library(compound_path)[['X1', 'ID1']] | |
else: | |
compound_df['X1'] = [compound_path] | |
compound_df['ID1'] = [compound_path.stem] | |
else: | |
compound_df['X1'] = [args.X1] | |
compound_df['ID1'] = 'compound_0' | |
compound_df.dropna(subset=['X1'], inplace=True) | |
compound_df.loc[compound_df['ID1'].isna(), 'ID1'] = compound_df.loc[compound_df['ID1'].isna(), 'X1'].apply( | |
lambda x: Chem.MolToInchiKey(Chem.MolFromSmiles(x)) | |
) | |
protein_df = pd.DataFrame(columns=['X2', 'ID2']) | |
if Path(args.X2).is_file(): | |
protein_path = Path(args.X2) | |
if protein_path.suffix in ['.csv', '.fasta']: | |
protein_df[['X2', 'ID2']] = read_protein_library(protein_path)[['X2', 'ID2']] | |
else: | |
protein_df['protein_path'] = [protein_path] | |
protein_df['ID2'] = [protein_path.stem] | |
else: | |
protein_df['X2'] = [args.X2] | |
protein_df['ID2'] = 'protein_0' | |
protein_df.dropna(subset=['X2'], inplace=True) | |
protein_df.loc[protein_df['ID2'].isna(), 'ID2'] = [ | |
f"protein_{i}" for i in range(protein_df['ID2'].isna().sum()) | |
] | |
compound_df = process_fragment_library(compound_df) | |
df = compound_df.merge(protein_df, how='cross') | |
# Identify duplicates based on 'X1' and 'X2' | |
duplicates = df[df.duplicated(subset=['X1', 'X2'], keep=False)] | |
if not duplicates.empty: | |
print("Duplicate rows based on columns 'X1' and 'X2':\n", duplicates[['ID1', 'X1', 'ID2', 'X2']]) | |
print("Keeping the first occurrence of each duplicate.") | |
df = df.drop_duplicates(subset=['X1', 'X2']) | |
df['name'] = df['ID2'] + '-' + df['ID1'] | |
df = df.replace({pd.NA: None}) | |
# Check unique mappings between IDn and Xn | |
assert check_one_to_one(df, 'ID1', 'X1'), "ID1-X1 mapping is not one-to-one." | |
assert check_one_to_one(df, 'ID2', 'X2'), "ID2-X2 mapping is not one-to-one." | |
""" | |
Docking phase | |
""" | |
# preprocessing of complexes into geometric graphs | |
test_dataset = InferenceDataset( | |
df=df, out_dir=args.out_dir, | |
lm_embeddings=True, | |
receptor_radius=score_model_args.receptor_radius, | |
remove_hs=True, # score_model_args.remove_hs, | |
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, | |
all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius, | |
atom_max_neighbors=score_model_args.atom_max_neighbors, | |
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') | |
else not score_model_args.not_knn_only_graph | |
) | |
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) | |
if args.confidence_ckpt is not None and not confidence_args.use_original_model_cache: | |
log.info('Confidence model uses different type of graphs than the score model. ' | |
'Loading (or creating if not existing) the data for the confidence model now.') | |
confidence_test_dataset = InferenceDataset( | |
df=df, out_dir=args.out_dir, | |
lm_embeddings=True, | |
receptor_radius=confidence_args.receptor_radius, | |
remove_hs=True, # confidence_args.remove_hs, | |
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, | |
all_atoms=confidence_args.all_atoms, | |
atom_radius=confidence_args.atom_radius, | |
atom_max_neighbors=confidence_args.atom_max_neighbors, | |
precomputed_lm_embeddings=test_dataset.lm_embeddings, | |
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') | |
else not score_model_args.not_knn_only_graph | |
) | |
else: | |
confidence_test_dataset = None | |
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) | |
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model) | |
state_dict = torch.load(Path(args.score_ckpt), map_location='cpu', weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
model = model.to(device) | |
model.eval() | |
if args.confidence_ckpt is not None: | |
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, | |
confidence_mode=True, old=args.old_confidence_model) | |
state_dict = torch.load(Path(args.confidence_ckpt), map_location='cpu', weights_only=True) | |
confidence_model.load_state_dict(state_dict, strict=True) | |
confidence_model = confidence_model.to(device) | |
confidence_model.eval() | |
else: | |
confidence_model = None | |
confidence_args = None | |
tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta') | |
failures, skipped = 0, 0 | |
samples_per_complex = args.samples_per_complex | |
test_ds_size = len(test_dataset) | |
df = test_loader.dataset.df | |
docking_dfs = [] | |
log.info(f'Size of fragment dataset: {test_ds_size}') | |
for idx, orig_complex_graph in tqdm(enumerate(test_loader), total=test_ds_size): | |
if not orig_complex_graph.success[0]: | |
skipped += 1 | |
log.warning( | |
f"The test dataset did not contain {df['name'].iloc[idx]}" | |
f" for {df['X1'].iloc[idx]} and {df['X2'].iloc[idx]}. We are skipping this complex.") | |
continue | |
try: | |
if confidence_test_dataset is not None: | |
confidence_complex_graph = confidence_test_dataset[idx] | |
if not confidence_complex_graph.success: | |
skipped += 1 | |
log.warning( | |
f"The confidence dataset did not contain {orig_complex_graph.name}. We are skipping this complex.") | |
continue | |
confidence_data_list = [copy.deepcopy(confidence_complex_graph) for _ in range(samples_per_complex)] | |
else: | |
confidence_data_list = None | |
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(samples_per_complex)] | |
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max, | |
initial_noise_std_proportion=args.initial_noise_std_proportion, | |
choose_residue=args.choose_residue) | |
lig = orig_complex_graph.mol[0] | |
# initialize visualisation | |
if args.save_visualisation: | |
visualization_list = [] | |
for graph in data_list: | |
pdb = PDBFile(lig) | |
pdb.add(lig, 0, 0) | |
pdb.add((orig_complex_graph['ligand'].pos + orig_complex_graph.original_center).detach().cpu(), 1, | |
0) | |
pdb.add((graph['ligand'].pos + graph.original_center).detach().cpu(), part=1, order=1) | |
visualization_list.append(pdb) | |
else: | |
visualization_list = None | |
# run reverse diffusion | |
data_list, confidence = sampling(data_list=data_list, model=model, | |
inference_steps=args.actual_steps if args.actual_steps is not None | |
else args.inference_steps, | |
tr_schedule=tr_schedule, rot_schedule=tr_schedule, | |
tor_schedule=tr_schedule, | |
device=device, t_to_sigma=t_to_sigma, model_args=score_model_args, | |
visualization_list=visualization_list, confidence_model=confidence_model, | |
confidence_data_list=confidence_data_list, | |
confidence_model_args=confidence_args, | |
batch_size=args.n_poses, no_final_step_noise=args.no_final_step_noise, | |
temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot, | |
args.temp_sampling_tor], | |
temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor], | |
temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot, | |
args.temp_sigma_data_tor]) | |
ligand_pos = np.asarray( | |
[complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for | |
complex_graph in data_list] | |
) | |
# save predictions | |
n_samples = len(confidence) | |
sample_df = pd.DataFrame([df.iloc[idx]] * n_samples) | |
confidence = confidence[:, 0].cpu().numpy() | |
sample_df['confidence'] = confidence | |
if args.save_docking: | |
sample_df['path'] = [ | |
Path( | |
docking_out_dir, f"{df['name'].iloc[idx]}-confidence{confidence[i]:.2f}.sdf" | |
) for i in range(n_samples) | |
] | |
sample_df['ligand_mol']= [ | |
create_mol_with_coords( | |
mol=RemoveAllHs(copy.deepcopy(lig)), | |
new_coords=pos, | |
path=sample_df['path'].iloc[i] if args.save_docking else None | |
) for i, pos in enumerate(ligand_pos) | |
] | |
# sample_df['ligand_pos'] = list(ligand_pos) | |
docking_dfs.append(sample_df) | |
# write_dir = f"{args.out_dir}/{df['name'].iloc[idx]}" | |
# for rank, pos in enumerate(ligand_pos): | |
# mol_pred = copy.deepcopy(lig) | |
# if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred) | |
# if rank == 0: write_mol_with_coords(mol_pred, pos, Path(write_dir, f'rank{rank + 1}.sdf')) | |
# write_mol_with_coords(mol_pred, pos, | |
# Path(write_dir, f'rank{rank + 1}_confidence{confidence[rank]:.2f}.sdf')) | |
# save visualisation frames | |
# if args.save_visualisation: | |
# if confidence is not None: | |
# for rank, batch_idx in enumerate(re_order): | |
# visualization_list[batch_idx].write( | |
# Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb')) | |
# else: | |
# for rank, batch_idx in enumerate(ligand_pos): | |
# visualization_list[batch_idx].write( | |
# Path(write_dir, f'rank{rank + 1}_reverseprocess.pdb')) | |
except Exception as e: | |
log.warning("Failed on", orig_complex_graph["name"], e) | |
failures += 1 | |
# Tear down DiffDock models and datasets | |
model.cpu() | |
del model | |
if confidence_model is not None: | |
confidence_model.cpu() | |
del confidence_model | |
del test_dataset | |
if confidence_test_dataset is not None: | |
del confidence_test_dataset | |
del test_loader | |
docking_df = pd.concat(docking_dfs, ignore_index=True) | |
result_msg = f""" | |
Failed for {failures} / {test_ds_size} complexes. | |
Skipped {skipped} / {test_ds_size} complexes. | |
""" | |
if failures or skipped: | |
log.warning(result_msg) | |
else: | |
log.info(result_msg) | |
log.info(f"Results saved in {docking_out_dir}") | |
return docking_df | |
def calculate_mol_atomic_distances(mol1, mol2, distance_type='min'): | |
mol1_coords = [ | |
mol1.GetConformer().GetAtomPosition(i) for i in range(mol1.GetNumAtoms()) | |
] | |
mol2_coords = [ | |
mol2.GetConformer().GetAtomPosition(i) for i in range(mol2.GetNumAtoms()) | |
] | |
# Ensure numpy arrays | |
mol1_coords = np.array(mol1_coords) | |
mol2_coords = np.array(mol2_coords) | |
# Compute pairwise distances between carbon atoms | |
atom_pairwise_distances = np.linalg.norm(mol1_coords[:, None, :] - mol2_coords[None, :, :], axis=-1) | |
# if np.any(np.isnan(atom_pairwise_distances)): | |
# import pdb | |
# pdb.set_trace() # Trigger a breakpoint if NaN is found | |
if distance_type == 'min': | |
return atom_pairwise_distances.min() | |
elif distance_type == 'mean': | |
return atom_pairwise_distances.mean() | |
elif distance_type is None: | |
return atom_pairwise_distances | |
else: | |
raise ValueError(f"Unsupported distance_type: {distance_type}") | |
def process_docking_results( | |
df, | |
eps=5, # Distance threshold for DBSCAN clustering | |
min_samples=5, # Minimum number of samples for a cluster (enrichment) | |
frag_dist_range=(2, 5), # Distance range for fragment linking | |
distance_type='min', # Type of distance to compute between fragments | |
): | |
assert len(frag_dist_range) == 2, 'Distance range must be a tuple of two values in Angstroms (Å).' | |
frag_dist_range = sorted(frag_dist_range) | |
# The mols in df should have been processed to have no explicit hydrogens, except heavy hydrogen isotopes. | |
docking_summaries = [] # For saving intermediate docking results | |
fragment_combos = [] # Fragment pairs for the linking step | |
# 1. Cluster docking poses | |
# Compute pairwise distances of molecules defined by the closest non-heavy atoms | |
for protein, protein_df in df.groupby('X2'): | |
protein_id = protein_df['ID2'].iloc[0] | |
protein_path = protein_df['protein_path'].iloc[0] | |
protein_df['index'] = protein_df.index | |
log.info(f'Processing docking results for {protein_id}...') | |
protein_fragment_combos = [] | |
dist_matrix = np.stack( | |
protein_df['ligand_mol'].parallel_apply( | |
lambda mol1: [ | |
calculate_mol_atomic_distances(mol1, mol2, distance_type=distance_type) | |
for mol2 in protein_df['ligand_mol'] | |
] | |
) | |
) | |
# Perform DBSCAN clustering | |
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed') | |
protein_df['cluster'] = dbscan.fit_predict(dist_matrix) | |
protein_df = protein_df.sort_values( | |
by=['X1', 'cluster', 'confidence'], ascending=[True, True, False] | |
) | |
# Add conformer number to ID1 | |
protein_df.groupby('ID1').cumcount().astype(str).radd(protein_df['ID1'] + '_') | |
if args.save_docking: | |
docking_summaries.append( | |
protein_df[['name', 'ID2', 'X2', 'ID1', 'X1', 'cluster', 'confidence', 'path']] | |
) | |
# Filter out outlier poses | |
protein_df = protein_df[protein_df['cluster'] != -1] | |
# Keep only the highest confidence pose per protein per ligand per cluster | |
protein_df = protein_df.groupby(['X1', 'cluster']).first().reset_index() | |
# 2. Create fragment-linking pairs | |
for cluster, cluster_df in protein_df.groupby('cluster'): | |
if len(cluster_df) > 1: # Skip clusters with only one pose | |
pairs = list(itertools.combinations(cluster_df['index'], 2)) | |
for i, j in pairs: | |
row1 = cluster_df[cluster_df['index'] == i].iloc[0] | |
row2 = cluster_df[cluster_df['index'] == j].iloc[0] | |
dist = dist_matrix[i, j] | |
# Check if intermolecular distance is within the range | |
if frag_dist_range[0] < dist < frag_dist_range[1]: | |
combined_smiles = f"{row1['X1']}.{row2['X1']}" | |
combined_mol = Chem.CombineMols(row1['ligand_mol'], row2['ligand_mol']) | |
complex_name = f"{protein_id}-{row1['ID1']}-{row2['ID1']}" | |
ligand_path = f"{row1['path']},{row2['path']}" | |
protein_fragment_combos.append( | |
(complex_name, protein, protein_path, combined_smiles, ligand_path, combined_mol, dist) | |
) | |
log.info(f'Number of fragment pairs for {protein_id}: {len(protein_fragment_combos)}.') | |
fragment_combos.extend(protein_fragment_combos) | |
# Save intermediate docking results | |
if args.save_docking: | |
docking_summary_df = pd.concat(docking_summaries, ignore_index=True) | |
docking_summary_df.to_csv(Path(args.out_dir, 'docking_summary.csv'), index=False) | |
log.info(f'Saved intermediate docking results to {args.out_dir}') | |
# Convert fragment pair results to DataFrame | |
if fragment_combos: | |
linking_df = pd.DataFrame( | |
fragment_combos, columns=['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'ligand_mol', 'distance'] | |
) | |
linking_df[ | |
['name', 'X2', 'protein_path', 'X1', 'ligand_path', 'distance'] | |
].to_csv(Path(args.out_dir, 'linking_summary.csv'), index=False) | |
return linking_df | |
else: | |
raise ValueError('No eligible fragment pairs found for linking.') | |
def get_pocket(mol, pdb_path, backbone_atoms_only=False): | |
struct = PDBParser().get_structure('', pdb_path) | |
residue_ids = [] | |
atom_coords = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
for atom in residue.get_atoms(): | |
atom_coords.append(atom.get_coord()) | |
residue_ids.append(resid) | |
residue_ids = np.array(residue_ids) | |
atom_coords = np.array(atom_coords) | |
mol_atom_coords = mol.GetConformer().GetPositions() | |
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) | |
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) | |
pocket_coords = [] | |
pocket_types = [] | |
for residue in struct.get_residues(): | |
resid = residue.get_id()[1] | |
if resid not in contact_residues: | |
continue | |
for atom in residue.get_atoms(): | |
atom_name = atom.get_name() | |
atom_type = atom.element.upper() | |
atom_coord = atom.get_coord() | |
if backbone_atoms_only and atom_name not in {'N', 'CA', 'C', 'O'}: | |
continue | |
pocket_coords.append(atom_coord.tolist()) | |
pocket_types.append(atom_type) | |
pocket_pos = [] | |
pocket_one_hot = [] | |
pocket_charges = [] | |
for coord, atom_type in zip(pocket_coords, pocket_types): | |
if atom_type not in const.GEOM_ATOM2IDX.keys(): | |
continue | |
pocket_pos.append(coord) | |
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
pocket_pos = np.array(pocket_pos) | |
pocket_one_hot = np.array(pocket_one_hot) | |
pocket_charges = np.array(pocket_charges) | |
return pocket_pos, pocket_one_hot, pocket_charges | |
def generate_linker( | |
df, backbone_atoms_only, model, | |
output_dir, n_samples, n_steps, linker_size, anchors, max_batch_size, random_seed | |
): | |
# Setup | |
if random_seed is not None: | |
set_deterministic(random_seed) | |
output_dir = Path(output_dir, 'linking') | |
output_dir.mkdir(exist_ok=True, parents=True) | |
if linker_size.isdigit(): | |
print(f'Will generate linkers with {linker_size} atoms') | |
linker_size = int(linker_size) | |
def sample_fn(_data): | |
return torch.ones(_data['positions'].shape[0], device=device, dtype=const.TORCH_INT) * linker_size | |
else: | |
boundaries = [x.strip() for x in linker_size.split(',')] | |
if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit(): | |
left = int(boundaries[0]) | |
right = int(boundaries[1]) | |
print(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})') | |
def sample_fn(_data): | |
shape = len(_data['positions']), | |
return torch.randint(left, right + 1, shape, device=device, dtype=const.TORCH_INT) | |
else: | |
print(f'Will generate linkers with sampled numbers of atoms') | |
size_classifier = SizeClassifier.load_from_checkpoint(linker_size, map_location=device).eval().to(device) | |
def sample_fn(_data): | |
out, _ = size_classifier.forward(_data, return_loss=False, with_pocket=True, adjust_shape=True) | |
probabilities = torch.softmax(out, dim=1) | |
distribution = torch.distributions.Categorical(probs=probabilities) | |
samples = distribution.sample() | |
sizes = [] | |
for label in samples.detach().cpu().numpy(): | |
sizes.append(size_classifier.linker_id2size[label]) | |
sizes = torch.tensor(sizes, device=samples.device, dtype=const.TORCH_INT) | |
return sizes | |
if n_steps is not None: | |
model.edm.T = n_steps | |
if model.center_of_mass == 'anchors' and anchors is None: | |
print( | |
'Please pass anchor atoms indices ' | |
'or use another DiffLinker model that does not require information about anchors' | |
) | |
return | |
cached_parse_molecule = cache(parse_molecule) | |
dataset = [] | |
for i, row in df.iterrows(): | |
mol = row['ligand_mol'] # Hs already removed | |
# Parsing fragments data | |
frag_pos, frag_one_hot, frag_charges = cached_parse_molecule(mol, is_geom=ddpm.is_geom) | |
# Parsing pocket data | |
pocket_pos, pocket_one_hot, pocket_charges = get_pocket(mol, row['protein_path'], backbone_atoms_only) | |
positions = np.concatenate([frag_pos, pocket_pos], axis=0) | |
one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0) | |
charges = np.concatenate([frag_charges, pocket_charges], axis=0) | |
anchor_flags = np.zeros_like(charges) | |
if anchors is not None: | |
for anchor in anchors.split(','): | |
anchor_flags[int(anchor.strip()) - 1] = 1 | |
fragment_only_mask = np.concatenate([ | |
np.ones_like(frag_charges), | |
np.zeros_like(pocket_charges), | |
]) | |
pocket_mask = np.concatenate([ | |
np.zeros_like(frag_charges), | |
np.ones_like(pocket_charges), | |
]) | |
linker_mask = np.concatenate([ | |
np.zeros_like(frag_charges), | |
np.zeros_like(pocket_charges), | |
]) | |
fragment_mask = np.concatenate([ | |
np.ones_like(frag_charges), | |
np.ones_like(pocket_charges), | |
]) | |
dataset.extend([{ | |
'name': row['name'], | |
'X1': row['X1'], | |
'X2': row['X2'], | |
'protein_path': row['protein_path'], | |
'ligand_path': row['ligand_path'], | |
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
'anchors': torch.tensor(anchor_flags, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), | |
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), | |
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
'num_atoms': len(positions) | |
}] * n_samples) | |
dataset = ProteinConditionedDataset(data=dataset) | |
ddpm.val_dataset = dataset | |
global_batch_size = min(n_samples, max_batch_size) | |
dataloader = get_dataloader( | |
dataset, batch_size=global_batch_size, collate_fn=collate_with_fragment_without_pocket_edges | |
) | |
# df.drop(columns=['ligand_mol', 'protein_path'], inplace=True) | |
linking_dfs = [] | |
# Sampling | |
print('Sampling...') | |
# TODO: update linking_summary.csv per batch | |
for batch_i, data in tqdm(enumerate(dataloader), total=len(dataloader)): | |
effective_batch_size = len(data['positions']) | |
complex_name = data['name'][0] | |
batch_df = pd.DataFrame({ | |
'name': data['name'], | |
'X1': data['X1'], | |
'X2': data['X2'], | |
'protein_path': data['protein_path'], | |
'ligand_path': data['ligand_path'], | |
}) | |
chain = None | |
node_mask = None | |
for i in range(5): | |
try: | |
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
break | |
except FoundNaNException: | |
continue | |
if chain is None: | |
log.warning(f'Could not generate linker for {complex_name} in 5 attempts') | |
continue | |
x = chain[0][:, :, :ddpm.n_dims] | |
h = chain[0][:, :, ddpm.n_dims:] | |
# Put the molecule back to the initial orientation | |
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] | |
pos_masked = data['positions'] * com_mask | |
N = com_mask.sum(1, keepdims=True) | |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N | |
x = x + mean * node_mask | |
node_mask[torch.where(data['pocket_mask'])] = 0 | |
batch_df['out_path'] = [Path(output_dir, f'{complex_name}_{i}.sdf') for i in range(effective_batch_size)] | |
batch_df['one_hot'] = list(h.cpu()) | |
batch_df['positions'] = list(x.cpu()) | |
batch_df['node_mask'] = list(node_mask.cpu()) | |
batch_df['X1^'] = batch_df.parallel_apply( | |
lambda row: save_sdf( | |
row['out_path'], row['one_hot'], row['positions'], row['node_mask'], is_geom=ddpm.is_geom | |
), axis=1 | |
) | |
linking_dfs.append(batch_df[['name', 'protein_path', 'X2', 'ligand_path', 'X1', 'X1^', 'out_path']]) | |
# for i in range(effective_batch_size): | |
# # # Save XYZ file and generate SMILES | |
# # out_xyz = Path(output_dir, f'{name}_{offset_idx+i}.xyz') | |
# # smiles = save_xyz_files(out_xyz, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom) | |
# # # Convert XYZ to SDF | |
# # out_sdf = Path(output_dir, name, f'output_{offset_idx+i}.sdf') | |
# # with open(os.devnull, 'w') as devnull: | |
# # subprocess.run(f'obabel {out_xyz} -O {out_sdf} -q', shell=True, stdout=devnull) | |
# # Save SDF file and generate SMILES | |
# out_sdf = Path(output_dir, f'{data["name"][i]}.sdf') | |
# smiles = save_sdf(out_sdf, h[i], x[i], node_mask[i], is_geom=ddpm.is_geom) | |
# | |
# # Add experiment summary info | |
# batch_df['X1^'] = smiles | |
# batch_df['out_path'] = str(out_sdf) | |
# linking_dfs.append(batch_df) | |
if linking_dfs: | |
linking_summary_df = pd.concat(linking_dfs, ignore_index=True) | |
linking_summary_df.to_csv(Path(output_dir.parent, 'linking_summary.csv'), index=False) | |
print(f'Saved experiment summary and generated molecules to {output_dir}') | |
else: | |
raise ValueError('No linkers generated.') | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
# Fragment docking settings | |
parser.add_argument('--config', type=FileType(mode='r'), default='default_inference_args.yaml') | |
parser.add_argument('--protein_ligand_csv', type=str, default=None, | |
help='Path to a .csv file specifying the input as described in the README. ' | |
'If this is not None, it will be used instead of the `X1` and `X2` parameters') | |
parser.add_argument('-n', '--name', type=str, default=None, | |
help='Name that the experiment will be saved with') | |
parser.add_argument('--X1', type=str, | |
help='Either a SMILES string or the path of a molecule file that rdkit can read') | |
parser.add_argument('--X2', type=str, | |
help='Either a FASTA sequence or the path of a protein for ESMFold') | |
parser.add_argument('-l', '--log', '--loglevel', type=str, default='INFO', dest="loglevel", | |
help='Log level. Default %(default)s') | |
parser.add_argument('--out_dir', type=str, default='results/', | |
help='Directory where the outputs will be written to') | |
parser.add_argument('--save_docking', action='store_true', default=True, | |
help='Save the intermediate docking results including SDF files and a summary CSV.') | |
parser.add_argument('--save_visualisation', action='store_true', default=False, | |
help='Save a pdb file with all of the steps of the reverse diffusion') | |
parser.add_argument('--samples_per_complex', type=int, default=10, | |
help='Number of samples to generate') | |
# parser.add_argument('--model_dir', type=str, default=None, | |
# help='Path to folder with trained score model and hyperparameters') | |
parser.add_argument('--score_ckpt', type=str, default='best_ema_inference_epoch_model.pt', | |
help='Checkpoint to use for the score model') | |
# parser.add_argument('--confidence_model_dir', type=str, default=None, | |
# help='Path to folder with trained confidence model and hyperparameters') | |
parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', | |
help='Checkpoint to use for the confidence model') | |
parser.add_argument('--n_poses', type=int, default=10, help='') | |
parser.add_argument('--no_final_step_noise', action='store_true', default=True, | |
help='Use no noise in the final step of the reverse diffusion') | |
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps') | |
parser.add_argument('--actual_steps', type=int, default=None, | |
help='Number of denoising steps that are actually performed') | |
parser.add_argument('--old_score_model', action='store_true', default=False, help='') | |
parser.add_argument('--old_confidence_model', action='store_true', default=True, help='') | |
parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, | |
help='Initial noise std proportion') | |
parser.add_argument('--choose_residue', action='store_true', default=False, help='') | |
parser.add_argument('--temp_sampling_tr', type=float, default=1.0) | |
parser.add_argument('--temp_psi_tr', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5) | |
parser.add_argument('--temp_sampling_rot', type=float, default=1.0) | |
parser.add_argument('--temp_psi_rot', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5) | |
parser.add_argument('--temp_sampling_tor', type=float, default=1.0) | |
parser.add_argument('--temp_psi_tor', type=float, default=0.0) | |
parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5) | |
parser.add_argument('--gnina_minimize', action='store_true', default=False, help='') | |
parser.add_argument('--gnina_path', type=str, default='gnina', help='') | |
parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', | |
help='') # To redirect gnina subprocesses stdouts from the terminal window | |
parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='') | |
parser.add_argument('--gnina_autobox_add', type=float, default=4.0) | |
parser.add_argument('--gnina_poses_to_optimize', type=int, default=1) | |
# Linker generation settings | |
# parser.add_argument('--fragments', action='store', type=str, required=True, | |
# help='Path to the file with input fragments' | |
# ) | |
# parser.add_argument( | |
# '--protein', action='store', type=str, required=True, | |
# help='Path to the file with the target protein' | |
# ) | |
parser.add_argument( | |
'--backbone_atoms_only', action='store_true', required=False, default=False, | |
help='Flag if to use only protein backbone atoms' | |
) | |
parser.add_argument( | |
'--linker_ckpt', action='store', type=str, | |
help='Path to the DiffLinker model' | |
) | |
parser.add_argument( | |
'--linker_size', action='store', type=str, | |
help='Linker size (int) or allowed size boundaries (comma-separated) or path to the size prediction model' | |
) | |
parser.add_argument( | |
'--n_linkers', action='store', type=int, required=False, default=5, | |
help='Number of linkers to generate' | |
) | |
parser.add_argument( | |
'--n_steps', action='store', type=int, required=False, default=1000, | |
help='Number of denoising steps' | |
) | |
parser.add_argument( | |
'--anchors', action='store', type=str, required=False, default=None, | |
help='Comma-separated indices of anchor atoms ' | |
'(according to the order of atoms in the input fragments file, enumeration starts with 1)' | |
) | |
parser.add_argument( | |
'--max_batch_size', action='store', type=int, required=False, default=16, | |
help='Max batch size' | |
) | |
parser.add_argument( | |
'--random_seed', action='store', type=int, required=False, default=None, | |
help='Random seed' | |
) | |
parser.add_argument( | |
'--robust', action='store_true', required=False, default=False, | |
help='Robust sampling modification' | |
) | |
parser.add_argument( | |
'--dock', action='store_true', default=False, | |
help='Fragment docking with DiffDock' | |
) | |
parser.add_argument( | |
'--link', action='store_true', default=False, | |
help='Linker generation with DiffLinker' | |
) | |
args = parser.parse_args() | |
if args.config: | |
config_dict = yaml.load(args.config, Loader=yaml.FullLoader) | |
arg_dict = args.__dict__ | |
for key, value in config_dict.items(): | |
if isinstance(value, list): | |
for v in value: | |
arg_dict[key].append(v) | |
else: | |
arg_dict[key] = value | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
configure_logger(args.loglevel) | |
log = get_logger() | |
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
experiment_name = f"{date_time}_{args.name}" | |
args.out_dir = Path(args.out_dir, experiment_name) | |
if args.dock: | |
docking_df = dock_fragments(args) | |
linking_df = process_docking_results( | |
docking_df, | |
eps=args.eps, min_samples=args.min_samples, | |
frag_dist_range=args.frag_dist_range, distance_type=args.distance_type | |
) | |
if args.link: | |
ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device) | |
generate_linker( | |
linking_df, | |
backbone_atoms_only=args.backbone_atoms_only, | |
model=ddpm, | |
output_dir=args.out_dir, | |
n_samples=args.n_linkers, | |
n_steps=args.n_steps, | |
linker_size=args.linker_size, | |
anchors=args.anchors, | |
max_batch_size=args.max_batch_size, | |
random_seed=args.random_seed, | |
) | |
if args.link: | |
linking_df = pd.read_csv(args.protein_ligand_csv) | |
ddpm = DDPM.load_from_checkpoint(args.linker_ckpt, map_location=device, robust=args.robust).eval().to(device) | |
generate_linker( | |
linking_df, | |
backbone_atoms_only=args.backbone_atoms_only, | |
model=ddpm, | |
output_dir=args.out_dir, | |
n_samples=args.n_linkers, | |
n_steps=args.n_steps, | |
linker_size=args.linker_size, | |
anchors=args.anchors, | |
max_batch_size=args.max_batch_size, | |
random_seed=args.random_seed, | |
) | |