|
|
from modules import *
|
|
|
import os, sys
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from config import CFG
|
|
|
import utils
|
|
|
import json
|
|
|
import pandas as pd
|
|
|
import pickle
|
|
|
from rdkit import Chem
|
|
|
from rdkit.Chem import inchi
|
|
|
|
|
|
def smiles_to_inchikey(smiles, nostereo=True):
|
|
|
try:
|
|
|
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
|
if mol is None:
|
|
|
return None
|
|
|
|
|
|
if nostereo:
|
|
|
options = "-SNon"
|
|
|
inchi_string = inchi.MolToInchi(mol, options=options)
|
|
|
else:
|
|
|
inchi_string = inchi.MolToInchi(mol)
|
|
|
|
|
|
if not inchi_string:
|
|
|
return None
|
|
|
|
|
|
inchikey = inchi.InchiToInchiKey(inchi_string)
|
|
|
|
|
|
return inchikey
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"转换失败: {e}")
|
|
|
return None
|
|
|
|
|
|
def calc_mol_embeddings(model, smis, cfg):
|
|
|
model.eval()
|
|
|
fp_featsl = []
|
|
|
gnn_featsl = []
|
|
|
fm_featsl = []
|
|
|
valid_smis = []
|
|
|
|
|
|
for smil in smis:
|
|
|
smi = smil[1]
|
|
|
try:
|
|
|
if 'gnn' in cfg.mol_encoder:
|
|
|
gnn_feats = utils.mol_graph_featurizer(smi)
|
|
|
gnn_featsl.append(gnn_feats)
|
|
|
if 'fp' in cfg.mol_encoder:
|
|
|
fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
|
|
fp_featsl.append(fp_feats)
|
|
|
if 'fm' in cfg.mol_encoder:
|
|
|
fm_feats = utils.smi2fmvec(smi).to(cfg.device)
|
|
|
fm_featsl.append(fm_feats)
|
|
|
valid_smis.append(smil)
|
|
|
except Exception as e:
|
|
|
print(smi, e)
|
|
|
continue
|
|
|
|
|
|
mol_feat_list = []
|
|
|
if 'gnn' in cfg.mol_encoder:
|
|
|
vl, al, msl = [], [], []
|
|
|
bat = {}
|
|
|
for b in gnn_featsl:
|
|
|
if 'V' in b:
|
|
|
vl.append(b['V'])
|
|
|
if 'A' in b:
|
|
|
al.append(b['A'])
|
|
|
if 'mol_size' in b:
|
|
|
msl.append(b['mol_size'])
|
|
|
|
|
|
vl1, al1 = [], []
|
|
|
if vl and al and msl:
|
|
|
max_n = max(map(lambda x:x.shape[0], vl))
|
|
|
for v in vl:
|
|
|
vl1.append(utils.pad_V(v, max_n))
|
|
|
for a in al:
|
|
|
al1.append(utils.pad_A(a, max_n))
|
|
|
|
|
|
bat['V'] = torch.stack(vl1).to(cfg.device)
|
|
|
bat['A'] = torch.stack(al1).to(cfg.device)
|
|
|
bat['mol_size'] = torch.cat(msl, dim=0).to(cfg.device)
|
|
|
|
|
|
mol_feat_list.append(model.mol_gnn_encoder(bat))
|
|
|
|
|
|
del bat
|
|
|
|
|
|
if 'fp' in cfg.mol_encoder:
|
|
|
mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
|
|
|
|
|
|
if 'fm' in cfg.mol_encoder:
|
|
|
mol_feat_list.append(torch.stack(fm_featsl).to(cfg.device))
|
|
|
|
|
|
if len(mol_feat_list) > 1:
|
|
|
mol_features = torch.cat(mol_feat_list, dim=1).to(cfg.device)
|
|
|
else:
|
|
|
mol_features = mol_feat_list[0].to(cfg.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
mol_embeddings = model.mol_projection(mol_features)
|
|
|
|
|
|
del mol_features, mol_feat_list
|
|
|
|
|
|
return mol_embeddings, valid_smis
|
|
|
|
|
|
def find_matches(model, ms, smis, cfg, n=10, batch_size=64):
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
|
|
|
ms_features = ms_features.unsqueeze(0)
|
|
|
ms_embeddings = model.ms_projection(ms_features)
|
|
|
ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=1)
|
|
|
|
|
|
|
|
|
all_similarities = []
|
|
|
all_valid_smis = []
|
|
|
|
|
|
|
|
|
all_embeddings = []
|
|
|
for i in tqdm(range(0, len(smis), batch_size)):
|
|
|
batch_smis = smis[i:i+batch_size]
|
|
|
batch_embeddings, valid_smis = calc_mol_embeddings(model, batch_smis, cfg)
|
|
|
all_embeddings.append(batch_embeddings)
|
|
|
all_valid_smis.extend(valid_smis)
|
|
|
|
|
|
del batch_embeddings
|
|
|
|
|
|
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
|
all_embeddings_n = F.normalize(all_embeddings, p=2, dim=1)
|
|
|
|
|
|
|
|
|
similarities = F.cosine_similarity(all_embeddings_n, ms_embeddings_n, dim=1)
|
|
|
|
|
|
|
|
|
if n == -1 or n > len(all_valid_smis):
|
|
|
n = len(all_valid_smis)
|
|
|
|
|
|
values, topk_indices = torch.topk(similarities, n)
|
|
|
|
|
|
topk_indices_list = topk_indices.cpu().tolist()
|
|
|
|
|
|
matchsmis = [all_valid_smis[idx] for idx in topk_indices_list]
|
|
|
|
|
|
return matchsmis, values.cpu().numpy()*100, topk_indices_list
|
|
|
|
|
|
def calc(models, datal, cfg):
|
|
|
dicall = {}
|
|
|
coridxd = {}
|
|
|
|
|
|
for idx, model in enumerate(models):
|
|
|
for nn, data in enumerate(datal):
|
|
|
print(f'Calculating {nn}-th MS...')
|
|
|
|
|
|
try:
|
|
|
smis, scores, indices = find_matches(model, data['ms'], data['candidates'], cfg, 50)
|
|
|
except Exception as e:
|
|
|
print(131, e)
|
|
|
continue
|
|
|
|
|
|
dic = {}
|
|
|
for n, smil in enumerate(smis):
|
|
|
smi = smil[1]
|
|
|
if smi in dic:
|
|
|
dic[smi]['score'] = scores[n]
|
|
|
dic[smi]['iscor'] = smis[n][-1]
|
|
|
dic[smi]['idx'] = smis[n][0]
|
|
|
else:
|
|
|
dic[smi] = {'score': scores[n], 'iscor': smis[n][-1], 'idx': smis[n][0]}
|
|
|
|
|
|
|
|
|
ikey = smiles_to_inchikey(data['smiles'], True)
|
|
|
if ikey is None:
|
|
|
ikey = data['ikey']
|
|
|
|
|
|
if ikey in dicall:
|
|
|
for k, v in dic.items():
|
|
|
if k in dicall[ikey]:
|
|
|
dicall[ikey][k]['score'] += v['score']
|
|
|
dicall[ikey][k]['score'] /= 2
|
|
|
else:
|
|
|
dicall[ikey][k] = v
|
|
|
else:
|
|
|
dicall[ikey] = dic
|
|
|
|
|
|
for ikey, dic in dicall.items():
|
|
|
smis = [k for k in dic.keys()]
|
|
|
scorel = [d['score'] for d in dic.values()]
|
|
|
iscorl = [d['iscor'] for d in dic.values()]
|
|
|
indexl = [d['idx'] for d in dic.values()]
|
|
|
|
|
|
scoretsor = torch.tensor(scorel)
|
|
|
n = 100
|
|
|
if n > len(scorel):
|
|
|
n = len(scorel)
|
|
|
|
|
|
values, indices = torch.topk(scoretsor, n)
|
|
|
|
|
|
|
|
|
indices_list = indices.cpu().tolist()
|
|
|
|
|
|
scorel = values.cpu().numpy()
|
|
|
smis = [smis[i] for i in indices_list]
|
|
|
iscorl = [iscorl[i] for i in indices_list]
|
|
|
indexl = [indexl[i] for i in indices_list]
|
|
|
|
|
|
try:
|
|
|
i = iscorl.index(True)
|
|
|
k = 'Hit %.3d' %(i+1)
|
|
|
if k in coridxd:
|
|
|
coridxd[k] += 1
|
|
|
else:
|
|
|
coridxd[k] = 1
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
ks = sorted(list(coridxd.keys()))
|
|
|
dc = {}
|
|
|
sumtop3 = 0
|
|
|
|
|
|
for k in ks:
|
|
|
dc[k] = [coridxd[k]]
|
|
|
if k in ['Hit 001', 'Hit 002', 'Hit 003']:
|
|
|
sumtop3 += coridxd[k]
|
|
|
|
|
|
for i in range(100):
|
|
|
k = 'Hit %.3d' %(i+1)
|
|
|
if not k in dc:
|
|
|
dc[k] = [0]
|
|
|
|
|
|
return sumtop3, dc, dicall
|
|
|
|
|
|
def calc_rank(dicall):
|
|
|
rankd = {}
|
|
|
|
|
|
for ikey, dic in dicall.items():
|
|
|
smis = [k for k in dic.keys()]
|
|
|
scorel = [d['score'] for d in dic.values()]
|
|
|
iscorl = [d['iscor'] for d in dic.values()]
|
|
|
indexl = [d['idx'] for d in dic.values()]
|
|
|
|
|
|
scoretsor = torch.tensor(scorel)
|
|
|
n = 100
|
|
|
if n > len(scorel):
|
|
|
n = len(scorel)
|
|
|
|
|
|
values, indices = torch.topk(scoretsor, n)
|
|
|
|
|
|
scorel = values
|
|
|
smis = [smis[i] for i in indices]
|
|
|
iscorl = [iscorl[i] for i in indices]
|
|
|
indexl = [indexl[i] for i in indices]
|
|
|
|
|
|
sl = []
|
|
|
for n, smi in enumerate(smis):
|
|
|
sl.append(f'{scorel[n]}:{smi}:{smiles_to_inchikey(smi)}')
|
|
|
|
|
|
try:
|
|
|
i = iscorl.index(True)
|
|
|
rankd[ikey] = {'Hit': i+1, 'Rank': sl}
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
return rankd
|
|
|
|
|
|
def predict(modelfnl, datal, datafn=''):
|
|
|
maxtop3 = 0
|
|
|
maxoutt = ''
|
|
|
|
|
|
for fn in modelfnl:
|
|
|
d = torch.load(fn)
|
|
|
CFG.load(d['config'])
|
|
|
print(d['config'])
|
|
|
CFG.save('', True)
|
|
|
|
|
|
model = FragSimiModel(CFG).to(CFG.device)
|
|
|
model.load_state_dict(d['state_dict'])
|
|
|
|
|
|
sumtop3, dc, dicall = calc([model], datal, CFG)
|
|
|
|
|
|
sumtop10 = 0
|
|
|
for k in ['Hit %.3d' %(i+1) for i in range(10)]:
|
|
|
if k in dc:
|
|
|
sumtop10 += dc[k][0]
|
|
|
|
|
|
sumtop50 = 0
|
|
|
for k in ['Hit %.3d' %(i+1) for i in range(50)]:
|
|
|
if k in dc:
|
|
|
sumtop50 += dc[k][0]
|
|
|
|
|
|
tops = {}
|
|
|
for i in range(100):
|
|
|
k = 'Hit %.3d' %(i+1)
|
|
|
key = k.replace('Hit', 'Top')
|
|
|
if not key in tops:
|
|
|
tops[key] = [0]
|
|
|
if k in dc:
|
|
|
for n in range(i+1):
|
|
|
kk = 'Hit %.3d' %(n+1)
|
|
|
if kk in dc:
|
|
|
tops[key][0] += dc[kk][0]
|
|
|
|
|
|
outt = f'Top1: {dc.setdefault("Hit 001", [0])[0]}, top3: {sumtop3}, top10: {sumtop10}, top50: {sumtop50} of {len(datal)}'
|
|
|
|
|
|
if sumtop3 > maxtop3:
|
|
|
maxtop3 = sumtop3
|
|
|
maxoutt = outt
|
|
|
|
|
|
basefn = fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}')
|
|
|
rank = calc_rank(dicall)
|
|
|
json.dump(rank, open(basefn + '-predict-rank.json', 'w'), indent=2)
|
|
|
|
|
|
df = pd.DataFrame(tops)
|
|
|
df.to_csv(basefn + '-predict-summary.csv', index=False)
|
|
|
|
|
|
return maxoutt, maxtop3
|
|
|
|
|
|
def main(datafn, fnl):
|
|
|
outl = []
|
|
|
|
|
|
datal = json.load(open(datafn))
|
|
|
|
|
|
n = 0
|
|
|
for n, fn in enumerate(fnl):
|
|
|
out, _ = predict([fn], datal, datafn)
|
|
|
print(out, os.path.basename(fn))
|
|
|
outl.append(out)
|
|
|
|
|
|
print(outl)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
import time
|
|
|
t0 = time.time()
|
|
|
main(sys.argv[1], sys.argv[2:])
|
|
|
print(300, time.time()-t0)
|
|
|
|