"""Gradio demo for schemist.""" from typing import Iterable, List, Union from io import TextIOWrapper import os os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue" from carabiner import cast, print_err from carabiner.pd import read_table import gradio as gr import nemony as nm import numpy as np import pandas as pd from rdkit.Chem import Draw, Mol from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedModel import schemist as sch from schemist.converting import ( _TO_FUNCTIONS, _FROM_FUNCTIONS, convert_string_representation, _x2mol, ) from schemist.tables import converter MODELS = ( "scbirlab/lchemme-base-zinc22-lteq300", "scbirlab/lchemme-base-dosedo-lteq300", "facebook/bart-base", ) models = {model_name: ( AutoTokenizer.from_pretrained(model_name), AutoModelForSeq2SeqLM.from_pretrained(model_name), ) for model_name in MODELS} def load_input_data(file: TextIOWrapper) -> pd.DataFrame: df = read_table(file.name) string_cols = list(df.select_dtypes(exclude=[np.number])) df = gr.Dataframe(value=df, visible=True) return df, gr.Dropdown(choices=string_cols, interactive=True) def _clean_split_input(strings: str) -> List[str]: return [s2.strip() for s in strings.split("\n") for s2 in s.split(",")] def _convert_input( strings: str, input_representation: str = 'smiles', output_representation: Union[Iterable[str], str] = 'smiles' ) -> List[str]: strings = _clean_split_input(strings) converted = convert_string_representation( strings=strings, input_representation=input_representation, output_representation=output_representation, ) return { key: list(map(str, cast(val, to=list))) for key, val in converted.items() } def model_convert( df: pd.DataFrame, name: str, tokenizer, model: PreTrainedModel ) -> pd.DataFrame: model_basename = name.split("/")[-1] inputs = tokenizer(df["inputs"].tolist(), return_tensors="pt") model.eval() model_args = {key: inputs[key] for key in ['input_ids', 'attention_mask']} outputs = model( **model_args, # decoder_input_ids=model_args['input_ids'], ) output_smiles = tokenizer.batch_decode( outputs.logits.argmax(dim=-1), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) output_inchikey = convert_string_representation( strings=output_smiles, output_representation="inchikey", ) return pd.DataFrame({ f"{model_basename}_smiles": output_smiles, f"{model_basename}_inchikey": output_inchikey, }) def convert_one( strings: str, output_representation: Union[Iterable[str], str] = MODELS[0] ): input_representation: str = 'smiles' df = pd.DataFrame({ "inputs": _clean_split_input(strings), }) true_canonical_df = convert_file( df=df, column="inputs", input_representation=input_representation, output_representation=["smiles", "inchikey"] ) output_representation = cast(output_representation, to=list) model_canonical_dfs = { model_name: model_convert(df, model_name, *models[model_name]) for model_name in output_representation } return gr.DataFrame( pd.concat([true_canonical_df] + list(model_canonical_dfs.values()), axis=1), visible=True ) def convert_file( df: pd.DataFrame, column: str = 'smiles', input_representation: str = 'smiles', output_representation: Union[str, Iterable[str]] = 'smiles' ): message = f"Converting from {input_representation} to {output_representation}..." print_err(message) gr.Info(message, duration=3) errors, df = converter( df=df, column=column, input_representation=input_representation, output_representation=output_representation, ) df = df[ cast(output_representation, to=list) + [col for col in df if col not in output_representation] ] all_err = sum(err for key, err in errors.items()) message = ( f"Converted {df.shape[0]} molecules from " f"{input_representation} to {output_representation} " f"with {all_err} errors!" ) print_err(message) gr.Info(message, duration=5) return df def draw_one( strings: Union[Iterable[str], str] ): input_representation: str = 'smiles' _ids = _convert_input( strings, input_representation, ["inchikey", "id"], ) mols = cast(_x2mol(_clean_split_input(strings), input_representation), to=list) if isinstance(mols, Mol): mols = [mols] return Draw.MolsToGridImage( mols, molsPerRow=min(3, len(mols)), subImgSize=(300, 300), legends=["\n".join(items) for items in zip(*_ids.values())], ) def download_table( df: pd.DataFrame ) -> str: df_hash = nm.hash(pd.util.hash_pandas_object(df).values) filename = f"converted-{df_hash}.csv" df.to_csv(filename, index=False) return gr.DownloadButton(value=filename, visible=True) with gr.Blocks() as demo: gr.Markdown( """ # SMILES canonicalization with LChemME Interface to demonstrate SMILES canonicalization using Large Chemical Models pre-trained using [LChemME](https://github.com/scbirlab/lchemme). """ ) input_line = gr.Textbox( label="Input", placeholder="Paste your molecule(s) here, one per line", lines=2, interactive=True, submit_btn=True, ) output_format_single = gr.CheckboxGroup( label="Use model(s):", choices=list(MODELS), value=MODELS[:1], interactive=True, ) examples = gr.Examples( examples=[ ["CC(Oc1c(cccc1)C(=O)N)=O", MODELS[0]], ["O=S1(N([C@H](C)COC(NC[3H])=O)C[C@H]([C@@H](Oc2cc(-c3cnc(c(c3)C)OC)ccc21)CN(C)C(c1c(C)c(sc1Cl)C)=O)C)=O", MODELS[1]], ["CC(Oc1ccccc1C(O)=O)=O", MODELS[0]], ["CC(Oc1ccccc1C(O)=O)=O", MODELS[2]], ], inputs=[input_line, output_format_single], ) download_single = gr.DownloadButton( label="Download converted data", visible=False, ) output_line = gr.DataFrame( label="Converted", interactive=False, visible=False, ) drawing = gr.Image(label="Chemical structures") gr.on( [ input_line.submit, ], fn=convert_one, inputs=[ input_line, output_format_single, ], outputs={ output_line, } ).then( draw_one, inputs=[ input_line, ], outputs=drawing, ).then( download_table, inputs=output_line, outputs=download_single ) if __name__ == "__main__": demo.queue() demo.launch(share=True)