| | import os |
| | import sys |
| | import re |
| |
|
| | from pathlib import Path |
| | from typing import Collection, List, Dict, Type |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | from tqdm import tqdm |
| |
|
| | from .metrics import FullEvaluator, FullCollectionEvaluator |
| |
|
| | AUXILIARY_COLUMNS = ['sample', 'sdf_file', 'pdb_file', 'subdir'] |
| | VALIDITY_METRIC_NAME = 'medchem.valid' |
| |
|
| |
|
| | def get_data_type(key: str, data_types: Dict[str, Type], default=float) -> Type: |
| | found_data_type_key = None |
| | found_data_type_value = None |
| | for data_type_key, data_type_value in data_types.items(): |
| | if re.match(data_type_key, key) is not None: |
| | if found_data_type_key is not None: |
| | raise ValueError(f'Multiple data type keys match [{key}]: {found_data_type_key}, {data_type_key}') |
| |
|
| | found_data_type_value = data_type_value |
| | found_data_type_key = data_type_key |
| |
|
| | if found_data_type_key is None: |
| | if default is None: |
| | raise KeyError(key) |
| | else: |
| | found_data_type_value = default |
| |
|
| | return found_data_type_value |
| |
|
| |
|
| | def convert_data_to_table(data: List[Dict], data_types: Dict[str, Type]) -> pd.DataFrame: |
| | """ |
| | Converts data from `evaluate_drugflow` to a detailed table |
| | """ |
| | table = [] |
| | for entry in data: |
| | table_entry = {} |
| | for key, value in entry.items(): |
| | if key in AUXILIARY_COLUMNS: |
| | table_entry[key] = value |
| | continue |
| | if get_data_type(key, data_types) != list: |
| | table_entry[key] = value |
| | table.append(table_entry) |
| |
|
| | return pd.DataFrame(table) |
| |
|
| | def aggregated_metrics(table: pd.DataFrame, data_types: Dict[str, Type], validity_metric_name: str = None): |
| | """ |
| | Args: |
| | table (pd.DataFrame): table with metrics computed for each sample |
| | data_types (Dict[str, Type]): dictionary with data types for each column |
| | validity_metric_name (str): name of the column that has validity metric |
| | |
| | Returns: |
| | agg_table (pd.DataFrame): table with columns ['metric', 'value', 'std'] |
| | """ |
| | aggregated_results = [] |
| |
|
| | |
| | |
| | |
| | if validity_metric_name is not None: |
| | aggregated_results.append({ |
| | 'metric': validity_metric_name, |
| | 'value': table[validity_metric_name].fillna(False).astype(float).mean(), |
| | 'std': None, |
| | }) |
| | table = table[table[validity_metric_name]] |
| |
|
| | |
| | for column in table.columns: |
| | if column in AUXILIARY_COLUMNS + [validity_metric_name] or get_data_type(column, data_types) == str: |
| | continue |
| | with pd.option_context("future.no_silent_downcasting", True): |
| | if get_data_type(column, data_types) == bool: |
| | values = table[column].fillna(0).values.astype(float).mean() |
| | std = None |
| | else: |
| | values = table[column].dropna().values.astype(float).mean() |
| | std = table[column].dropna().values.astype(float).std() |
| |
|
| | aggregated_results.append({ |
| | 'metric': column, |
| | 'value': values, |
| | 'std': std, |
| | }) |
| |
|
| | agg_table = pd.DataFrame(aggregated_results) |
| | return agg_table |
| |
|
| |
|
| | def collection_metrics( |
| | table: pd.DataFrame, |
| | reference_smiles: Collection[str], |
| | validity_metric_name: str = None, |
| | exclude_evaluators: Collection[str] = [], |
| | ): |
| | """ |
| | Args: |
| | table (pd.DataFrame): table with metrics computed for each sample |
| | reference_smiles (Collection[str]): list of reference SMILES (e.g. training set) |
| | validity_metric_name (str): name of the column that has validity metric |
| | exclude_evaluators (Collection[str]): Evaluator IDs to exclude |
| | |
| | Returns: |
| | col_table (pd.DataFrame): table with columns ['metric', 'value'] |
| | """ |
| |
|
| | |
| | if validity_metric_name is not None: |
| | table = table[table[validity_metric_name]] |
| |
|
| | evaluator = FullCollectionEvaluator(reference_smiles, exclude_evaluators=exclude_evaluators) |
| | smiles = table['representation.smiles'].values |
| | if len(smiles) == 0: |
| | print('No valid input molecules') |
| | return pd.DataFrame(columns=['metric', 'value']) |
| |
|
| | collection_metrics = evaluator(smiles) |
| | results = [ |
| | {'metric': key, 'value': value} |
| | for key, value in collection_metrics.items() |
| | ] |
| |
|
| | col_table = pd.DataFrame(results) |
| | return col_table |
| |
|
| |
|
| | def evaluate_drugflow_subdir( |
| | in_dir: Path, |
| | evaluator: FullEvaluator, |
| | desc: str = None, |
| | n_samples: int = None, |
| | ) -> List[Dict]: |
| | """ |
| | Computes per-molecule metrics for a single directory of samples for one target |
| | """ |
| | results = [] |
| | valid_files = [ |
| | int(fname.split('_')[0]) |
| | for fname in os.listdir(in_dir) |
| | if fname.endswith('_ligand.sdf') and not fname.startswith('.') |
| | ] |
| | if len(valid_files) == 0: |
| | return pd.DataFrame() |
| |
|
| | upper_bound = max(valid_files) + 1 |
| | if n_samples is not None: |
| | upper_bound = min(upper_bound, n_samples) |
| |
|
| | for i in tqdm(range(upper_bound), desc=desc, file=sys.stdout): |
| | in_mol = Path(in_dir, f'{i}_ligand.sdf') |
| | in_prot = Path(in_dir, f'{i}_pocket.pdb') |
| | res = evaluator(in_mol, in_prot) |
| |
|
| | res['sample'] = i |
| | res['sdf_file'] = str(in_mol) |
| | res['pdb_file'] = str(in_prot) |
| | results.append(res) |
| |
|
| | return results |
| |
|
| |
|
| | def evaluate_drugflow( |
| | in_dir: Path, |
| | evaluator: FullEvaluator, |
| | n_samples: int = None, |
| | job_id: int = 0, |
| | n_jobs: int = 1, |
| | ) -> List[Dict]: |
| | """ |
| | 1. Computes per-molecule metrics for all single directories of samples |
| | 2. Aggregates these metrics |
| | 3. Computes additional collection metrics (if `reference_smiles_path` is provided) |
| | """ |
| | data = [] |
| | total_number_of_subdirs = len([path for path in in_dir.glob("[!.]*") if os.path.isdir(path)]) |
| | i = 0 |
| | for subdir in in_dir.glob("[!.]*"): |
| | if not os.path.isdir(subdir): |
| | continue |
| |
|
| | i += 1 |
| | if (i - 1) % n_jobs != job_id: |
| | continue |
| |
|
| | curr_data = evaluate_drugflow_subdir( |
| | in_dir=subdir, |
| | evaluator=evaluator, |
| | desc=f'[{i}/{total_number_of_subdirs}] {str(subdir.name)}', |
| | n_samples=n_samples, |
| | ) |
| | for entry in curr_data: |
| | entry['subdir'] = str(subdir) |
| | data.append(entry) |
| |
|
| | return data |
| |
|
| |
|
| | def compute_all_metrics_drugflow( |
| | in_dir: Path, |
| | gnina_path: Path, |
| | reduce_path: Path = None, |
| | reference_smiles_path: Path = None, |
| | n_samples: int = None, |
| | validity_metric_name: str = VALIDITY_METRIC_NAME, |
| | exclude_evaluators: Collection[str] = [], |
| | job_id: int = 0, |
| | n_jobs: int = 1, |
| | ): |
| | evaluator = FullEvaluator(gnina=gnina_path, reduce=reduce_path, exclude_evaluators=exclude_evaluators) |
| | data = evaluate_drugflow(in_dir=in_dir, evaluator=evaluator, n_samples=n_samples, job_id=job_id, n_jobs=n_jobs) |
| | table_detailed = convert_data_to_table(data, evaluator.dtypes) |
| | table_aggregated = aggregated_metrics( |
| | table_detailed, |
| | data_types=evaluator.dtypes, |
| | validity_metric_name=validity_metric_name |
| | ) |
| |
|
| | |
| | if reference_smiles_path is not None: |
| | reference_smiles = np.load(reference_smiles_path) |
| | col_metrics = collection_metrics( |
| | table=table_detailed, |
| | reference_smiles=reference_smiles, |
| | validity_metric_name=validity_metric_name, |
| | exclude_evaluators=exclude_evaluators |
| | ) |
| | table_aggregated = pd.concat([table_aggregated, col_metrics]) |
| |
|
| | return data, table_detailed, table_aggregated |
| |
|