lchemme / app.py
Eachan Johnson
Don't use non-default model cache
cd9c1a6
"""Gradio demo for schemist."""
from typing import Iterable, List, Union
from io import TextIOWrapper
import os
os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue"
from carabiner import cast, print_err
from carabiner.pd import read_table
import gradio as gr
import nemony as nm
import numpy as np
import pandas as pd
from rdkit.Chem import Draw, Mol
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedModel
import schemist as sch
from schemist.converting import (
_TO_FUNCTIONS,
_FROM_FUNCTIONS,
convert_string_representation,
_x2mol,
)
from schemist.tables import converter
MODELS = (
"scbirlab/lchemme-base-zinc22-lteq300",
"scbirlab/lchemme-base-dosedo-lteq300",
"facebook/bart-base",
)
models = {model_name: (
AutoTokenizer.from_pretrained(model_name),
AutoModelForSeq2SeqLM.from_pretrained(model_name),
) for model_name in MODELS}
def load_input_data(file: TextIOWrapper) -> pd.DataFrame:
df = read_table(file.name)
string_cols = list(df.select_dtypes(exclude=[np.number]))
df = gr.Dataframe(value=df, visible=True)
return df, gr.Dropdown(choices=string_cols, interactive=True)
def _clean_split_input(strings: str) -> List[str]:
return [s2.strip() for s in strings.split("\n") for s2 in s.split(",")]
def _convert_input(
strings: str,
input_representation: str = 'smiles',
output_representation: Union[Iterable[str], str] = 'smiles'
) -> List[str]:
strings = _clean_split_input(strings)
converted = convert_string_representation(
strings=strings,
input_representation=input_representation,
output_representation=output_representation,
)
return {
key: list(map(str, cast(val, to=list)))
for key, val in converted.items()
}
def model_convert(
df: pd.DataFrame,
name: str,
tokenizer,
model: PreTrainedModel
) -> pd.DataFrame:
model_basename = name.split("/")[-1]
inputs = tokenizer(df["inputs"].tolist(), return_tensors="pt")
model.eval()
model_args = {key: inputs[key] for key in ['input_ids', 'attention_mask']}
outputs = model(
**model_args,
# decoder_input_ids=model_args['input_ids'],
)
output_smiles = tokenizer.batch_decode(
outputs.logits.argmax(dim=-1),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
output_inchikey = convert_string_representation(
strings=output_smiles,
output_representation="inchikey",
)
return pd.DataFrame({
f"{model_basename}_smiles": output_smiles,
f"{model_basename}_inchikey": output_inchikey,
})
def convert_one(
strings: str,
output_representation: Union[Iterable[str], str] = MODELS[0]
):
input_representation: str = 'smiles'
df = pd.DataFrame({
"inputs": _clean_split_input(strings),
})
true_canonical_df = convert_file(
df=df,
column="inputs",
input_representation=input_representation,
output_representation=["smiles", "inchikey"]
)
output_representation = cast(output_representation, to=list)
model_canonical_dfs = {
model_name: model_convert(df, model_name, *models[model_name])
for model_name in output_representation
}
return gr.DataFrame(
pd.concat([true_canonical_df] + list(model_canonical_dfs.values()), axis=1),
visible=True
)
def convert_file(
df: pd.DataFrame,
column: str = 'smiles',
input_representation: str = 'smiles',
output_representation: Union[str, Iterable[str]] = 'smiles'
):
message = f"Converting from {input_representation} to {output_representation}..."
print_err(message)
gr.Info(message, duration=3)
errors, df = converter(
df=df,
column=column,
input_representation=input_representation,
output_representation=output_representation,
)
df = df[
cast(output_representation, to=list) +
[col for col in df if col not in output_representation]
]
all_err = sum(err for key, err in errors.items())
message = (
f"Converted {df.shape[0]} molecules from "
f"{input_representation} to {output_representation} "
f"with {all_err} errors!"
)
print_err(message)
gr.Info(message, duration=5)
return df
def draw_one(
strings: Union[Iterable[str], str]
):
input_representation: str = 'smiles'
_ids = _convert_input(
strings,
input_representation,
["inchikey", "id"],
)
mols = cast(_x2mol(_clean_split_input(strings), input_representation), to=list)
if isinstance(mols, Mol):
mols = [mols]
return Draw.MolsToGridImage(
mols,
molsPerRow=min(3, len(mols)),
subImgSize=(300, 300),
legends=["\n".join(items) for items in zip(*_ids.values())],
)
def download_table(
df: pd.DataFrame
) -> str:
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
filename = f"converted-{df_hash}.csv"
df.to_csv(filename, index=False)
return gr.DownloadButton(value=filename, visible=True)
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMILES canonicalization with LChemME
Interface to demonstrate SMILES canonicalization using Large Chemical Models pre-trained using
[LChemME](https://github.com/scbirlab/lchemme).
"""
)
input_line = gr.Textbox(
label="Input",
placeholder="Paste your molecule(s) here, one per line",
lines=2,
interactive=True,
submit_btn=True,
)
output_format_single = gr.CheckboxGroup(
label="Use model(s):",
choices=list(MODELS),
value=MODELS[:1],
interactive=True,
)
examples = gr.Examples(
examples=[
["CC(Oc1c(cccc1)C(=O)N)=O", MODELS[0]],
["O=S1(N([C@H](C)COC(NC[3H])=O)C[C@H]([C@@H](Oc2cc(-c3cnc(c(c3)C)OC)ccc21)CN(C)C(c1c(C)c(sc1Cl)C)=O)C)=O", MODELS[1]],
["CC(Oc1ccccc1C(O)=O)=O", MODELS[0]],
["CC(Oc1ccccc1C(O)=O)=O", MODELS[2]],
],
inputs=[input_line, output_format_single],
)
download_single = gr.DownloadButton(
label="Download converted data",
visible=False,
)
output_line = gr.DataFrame(
label="Converted",
interactive=False,
visible=False,
)
drawing = gr.Image(label="Chemical structures")
gr.on(
[
input_line.submit,
],
fn=convert_one,
inputs=[
input_line,
output_format_single,
],
outputs={
output_line,
}
).then(
draw_one,
inputs=[
input_line,
],
outputs=drawing,
).then(
download_table,
inputs=output_line,
outputs=download_single
)
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)