import json | |
import logging | |
import os | |
from collections import defaultdict | |
from typing import Dict, List, Tuple | |
import mols2grid | |
import pandas as pd | |
from rdkit import Chem | |
from terminator.selfies import decoder | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
def draw_grid_generate( | |
seeds: List[str], | |
samples: List[str], | |
n_cols: int = 3, | |
size=(140, 200), | |
) -> str: | |
""" | |
Uses mols2grid to draw a HTML grid for the generated molecules | |
Args: | |
samples: The generated samples. | |
n_cols: Number of columns in grid. Defaults to 5. | |
size: Size of molecule in grid. Defaults to (140, 200). | |
Returns: | |
HTML to display | |
""" | |
result = defaultdict(list) | |
result.update( | |
{ | |
"SMILES": seeds + samples, | |
"Name": [f"Seed_{i}" for i in range(len(seeds))] | |
+ [f"Generated_{i}" for i in range(len(samples))], | |
}, | |
) | |
result_df = pd.DataFrame(result) | |
obj = mols2grid.display( | |
result_df, | |
tooltip=list(result.keys()), | |
height=1100, | |
n_cols=n_cols, | |
name="Results", | |
size=size, | |
) | |
return obj.data | |