CMSSP / code /dataset.py
OliXio's picture
Upload 5 files
5946936 verified
import os, json
import torch
import utils
def calc_feats(smi, ms, nls, cfg):
item = {}
item['ms_bins'] = utils.ms_binner(ms, nls,
min_mz=cfg.min_mz,
max_mz=cfg.max_mz,
bin_size=cfg.bin_size,
add_nl=cfg.add_nl,
binary_intn=cfg.binary_intn)
fmcalced = False
if 'fp' in cfg.mol_encoder:
if not 'fm' in cfg.mol_encoder:
item['mol_fps'] = utils.mol_fp_encoder(smi,
tp=cfg.fptype,
nbits=cfg.mol_embedding_dim)
else:
item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
tp=cfg.fptype,
nbits=cfg.mol_embedding_dim)
fmcalced = True
if 'gnn' in cfg.mol_encoder:
f = utils.mol_graph_featurizer(smi)
if not f:
return None
item.update(f)
if 'fm' in cfg.mol_encoder and not fmcalced:
item['mol_fmvec'] = utils.smi2fmvec(smi)
return item
class Dataset(torch.utils.data.Dataset):
def __init__(self, inp, cfg):
if type(inp) is str:
self.data = json.load(open(inp))
else:
self.data = inp
self.cfg = cfg
def __getitem__(self, idx):
item = {}
try:
if 'ms_bins' in self.data[idx]:
return self.data[idx]
if 'nls' in self.data[idx]:
nls = self.data[idx]['nls']
else:
nls = []
ms = self.data[idx]['ms']
smi = self.data[idx]['smiles']
item = calc_feats(smi, ms, nls, self.cfg)
except Exception as e:
print('='*50, idx, str(e))
return None
return item
def __len__(self):
return len(self.data)
class DatasetGNNFP(torch.utils.data.Dataset):
def __init__(self, inp, cfg):
if type(inp) is str:
self.data = json.load(open(inp))
else:
self.data = inp
self.cfg = cfg
def __getitem__(self, idx):
try:
smi = self.data[idx]['smiles']
item = {}
item['mol_fps'] = utils.mol_fp_encoder(smi,
tp=self.cfg.fptype,
nbits=self.cfg.mol_embedding_dim)
item.update(utils.mol_graph_featurizer(smi))
except Exception as e:
print('='*50, idx, str(e))
return None
return item
def __len__(self):
return len(self.data)
class PathDataset(torch.utils.data.Dataset):
def __init__(self, pathlist, cfg):
self.fns = pathlist
self.cfg = cfg
self.data = {}
def __getitem__(self, idx):
try:
item = {}
nls = []
if not idx in self.data:
out = self.proc_data(self.fns[idx], self.cfg.energy)
if out is None:
return None
self.data[idx] = out
ms = self.data[idx]['ms']
smi = self.data[idx]['smiles']
item = calc_feats(smi, ms, nls, self.cfg)
except Exception as e:
#print('='*50, idx, str(e))
return None
return item
def proc_data(self, fn, energy='Energy1'):
tl = open(fn).readlines()
l = []
try:
flag = False
for i in tl:
if energy in i:
smi = i.split(';')[-2]
flag = True
continue
if 'END IONS' in i:
if flag:
break
if flag:
mz, intn = i.split(' ')
l.append((float(mz), float(intn)))
except:
return None
out = {'ms': l, 'smiles': smi}
return out
def __len__(self):
return len(self.fns)