import gradio as gr from mhnreact.inspect import list_models, load_clf from rdkit.Chem import rdChemReactions as Reaction from rdkit.Chem.Draw import rdMolDraw2D from PIL import Image, ImageDraw from ssretro_template import ssretro def get_output(p): rxn = Reaction.ReactionFromSmarts(p, useSmiles=False) d = rdMolDraw2D.MolDraw2DCairo(800, 200) d.DrawReaction(rxn, highlightByReactant=False) d.FinishDrawing() text = d.GetDrawingText() return text def ssretro_prediction(molecule): model_fn = list_models()[0] retro_clf = load_clf(model_fn) outputs = ssretro(molecule, retro_clf) predict, txt = [], [] for pred in outputs: txt.append(f'predicted top-{pred["template_rank"]-1}, prob: {pred["prob"]:2.1f}%; {pred["reaction"]}') predict.append(get_output(pred["reaction"])) return predict, txt def mhn_react_backend(mol): output_dir = "outputs" formatter = "03d" images = [] predictions, comments = ssretro_prediction(mol) for i in range(len(predictions)): output_im = f"{str(output_dir)}/{format(i, formatter)}.png" with open(output_im, "wb") as fh: fh.write(predictions[i]) fh.close() img = Image.open(output_im) I1 = ImageDraw.Draw(img) I1.text((20, 10), comments[i], fill=(30, 0, 44)) images.append(img) img.save(output_im) return images demo = gr.Interface(fn=mhn_react_backend, inputs="text", outputs="gallery") demo.launch()