Spaces:
Sleeping
Sleeping
from typing import List, Optional, Tuple, Any, Dict | |
import functools | |
import pandas as pd | |
import numpy as np | |
from tqdm import tqdm | |
from sklearn.cluster import AgglomerativeClustering, KMeans | |
from scipy.stats import skew | |
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score | |
from rdkit import Chem, DataStructs | |
from rdkit.Chem import rdFingerprintGenerator | |
from protac_splitter.graphs.utils import get_fp, numpy_to_rdkit_fp | |
from protac_splitter.chemoinformatics import remove_dummy_atoms | |
def get_umap_clusters_fp(fp_list: List[str], n_clusters: int = 7) -> np.ndarray: | |
""" | |
Cluster a list of SMILES strings using the umap clustering algorithm. | |
From Scaffold Splits Overestimate Virtual Screening Performance | |
https://arxiv.org/abs/2406.00873 | |
Args: | |
fp_list (List[str]): List of SMILES strings. | |
n_clusters (int): The number of clusters to use for clustering. | |
Returns: | |
np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list. | |
""" | |
ac = AgglomerativeClustering(n_clusters=n_clusters) | |
ac.fit_predict(np.stack(fp_list)) | |
return ac.labels_ | |
def get_kmeans_clusters_fp(fp_list: List[str], n_clusters: int = 10, return_centroids: bool = False) -> np.ndarray: | |
""" | |
Cluster a list of SMILES strings using the KMeans clustering algorithm. | |
Args: | |
fp_list (List[str]): List of SMILES strings. | |
n_clusters (int): The number of clusters to use for clustering. | |
return_centroids (bool): If True, return the cluster centroids as well. | |
Returns: | |
np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list. | |
""" | |
km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=42, max_iter=1000) | |
if return_centroids: | |
km.fit(np.stack(fp_list)) | |
return km.labels_, km.cluster_centers_ | |
return km.fit_predict(np.stack(fp_list)) | |
def evaluate_clusters(X: np.array, clusters: np.ndarray) -> Dict[str, float]: | |
""" Compute clustering metrics and assess cluster size distribution. | |
Args: | |
X (np.array): The input data used for clustering. | |
clusters (np.ndarray): The cluster labels for each data point in X. | |
Returns: | |
Dict[str, float]: A dictionary containing various clustering metrics: | |
- silhouette: Silhouette score of the clustering. | |
- davies_bouldin: Davies-Bouldin index of the clustering. | |
- calinski_harabasz: Calinski-Harabasz index of the clustering. | |
- avg_cluster_size: Average size of clusters. | |
- avg_cluster_data_ratio: Ratio of average cluster size to total data size. | |
- std_cluster_size: Standard deviation of cluster sizes. | |
- min_cluster_size: Minimum size of clusters. | |
- median_cluster_size: Median size of clusters. | |
- max_cluster_size: Maximum size of clusters. | |
- cluster_size_skewness: Skewness of cluster sizes indicating imbalance. | |
- num_clusters: Number of unique clusters found. | |
""" | |
unique_clusters = list(set(clusters)) | |
if len(unique_clusters) < 2: # Avoid single-cluster issues | |
return { | |
"silhouette": -1, | |
"davies_bouldin": float("inf"), | |
"calinski_harabasz": -1, | |
"avg_cluster_size": len(X), | |
"avg_cluster_data_ratio": 1, | |
"std_cluster_size": 0, | |
"min_cluster_size": len(X), | |
"median_cluster_size": len(X), | |
"max_cluster_size": len(X), | |
"cluster_size_skewness": 0, | |
"num_clusters": 1, | |
} | |
# Compute standard clustering metrics | |
silhouette = silhouette_score(X, clusters) | |
davies_bouldin = davies_bouldin_score(X, clusters) | |
calinski_harabasz = calinski_harabasz_score(X, clusters) | |
# Compute cluster size statistics | |
cluster_sizes = [len(np.where(clusters == i)[0]) for i in np.unique(clusters)] | |
avg_cluster_size = np.mean(cluster_sizes) | |
avg_cluster_data_ratio = avg_cluster_size / len(X) | |
std_cluster_size = np.std(cluster_sizes) | |
median_cluster_size = np.median(cluster_sizes) | |
min_cluster_size = np.min(cluster_sizes) | |
max_cluster_size = np.max(cluster_sizes) | |
cluster_size_skewness = skew(cluster_sizes, nan_policy="omit") # Indicates imbalance in cluster sizes | |
return { | |
"silhouette": silhouette, | |
"davies_bouldin": davies_bouldin, | |
"calinski_harabasz": calinski_harabasz, | |
"avg_cluster_size": avg_cluster_size, | |
"avg_cluster_data_ratio": avg_cluster_data_ratio, | |
"std_cluster_size": std_cluster_size, | |
"min_cluster_size": min_cluster_size, | |
"median_cluster_size": median_cluster_size, | |
"max_cluster_size": max_cluster_size, | |
"cluster_size_skewness": cluster_size_skewness, | |
"num_clusters": len(unique_clusters), | |
} | |
def get_representative_e3s( | |
train_df: pd.DataFrame, | |
fp_generator: Optional[Any] = None, | |
n_clusters_candidates: List[int] = [10, 25, 50, 100, 150], | |
e3_column: str = 'E3 Binder SMILES with direction', | |
) -> Tuple[List[str], List[Any], int, pd.DataFrame]: | |
""" | |
Get representative E3 ligands from a DataFrame of training data by clustering their fingerprints. | |
This function computes Morgan fingerprints for unique E3 ligands, clusters them using KMeans and UMAP, | |
evaluates the clusters using silhouette, Davies-Bouldin, and Calinski-Harabasz scores, and identifies | |
the optimal number of clusters based on these metrics. | |
It returns the representative E3 ligands, their fingerprints, the best number of clusters, and a DataFrame | |
containing the clustering metrics. | |
Parameters: | |
train_df (pd.DataFrame): DataFrame containing training data with E3 ligands. | |
fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator with 1024 bits and radius 6 is used. | |
n_clusters_candidates (List[int]): List of candidate numbers of clusters to evaluate. | |
e3_column (str): The column name in the DataFrame that contains the E3 ligand SMILES strings. | |
Returns: | |
Tuple[List[str], List[Any], int, pd.DataFrame]: A tuple containing: | |
- List of representative E3 ligand SMILES strings. | |
- List of RDKit fingerprints corresponding to the representative E3 ligands. | |
- The best number of clusters determined from the clustering metrics. | |
- DataFrame containing clustering metrics for each candidate number of clusters. | |
""" | |
if e3_column not in train_df.columns: | |
raise ValueError(f"Column '{e3_column}' not found in the DataFrame.") | |
if fp_generator is None: | |
fp_generator = rdFingerprintGenerator.GetMorganGenerator( | |
radius=16, | |
fpSize=1024, | |
useBondTypes=True, | |
includeChirality=True, | |
) | |
fp_dict = {} | |
for smi in tqdm(train_df[e3_column].unique()): | |
fp = get_fp(remove_dummy_atoms(smi), fp_generator) | |
if fp is not None: | |
fp_dict[smi] = fp | |
fp_list = list(fp_dict.values()) | |
fp2smiles = {fp.tobytes(): smi for smi, fp in fp_dict.items() if fp is not None} | |
centroids_dict = {} | |
clusters_dict = {} | |
metrics_df = [] | |
for n_clusters in tqdm(n_clusters_candidates, desc="Clustering and evaluating"): | |
clusters, centroids = get_kmeans_clusters_fp(fp_list, n_clusters=n_clusters, return_centroids=True) | |
metrics = evaluate_clusters(fp_list, clusters) | |
clusters_dict[f'kmeans_n{n_clusters}'] = clusters.copy() | |
centroids_dict[n_clusters] = centroids.copy() | |
metrics['num_clusters'] = n_clusters | |
metrics['cluster_algorithm'] = 'kmeans' | |
metrics_df.append(metrics.copy()) | |
clusters = get_umap_clusters_fp(fp_list, n_clusters=n_clusters) | |
metrics = evaluate_clusters(fp_list, clusters) | |
clusters_dict[f'umap_n{n_clusters}'] = clusters.copy() | |
metrics['num_clusters'] = n_clusters | |
metrics['cluster_algorithm'] = 'umap' | |
metrics_df.append(metrics.copy()) | |
metrics_df = pd.DataFrame(metrics_df) | |
# Get the sweet spot for the number of clusters | |
# Flip davies_bouldin so that all metrics are to be maximized | |
metrics_df['-davies_bouldin'] = -metrics_df['davies_bouldin'] | |
# Normalize all three metrics (by group if you want per algorithm) | |
metrics = ['silhouette', '-davies_bouldin', 'calinski_harabasz'] | |
df_norm = metrics_df.copy() | |
df_norm[metrics] = df_norm.groupby('cluster_algorithm')[metrics].transform( | |
lambda x: (x - x.min()) / (x.max() - x.min()) | |
) | |
# Measure divergence: standard deviation of normalized metrics per row | |
df_norm['metric_divergence'] = df_norm[metrics].std(axis=1) | |
# Pick the point with lowest divergence, possibly applying constraints (e.g. not too many clusters) | |
sweet_spots = df_norm.loc[df_norm.groupby('cluster_algorithm')['metric_divergence'].idxmin()] | |
best_n_clusters = sweet_spots[['num_clusters']]['num_clusters'].unique()[0] | |
# Get the centroids of the clusters | |
centroids = centroids_dict[best_n_clusters] | |
# Get the cluster labels for the centroids | |
clusters = np.array(clusters_dict[f'kmeans_n{n_clusters}']) | |
representative_e3s = [] | |
representative_e3s_fp = [] | |
for label, centroid in enumerate(centroids): | |
# Isolate the FP with the same label as the centroid | |
fp_cluster = np.array(fp_list)[clusters == label] | |
# Get the closest FP for the centroid, use euclidean distance | |
distances = np.linalg.norm(fp_cluster - centroid, axis=1) | |
closest_fp = np.argmin(distances) | |
# To get the SMILES from the FP, use the fp2smiles dictionary | |
closest_smiles = fp2smiles[fp_cluster[closest_fp].tobytes()] | |
# Append the closest SMILES to the representative_e3s list | |
representative_e3s.append(closest_smiles) | |
representative_e3s_fp.append(fp_cluster[closest_fp]) | |
# Convert the representative E3s to RDKit fingerprints | |
representative_e3s_fp = [numpy_to_rdkit_fp(fp) for fp in representative_e3s_fp] | |
return representative_e3s, representative_e3s_fp, best_n_clusters, metrics_df | |
DEFAULT_REPRESENTATIVE_E3S = [ | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)CN[*:2])cc1', | |
'O=C1CCC(N2Cc3c(N=[*:2])cccc3C2=O)C(=O)N1', | |
'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)[*:2])C(C)(C)C', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](Oc2ccccc2[*:2])C[C@H]1C(=O)N[C@@H]1CCCc2ccccc21)C1CCCCC1', | |
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(NC(=O)CCO[*:2])C(C)(C)C)cc1', | |
'O=C1CCC(N2Cc3ccc([*:2])cc3C2=O)C(=O)N1', | |
'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)c(OC(C)C)c1', | |
'CC(NC(=O)C1CC(O)CN1C(=O)C(N[*:2])C(C)(C)C)c1ccc(C2CC2)cc1', | |
'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN(CCCC[*:2])CC1', | |
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1', | |
'O=C1CCC(N2C(=O)c3cccc(OC[*:2])c3C2=O)C(=O)N1', | |
'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN([*:2])CC1', | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@H]2C[C@H](O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1', | |
'Cc1ncsc1-c1ccc([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1', | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)c(OC2CCNCC2)c1', | |
'CNC(C)C(=O)NC(C(=O)N1CC(Oc2ccc([*:2])cc2)CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1', | |
'C[C@H](NC(=O)[C@@H]1C[C@@H](O)CN1C(=O)[C@@H](N[*:2])C(C)(C)C)c1ccc(C(C)(C)C)cc1', | |
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1', | |
'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)NCc1ccc(-c2scnc2C)cc1[*:2])C(C)(C)C', | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)C2(F)CC2)C(C)(C)C)c([*:2])c1', | |
'CCOc1cc(C(C)(C)C)ccc1C1=NC(C)(c2ccc(Cl)cc2)C(C)(c2ccc(Cl)cc2)N1C(=O)N1CCN(CC(=O)[*:2])CC1', | |
'COc1ccc(C(=O)[*:2])cc1N1CCC(=O)NC1=O', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N[C@H]1C[C@H]2CC[C@@H]1N(CCc1ccc([*:2])cc1)C2)C1CCCCC1', | |
'CNC(C)C(=O)NC(C(=O)N1CC(N[*:2])CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1', | |
'CN[C@@H](C)C(=O)N[C@@H](CCCCN[*:2])C(=O)N1CCC[C@H]1C(=O)Nc1snnc1-c1ccccc1', | |
'CNC(C)C(=O)NC(C(=O)NC1CC2CCC1N(CCc1cccc([*:2])c1)C2)C1CCCCC1', | |
'O=C1CCC(N2C(=O)c3ccc(N[*:2])cc3C2=O)C(=O)N1', | |
'CNC(C)C(=O)NC(C(=O)N1CC(NC(=O)CC[*:2])CC1C(=O)Nc1c(F)cccc1F)C(C)(C)C', | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@H](N[*:2])C(C)(C)C)cc1', | |
'Cc1nc[nH]c1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1', | |
'Cc1ncsc1-c1ccc(C(C)NC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1', | |
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1', | |
'O=C1CCC(c2cccc([*:2])c2)C(=O)N1', | |
'CC(=O)N[C@H](C(=O)N1C[C@@H](O)C[C@@H]1C(=O)N[C@@H](CC(=O)N1CCC([*:2])CC1)c1ccccc1)C(C)C', | |
'O=C(CCl)[*:2]', | |
'CC[C@@H](NC(=O)[C@@H]1C[C@H](N[*:2])CN1C(=O)[C@@H](NC(=O)[C@H](C)NC)C(C)(C)C)c1ccccc1', | |
'CN[C@H](C)C(=O)N[C@@H]1CCO[C@@H]2CC(C)(C)[C@H](C(=O)N[C@@H]3CCCc4cc([*:2])ccc43)N2C1=O', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc(F)cc2)cs1)C1CCN(C[*:2])CC1', | |
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1', | |
'CNC(C)C(=O)NC(CCCCN[*:2])C(=O)N1CCCC1C(=O)Nc1snnc1-c1ccccc1', | |
'O=C1CCC(N2C(=O)c3cccc([*:2])c3C2=O)C(=O)O1', | |
'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)cc1OC(C)C', | |
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)c(OC2CCNCC2)c1', | |
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(-n2ccc3c(C(=O)[*:2])cccc32)c1)C(C)C', | |
'CCN1CCN(Cc2ccc(NC(=O)c3cccc(-c4ccc5nc(N[*:2])sc5n4)c3)cc2C(F)(F)F)CC1', | |
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)CC[*:2])C[C@H]1C(=O)Nc1c(F)cccc1F)C(C)(C)C', | |
'CNC(C)C(=O)NC(C(=O)N1CCCC1C(=O)NC(C(=O)[*:2])C(c1ccccc1)c1ccccc1)C1CCCCC1', | |
'CC(=O)NCC(C(=O)N1CC(O)CC1C(=O)NC(CC(=O)N1CCC(N2CCC([*:2])CC2)CC1)c1ccccc1)C(C)C', | |
] | |
def get_representative_e3s_fp( | |
e3_list: Optional[List[str]] = None, | |
fp_generator: Optional[Any] = None, | |
verbose: int = 0, | |
) -> List[DataStructs.ExplicitBitVect]: | |
""" | |
Generate Morgan fingerprints for a list of E3 ligands. If no list is provided, | |
it uses a default list of representative E3 ligands. | |
Parameters: | |
e3_list (Optional[List[str]]): List of SMILES strings for E3 ligands. If None, uses a default list. | |
fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator is used. | |
Returns: | |
List[DataStructs.ExplicitBitVect]: List of RDKit Morgan fingerprints for the E3 ligands. | |
""" | |
representative_e3s_fp = [] | |
if verbose > 0: | |
iterable = tqdm(e3_list or DEFAULT_REPRESENTATIVE_E3S, desc="Generating fingerprints for E3 ligands") | |
else: | |
iterable = e3_list or DEFAULT_REPRESENTATIVE_E3S | |
for smi in iterable: | |
# Get the Morgan fingerprint for the SMILES string | |
fp = get_fp(remove_dummy_atoms(smi), fp_generator, return_np=False) | |
if fp is not None: | |
representative_e3s_fp.append(fp) | |
else: | |
print(f"Warning: Invalid SMILES string '{smi}' encountered, skipping.") | |
if not representative_e3s_fp: | |
raise ValueError("No valid E3 ligands found in the provided list.") | |
return representative_e3s_fp | |