MHN-React / mhnreact /inspect.py
uragankatrrin's picture
Upload 12 files
2956799
raw history blame
No virus
3.26 kB
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
Johannes Kepler University Linz
Contact: seidl@ml.jku.at
File contains functions that
"""
from . import model
import torch
import os
MODEL_PATH = 'data/model/'
def smarts2svg(smarts, useSmiles=True, highlightByReactant=True, save_to=''):
"""
draws smiles of smarts to an SVG and displays it in the Notebook,
or optinally can be saved to a file `save_to`
adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
"""
# adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
from rdkit import RDConfig
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit import Geometry
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
from IPython.display import SVG, display
rxn = AllChem.ReactionFromSmarts(smarts,useSmiles=useSmiles)
d = Draw.MolDraw2DSVG(900, 100)
# rxn = AllChem.ReactionFromSmarts('[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][NH2:6]>CC(O)C.[Pt]>[CH3:1][C:2](=[O:3])[NH:6][CH3:5].[OH2:4]',useSmiles=True)
colors=[(0.3, 0.7, 0.9),(0.9, 0.7, 0.9),(0.6,0.9,0.3),(0.9,0.9,0.1)]
try:
d.DrawReaction(rxn,highlightByReactant=highlightByReactant)
d.FinishDrawing()
txt = d.GetDrawingText()
# self.assertTrue(txt.find("<svg") != -1)
# self.assertTrue(txt.find("</svg>") != -1)
svg = d.GetDrawingText()
svg2 = svg.replace('svg:','')
svg3 = SVG(svg2)
display(svg3)
if save_to!='':
with open(save_to, 'w') as f_handle:
f_handle.write(svg3.data)
except:
print('Error drawing')
return svg2
def list_models(model_path=MODEL_PATH):
"""returns a list of loadable models"""
return dict(enumerate(list(filter(lambda k: str(k)[-3:]=='.pt', os.listdir(model_path)))))
def load_clf(model_fn='', model_path=MODEL_PATH, device='cpu', model_type='mhn'):
""" returns the model with loaded weights given a filename"""
import json
config_fn = '_'.join(model_fn.split('_')[-2:]).split('.pt')[0]
conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
train_conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
# specify the config the saved model had
conf = model.ModelConfig(**conf_dict)
conf.device = device
print(conf.__dict__)
if model_type == 'staticQK':
clf = model.StaticQK(conf)
elif model_type == 'mhn':
clf = model.MHN(conf)
elif model_type == 'segler':
clf = model.SeglerBaseline(conf)
elif model_type == 'fortunato':
clf = model.SeglerBaseline(conf)
else:
raise NotImplementedError('model_type',model_type,'not found')
# load the model
PATH = model_path+model_fn
params = torch.load(PATH, map_location=torch.device('cpu')) #!!!
clf.load_state_dict(params, strict=False)
if 'templates+noise' in params.keys():
print('loading templates+noise')
clf.templates = params['templates+noise']
#clf.templates.to(clf.config.device)
return clf