import gradio as gr import pickle from mhnreact.inspect import list_models, load_clf from rdkit.Chem import rdChemReactions as Reaction from rdkit.Chem.Draw import rdMolDraw2D from PIL import Image, ImageDraw, ImageFont from ssretro_template import ssretro, ssretro_custom def custom_template_file(template: str): temp = [x.strip() for x in template.split(',')] template_dict = {} for i in range(len(temp)): template_dict[i] = temp[i] with open('saved_dictionary.pkl', 'wb') as f: pickle.dump(template_dict, f) return template_dict def get_output(p): rxn = Reaction.ReactionFromSmarts(p, useSmiles=False) d = rdMolDraw2D.MolDraw2DCairo(800, 200) d.DrawReaction(rxn, highlightByReactant=False) d.FinishDrawing() text = d.GetDrawingText() return text def ssretro_prediction(molecule, custom_template=False): model_fn = list_models()[0] retro_clf = load_clf(model_fn) predict, txt = [], [] if custom_template: outputs = ssretro_custom(molecule, retro_clf) else: outputs = ssretro(molecule, retro_clf) for pred in outputs: txt.append( f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;') predict.append(get_output(pred["reaction"])) return predict, txt def mhn_react_backend(mol, use_custom: bool): output_dir = "outputs" formatter = "03d" images = [] predictions, comments = ssretro_prediction(mol, use_custom) for i in range(len(predictions)): output_im = f"{str(output_dir)}/{format(i, formatter)}.png" with open(output_im, "wb") as fh: fh.write(predictions[i]) fh.close() font = ImageFont.truetype(r'tools/arial.ttf', 20) img = Image.open(output_im) right = 10 left = 10 top = 50 bottom = 1 width, height = img.size new_width = width + right + left new_height = height + top + bottom result = Image.new(img.mode, (new_width, new_height), (255, 255, 255)) result.paste(img, (left, top)) I1 = ImageDraw.Draw(result) I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0)) images.append(result) result.save(output_im) return images with gr.Blocks() as demo: gr.Markdown( """ [![Github](https://img.shields.io/badge/github-%20mhn--react-blue)](https://img.shields.io/badge/github-%20mhn--react-blue) [![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb) ### MHN-react Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities, molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis. """ ) with gr.Accordion("Information"): gr.Markdown("use one of example molecules
CC(=O)NCCC1=CNc2c1cc(OC)cc2,
CN1CCC[C@H]1c2cccnc2,
OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N" "In case the output is empty, no applicable templates were found" ) with gr.Tab("Generate Templates"): with gr.Row(): with gr.Column(scale = 1): inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule") radio = gr.Radio([False, True], label="use custom templates") btn = gr.Button(value="Generate") with gr.Column(scale=2): out = gr.Gallery(label="retro-synthesis") btn.click(mhn_react_backend, [inp, radio], out) with gr.Tab("Create custom templates"): gr.Markdown( """ Input the templates separated by comma.
Please do not upload templates one-by-one """ ) with gr.Column(): inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)") btn = gr.Button(value="upload") out_t = gr.Textbox(label = "added templates") btn.click(custom_template_file, inp_t, out_t) demo.launch(debug = True)