|
import gradio as gr |
|
from mhnreact.inspect import list_models, load_clf |
|
from rdkit.Chem import rdChemReactions as Reaction |
|
from rdkit.Chem.Draw import rdMolDraw2D |
|
from PIL import Image, ImageDraw |
|
from ssretro_template import ssretro |
|
|
|
|
|
def get_output(p): |
|
rxn = Reaction.ReactionFromSmarts(p, useSmiles=False) |
|
d = rdMolDraw2D.MolDraw2DCairo(800, 200) |
|
d.DrawReaction(rxn, highlightByReactant=False) |
|
d.FinishDrawing() |
|
text = d.GetDrawingText() |
|
|
|
return text |
|
|
|
def ssretro_prediction(molecule): |
|
|
|
model_fn = list_models()[0] |
|
retro_clf = load_clf(model_fn) |
|
|
|
outputs = ssretro(molecule, retro_clf) |
|
predict, txt = [], [] |
|
for pred in outputs: |
|
txt.append(f'predicted top-{pred["template_rank"]-1}, prob: {pred["prob"]:2.1f}%; {pred["reaction"]}') |
|
predict.append(get_output(pred["reaction"])) |
|
|
|
return predict, txt |
|
|
|
|
|
def mhn_react_backend(mol): |
|
|
|
output_dir = "outputs" |
|
formatter = "03d" |
|
images = [] |
|
|
|
predictions, comments = ssretro_prediction(mol) |
|
|
|
for i in range(len(predictions)): |
|
output_im = f"{str(output_dir)}/{format(i, formatter)}.png" |
|
|
|
with open(output_im, "wb") as fh: |
|
fh.write(predictions[i]) |
|
fh.close() |
|
|
|
img = Image.open(output_im) |
|
I1 = ImageDraw.Draw(img) |
|
|
|
I1.text((20, 10), comments[i], fill=(30, 0, 44)) |
|
|
|
images.append(img) |
|
img.save(output_im) |
|
|
|
return images |
|
|
|
|
|
demo = gr.Interface(fn=mhn_react_backend, inputs="text", outputs="gallery") |
|
demo.launch() |
|
|
|
|
|
|