Spaces:
Runtime error
Runtime error
from fcd_torch.utils import SmilesDataset, \ | |
calculate_frechet_distance, \ | |
todevice, \ | |
load_imported_model | |
import torch | |
from torch.utils.data import DataLoader | |
import os | |
import numpy as np | |
import warnings | |
class FCD: | |
""" | |
Computes Frechet ChemNet Distance on PyTorch. | |
* You can precalculate mean and sigma for further usage, | |
e.g. if you use the statistics from the same dataset | |
multiple times. | |
* Supports GPU and selection of GPU index | |
* Multithread SMILES parsing | |
Example 1: | |
fcd = FCD(device='cuda:0', n_jobs=8) | |
smiles_list = ['CCC', 'CCNC'] | |
fcd(smiles_list, smiles_list) | |
Example 2: | |
fcd = FCD(device='cuda:0', n_jobs=8) | |
smiles_list = ['CCC', 'CCNC'] | |
pgen = fcd.precalc(smiles_list) | |
fcd(smiles_list, pgen=pgen) | |
""" | |
def __init__(self, device='cpu', n_jobs=1, | |
batch_size=512, | |
model_path=None, | |
canonize=True): | |
""" | |
Loads ChemNet on device | |
params: | |
device: cpu for CPU, cuda:0 for GPU 0, etc. | |
n_jobs: number of workers to parse SMILES | |
batch_size: batch size for processing SMILES | |
model_path: path to ChemNet_v0.13_pretrained.pt | |
""" | |
if model_path is None: | |
model_dir = os.path.split(__file__)[0] | |
model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt') | |
self.device = device | |
self.n_jobs = n_jobs if n_jobs != 1 else 0 | |
self.batch_size = batch_size | |
keras_config = torch.load(model_path) | |
self.model = load_imported_model(keras_config) | |
self.model.eval() | |
self.canonize = canonize | |
def get_predictions(self, smiles_list): | |
if len(smiles_list) == 0: | |
return np.zeros((0, 512)) | |
dataloader = DataLoader( | |
SmilesDataset(smiles_list, canonize=self.canonize), | |
batch_size=self.batch_size, | |
num_workers=self.n_jobs | |
) | |
with todevice(self.model, self.device), torch.no_grad(): | |
chemnet_activations = [] | |
for batch in dataloader: | |
chemnet_activations.append( | |
self.model( | |
batch.transpose(1, 2).float().to(self.device) | |
).to('cpu').detach().numpy() | |
) | |
return np.row_stack(chemnet_activations) | |
def precalc(self, smiles_list): | |
if len(smiles_list) < 2: | |
warnings.warn("Can't compute FCD for less than 2 molecules" | |
"({} given)".format(len(smiles_list))) | |
return {} | |
chemnet_activations = self.get_predictions(smiles_list) | |
mu = chemnet_activations.mean(0) | |
sigma = np.cov(chemnet_activations.T) | |
return {'mu': mu, 'sigma': sigma} | |
def metric(self, pref, pgen): | |
if 'mu' not in pref or 'sigma' not in pgen: | |
warnings.warn("Failed to compute FCD (check ref)") | |
return np.nan | |
if 'mu' not in pgen or 'sigma' not in pgen: | |
warnings.warn("Failed to compute FCD (check gen)") | |
return np.nan | |
return calculate_frechet_distance( | |
pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma'] | |
) | |
def __call__(self, ref=None, gen=None, pref=None, pgen=None): | |
assert (ref is None) != (pref is None), "specify ref xor pref" | |
assert (gen is None) != (pgen is None), "specify gen xor pgen" | |
if pref is None: | |
pref = self.precalc(ref) | |
if pgen is None: | |
pgen = self.precalc(gen) | |
return self.metric(pref, pgen) | |