|
"""Gradio demo for schemist.""" |
|
|
|
from typing import Iterable, List, Union |
|
from io import TextIOWrapper |
|
import os |
|
os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue" |
|
|
|
from carabiner import cast, print_err |
|
from carabiner.pd import read_table |
|
import gradio as gr |
|
import nemony as nm |
|
import numpy as np |
|
import pandas as pd |
|
from rdkit.Chem import Draw, Mol |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedModel |
|
import schemist as sch |
|
from schemist.converting import ( |
|
_TO_FUNCTIONS, |
|
_FROM_FUNCTIONS, |
|
convert_string_representation, |
|
_x2mol, |
|
) |
|
from schemist.tables import converter |
|
|
|
MODELS = ( |
|
"scbirlab/lchemme-base-zinc22-lteq300", |
|
"scbirlab/lchemme-base-dosedo-lteq300", |
|
"facebook/bart-base", |
|
) |
|
|
|
models = {model_name: ( |
|
AutoTokenizer.from_pretrained(model_name), |
|
AutoModelForSeq2SeqLM.from_pretrained(model_name), |
|
) for model_name in MODELS} |
|
|
|
def load_input_data(file: TextIOWrapper) -> pd.DataFrame: |
|
df = read_table(file.name) |
|
string_cols = list(df.select_dtypes(exclude=[np.number])) |
|
df = gr.Dataframe(value=df, visible=True) |
|
return df, gr.Dropdown(choices=string_cols, interactive=True) |
|
|
|
|
|
def _clean_split_input(strings: str) -> List[str]: |
|
return [s2.strip() for s in strings.split("\n") for s2 in s.split(",")] |
|
|
|
|
|
def _convert_input( |
|
strings: str, |
|
input_representation: str = 'smiles', |
|
output_representation: Union[Iterable[str], str] = 'smiles' |
|
) -> List[str]: |
|
strings = _clean_split_input(strings) |
|
converted = convert_string_representation( |
|
strings=strings, |
|
input_representation=input_representation, |
|
output_representation=output_representation, |
|
) |
|
return { |
|
key: list(map(str, cast(val, to=list))) |
|
for key, val in converted.items() |
|
} |
|
|
|
|
|
def model_convert( |
|
df: pd.DataFrame, |
|
name: str, |
|
tokenizer, |
|
model: PreTrainedModel |
|
) -> pd.DataFrame: |
|
|
|
model_basename = name.split("/")[-1] |
|
inputs = tokenizer(df["inputs"].tolist(), return_tensors="pt") |
|
model.eval() |
|
model_args = {key: inputs[key] for key in ['input_ids', 'attention_mask']} |
|
outputs = model( |
|
**model_args, |
|
|
|
) |
|
output_smiles = tokenizer.batch_decode( |
|
outputs.logits.argmax(dim=-1), |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
output_inchikey = convert_string_representation( |
|
strings=output_smiles, |
|
output_representation="inchikey", |
|
) |
|
return pd.DataFrame({ |
|
f"{model_basename}_smiles": output_smiles, |
|
f"{model_basename}_inchikey": output_inchikey, |
|
}) |
|
|
|
|
|
def convert_one( |
|
strings: str, |
|
output_representation: Union[Iterable[str], str] = MODELS[0] |
|
): |
|
input_representation: str = 'smiles' |
|
df = pd.DataFrame({ |
|
"inputs": _clean_split_input(strings), |
|
}) |
|
|
|
true_canonical_df = convert_file( |
|
df=df, |
|
column="inputs", |
|
input_representation=input_representation, |
|
output_representation=["smiles", "inchikey"] |
|
) |
|
|
|
output_representation = cast(output_representation, to=list) |
|
model_canonical_dfs = { |
|
model_name: model_convert(df, model_name, *models[model_name]) |
|
for model_name in output_representation |
|
} |
|
|
|
return gr.DataFrame( |
|
pd.concat([true_canonical_df] + list(model_canonical_dfs.values()), axis=1), |
|
visible=True |
|
) |
|
|
|
|
|
def convert_file( |
|
df: pd.DataFrame, |
|
column: str = 'smiles', |
|
input_representation: str = 'smiles', |
|
output_representation: Union[str, Iterable[str]] = 'smiles' |
|
): |
|
message = f"Converting from {input_representation} to {output_representation}..." |
|
print_err(message) |
|
gr.Info(message, duration=3) |
|
errors, df = converter( |
|
df=df, |
|
column=column, |
|
input_representation=input_representation, |
|
output_representation=output_representation, |
|
) |
|
df = df[ |
|
cast(output_representation, to=list) + |
|
[col for col in df if col not in output_representation] |
|
] |
|
all_err = sum(err for key, err in errors.items()) |
|
message = ( |
|
f"Converted {df.shape[0]} molecules from " |
|
f"{input_representation} to {output_representation} " |
|
f"with {all_err} errors!" |
|
) |
|
print_err(message) |
|
gr.Info(message, duration=5) |
|
return df |
|
|
|
|
|
def draw_one( |
|
strings: Union[Iterable[str], str] |
|
): |
|
input_representation: str = 'smiles' |
|
_ids = _convert_input( |
|
strings, |
|
input_representation, |
|
["inchikey", "id"], |
|
) |
|
mols = cast(_x2mol(_clean_split_input(strings), input_representation), to=list) |
|
if isinstance(mols, Mol): |
|
mols = [mols] |
|
return Draw.MolsToGridImage( |
|
mols, |
|
molsPerRow=min(3, len(mols)), |
|
subImgSize=(300, 300), |
|
legends=["\n".join(items) for items in zip(*_ids.values())], |
|
) |
|
|
|
|
|
def download_table( |
|
df: pd.DataFrame |
|
) -> str: |
|
df_hash = nm.hash(pd.util.hash_pandas_object(df).values) |
|
filename = f"converted-{df_hash}.csv" |
|
df.to_csv(filename, index=False) |
|
return gr.DownloadButton(value=filename, visible=True) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
# SMILES canonicalization with LChemME |
|
|
|
Interface to demonstrate SMILES canonicalization using Large Chemical Models pre-trained using |
|
[LChemME](https://github.com/scbirlab/lchemme). |
|
|
|
""" |
|
) |
|
|
|
input_line = gr.Textbox( |
|
label="Input", |
|
placeholder="Paste your molecule(s) here, one per line", |
|
lines=2, |
|
interactive=True, |
|
submit_btn=True, |
|
) |
|
output_format_single = gr.CheckboxGroup( |
|
label="Use model(s):", |
|
choices=list(MODELS), |
|
value=MODELS[:1], |
|
interactive=True, |
|
) |
|
examples = gr.Examples( |
|
examples=[ |
|
["CC(Oc1c(cccc1)C(=O)N)=O", MODELS[0]], |
|
["O=S1(N([C@H](C)COC(NC[3H])=O)C[C@H]([C@@H](Oc2cc(-c3cnc(c(c3)C)OC)ccc21)CN(C)C(c1c(C)c(sc1Cl)C)=O)C)=O", MODELS[1]], |
|
["CC(Oc1ccccc1C(O)=O)=O", MODELS[0]], |
|
["CC(Oc1ccccc1C(O)=O)=O", MODELS[2]], |
|
], |
|
inputs=[input_line, output_format_single], |
|
) |
|
download_single = gr.DownloadButton( |
|
label="Download converted data", |
|
visible=False, |
|
) |
|
|
|
output_line = gr.DataFrame( |
|
label="Converted", |
|
interactive=False, |
|
visible=False, |
|
) |
|
drawing = gr.Image(label="Chemical structures") |
|
|
|
gr.on( |
|
[ |
|
input_line.submit, |
|
], |
|
fn=convert_one, |
|
inputs=[ |
|
input_line, |
|
output_format_single, |
|
], |
|
outputs={ |
|
output_line, |
|
} |
|
).then( |
|
draw_one, |
|
inputs=[ |
|
input_line, |
|
], |
|
outputs=drawing, |
|
).then( |
|
download_table, |
|
inputs=output_line, |
|
outputs=download_single |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(share=True) |
|
|
|
|