import logging from collections import defaultdict from typing import List, Callable from gt4sd.properties import PropertyPredictorRegistry from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor import torch import mols2grid import pandas as pd logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def get_affinity_function(target: str) -> Callable: return lambda mols: torch.stack( list( PaccMann( AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols) ).sample(len(mols)) ) ).tolist() EVAL_DICT = { "qed": PropertyPredictorRegistry.get_property_predictor("qed"), "sa": PropertyPredictorRegistry.get_property_predictor("sas"), } def draw_grid_generate( samples: List[str], properties: List[str], protein_target: str, n_cols: int = 3, size=(140, 200), ) -> str: """ Uses mols2grid to draw a HTML grid for the generated molecules Args: samples: The generated samples. n_cols: Number of columns in grid. Defaults to 5. size: Size of molecule in grid. Defaults to (140, 200). Returns: HTML to display """ if protein_target != "": EVAL_DICT.update({"affinity": get_affinity_function(protein_target)}) result = defaultdict(list) result.update( {"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]}, ) if "affinity" in properties: properties.remove("affinity") vals = EVAL_DICT["affinity"](samples) result["affinity"] = vals # Fill properties for sample in samples: for prop in properties: value = EVAL_DICT[prop](sample) result[prop].append(f"{prop} = {value}") result_df = pd.DataFrame(result) obj = mols2grid.display( result_df, tooltip=list(result.keys()), height=1100, n_cols=n_cols, name="Results", size=size, ) return obj.data