molgen_metric / fcd.py
saicharan2804
Added manual implementation of metrics
36173e1
raw
history blame contribute delete
No virus
3.72 kB
from fcd_torch.utils import SmilesDataset, \
calculate_frechet_distance, \
todevice, \
load_imported_model
import torch
from torch.utils.data import DataLoader
import os
import numpy as np
import warnings
class FCD:
"""
Computes Frechet ChemNet Distance on PyTorch.
* You can precalculate mean and sigma for further usage,
e.g. if you use the statistics from the same dataset
multiple times.
* Supports GPU and selection of GPU index
* Multithread SMILES parsing
Example 1:
fcd = FCD(device='cuda:0', n_jobs=8)
smiles_list = ['CCC', 'CCNC']
fcd(smiles_list, smiles_list)
Example 2:
fcd = FCD(device='cuda:0', n_jobs=8)
smiles_list = ['CCC', 'CCNC']
pgen = fcd.precalc(smiles_list)
fcd(smiles_list, pgen=pgen)
"""
def __init__(self, device='cpu', n_jobs=1,
batch_size=512,
model_path=None,
canonize=True):
"""
Loads ChemNet on device
params:
device: cpu for CPU, cuda:0 for GPU 0, etc.
n_jobs: number of workers to parse SMILES
batch_size: batch size for processing SMILES
model_path: path to ChemNet_v0.13_pretrained.pt
"""
if model_path is None:
model_dir = os.path.split(__file__)[0]
model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt')
self.device = device
self.n_jobs = n_jobs if n_jobs != 1 else 0
self.batch_size = batch_size
keras_config = torch.load(model_path)
self.model = load_imported_model(keras_config)
self.model.eval()
self.canonize = canonize
def get_predictions(self, smiles_list):
if len(smiles_list) == 0:
return np.zeros((0, 512))
dataloader = DataLoader(
SmilesDataset(smiles_list, canonize=self.canonize),
batch_size=self.batch_size,
num_workers=self.n_jobs
)
with todevice(self.model, self.device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
self.model(
batch.transpose(1, 2).float().to(self.device)
).to('cpu').detach().numpy()
)
return np.row_stack(chemnet_activations)
def precalc(self, smiles_list):
if len(smiles_list) < 2:
warnings.warn("Can't compute FCD for less than 2 molecules"
"({} given)".format(len(smiles_list)))
return {}
chemnet_activations = self.get_predictions(smiles_list)
mu = chemnet_activations.mean(0)
sigma = np.cov(chemnet_activations.T)
return {'mu': mu, 'sigma': sigma}
def metric(self, pref, pgen):
if 'mu' not in pref or 'sigma' not in pgen:
warnings.warn("Failed to compute FCD (check ref)")
return np.nan
if 'mu' not in pgen or 'sigma' not in pgen:
warnings.warn("Failed to compute FCD (check gen)")
return np.nan
return calculate_frechet_distance(
pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma']
)
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
assert (ref is None) != (pref is None), "specify ref xor pref"
assert (gen is None) != (pgen is None), "specify gen xor pgen"
if pref is None:
pref = self.precalc(ref)
if pgen is None:
pgen = self.precalc(gen)
return self.metric(pref, pgen)