File size: 3,257 Bytes
2956799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
        ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
        Johannes Kepler University Linz
Contact: seidl@ml.jku.at

File contains functions that
"""

from . import model
import torch
import os

MODEL_PATH = 'data/model/'

def smarts2svg(smarts, useSmiles=True, highlightByReactant=True, save_to=''):
    """
    draws smiles of smarts to an SVG and displays it in the Notebook,
    or optinally can be saved to a file `save_to`
    adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
    """
    # adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
    from rdkit import RDConfig
    from rdkit import Chem
    from rdkit.Chem import Draw, AllChem
    from rdkit.Chem.Draw import rdMolDraw2D
    from rdkit import Geometry
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import matplotlib
    from IPython.display import SVG, display

    rxn = AllChem.ReactionFromSmarts(smarts,useSmiles=useSmiles)
    d = Draw.MolDraw2DSVG(900, 100)

    # rxn = AllChem.ReactionFromSmarts('[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][NH2:6]>CC(O)C.[Pt]>[CH3:1][C:2](=[O:3])[NH:6][CH3:5].[OH2:4]',useSmiles=True)
    colors=[(0.3, 0.7, 0.9),(0.9, 0.7, 0.9),(0.6,0.9,0.3),(0.9,0.9,0.1)]
    try:
        d.DrawReaction(rxn,highlightByReactant=highlightByReactant)
        d.FinishDrawing()

        txt = d.GetDrawingText()
        # self.assertTrue(txt.find("<svg") != -1)
        # self.assertTrue(txt.find("</svg>") != -1)

        svg = d.GetDrawingText()
        svg2 = svg.replace('svg:','')
        svg3 = SVG(svg2)
        display(svg3)

        if save_to!='':
            with open(save_to, 'w') as f_handle:
                f_handle.write(svg3.data)
    except:
        print('Error drawing')

    return svg2

def list_models(model_path=MODEL_PATH):
    """returns a list of loadable models"""
    return dict(enumerate(list(filter(lambda k: str(k)[-3:]=='.pt', os.listdir(model_path)))))

def load_clf(model_fn='', model_path=MODEL_PATH, device='cpu', model_type='mhn'):
    """ returns the model with loaded weights given a filename"""
    import json
    config_fn = '_'.join(model_fn.split('_')[-2:]).split('.pt')[0]
    conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
    train_conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )

    # specify the config the saved model had
    conf = model.ModelConfig(**conf_dict)
    conf.device = device
    print(conf.__dict__)

    if model_type == 'staticQK':
        clf = model.StaticQK(conf)
    elif model_type == 'mhn':
        clf = model.MHN(conf)
    elif model_type == 'segler':
        clf = model.SeglerBaseline(conf)
    elif model_type == 'fortunato':
        clf = model.SeglerBaseline(conf)
    else:
        raise NotImplementedError('model_type',model_type,'not found')

    # load the model
    PATH = model_path+model_fn
    params = torch.load(PATH, map_location=torch.device('cpu')) #!!!
    clf.load_state_dict(params, strict=False)
    if 'templates+noise' in params.keys():
        print('loading templates+noise')
        clf.templates = params['templates+noise']
        #clf.templates.to(clf.config.device)
    return clf