import logging import mols2grid import pandas as pd from rdkit import Chem from terminator.selfies import decoder logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def draw_grid_predict(prediction: str, target: str, domain: str) -> str: """ Uses mols2grid to draw a HTML grid for the prediction Args: prediction: Predicted sequence. target: Target molecule domain: Domain of the prediction (molecules or proteins) Returns: HTML to display """ if domain not in ["Molecules", "Proteins"]: raise ValueError(f"Unsupported domain {domain}") seq = target.split("|")[-1] converter = ( decoder if domain == "Molecules" else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x)) ) try: seq = converter(seq) except Exception: logger.warning(f"Could not draw sequence {seq}") result = {"SMILES": [seq], "Name": ["Target"]} # Add properties for prop in prediction.split("<")[1:]: result[ prop.split(">")[0] ] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}" result_df = pd.DataFrame(result) obj = mols2grid.display( result_df, tooltip=list(result.keys()), height=900, n_cols=1, name="Results", size=(600, 700), ) return obj.data