File size: 1,515 Bytes
6bd991c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from mhnreact.inspect import list_models, load_clf
from rdkit.Chem import rdChemReactions as Reaction
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image, ImageDraw
from ssretro_template import ssretro


def get_output(p):
    rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
    d = rdMolDraw2D.MolDraw2DCairo(800, 200)
    d.DrawReaction(rxn, highlightByReactant=False)
    d.FinishDrawing()
    text = d.GetDrawingText()

    return text

def ssretro_prediction(molecule):

    model_fn = list_models()[0]
    retro_clf = load_clf(model_fn)

    outputs = ssretro(molecule, retro_clf)
    predict, txt = [], []
    for pred in outputs:
        txt.append(f'predicted top-{pred["template_rank"]-1}, prob: {pred["prob"]:2.1f}%; {pred["reaction"]}')
        predict.append(get_output(pred["reaction"]))

    return predict, txt


def mhn_react_backend(mol):

    output_dir = "outputs"
    formatter = "03d"
    images = []

    predictions, comments = ssretro_prediction(mol)

    for i in range(len(predictions)):
        output_im = f"{str(output_dir)}/{format(i, formatter)}.png"

        with open(output_im, "wb") as fh:
            fh.write(predictions[i])
        fh.close()

        img = Image.open(output_im)
        I1 = ImageDraw.Draw(img)

        I1.text((20, 10), comments[i], fill=(30, 0, 44))

        images.append(img)
        img.save(output_im)

    return images


demo = gr.Interface(fn=mhn_react_backend, inputs="text", outputs="gallery")
demo.launch()