import json import logging import os from collections import defaultdict from typing import Dict, List, Tuple import mols2grid import pandas as pd from rdkit import Chem from terminator.selfies import decoder logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def draw_grid_generate( seeds: List[str], scaffolds: List[str], samples: List[str], n_cols: int = 5, size=(140, 200), ) -> str: """ Uses mols2grid to draw a HTML grid for the generated molecules Args: samples: The generated samples. n_cols: Number of columns in grid. Defaults to 5. size: Size of molecule in grid. Defaults to (140, 200). Returns: HTML to display """ result = defaultdict(list) result.update( { "SMILES": seeds + scaffolds + samples, "Name": [f"Seed_{i}" for i in range(len(seeds))] + [f"Scaffold_{i}" for i in range(len(scaffolds))] + [f"Generated_{i}" for i in range(len(samples))], }, ) result_df = pd.DataFrame(result) obj = mols2grid.display( result_df, tooltip=list(result.keys()), height=1100, n_cols=n_cols, name="Results", size=size, ) return obj.data