Spaces:
Runtime error
Runtime error
File size: 3,717 Bytes
36173e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from fcd_torch.utils import SmilesDataset, \
calculate_frechet_distance, \
todevice, \
load_imported_model
import torch
from torch.utils.data import DataLoader
import os
import numpy as np
import warnings
class FCD:
"""
Computes Frechet ChemNet Distance on PyTorch.
* You can precalculate mean and sigma for further usage,
e.g. if you use the statistics from the same dataset
multiple times.
* Supports GPU and selection of GPU index
* Multithread SMILES parsing
Example 1:
fcd = FCD(device='cuda:0', n_jobs=8)
smiles_list = ['CCC', 'CCNC']
fcd(smiles_list, smiles_list)
Example 2:
fcd = FCD(device='cuda:0', n_jobs=8)
smiles_list = ['CCC', 'CCNC']
pgen = fcd.precalc(smiles_list)
fcd(smiles_list, pgen=pgen)
"""
def __init__(self, device='cpu', n_jobs=1,
batch_size=512,
model_path=None,
canonize=True):
"""
Loads ChemNet on device
params:
device: cpu for CPU, cuda:0 for GPU 0, etc.
n_jobs: number of workers to parse SMILES
batch_size: batch size for processing SMILES
model_path: path to ChemNet_v0.13_pretrained.pt
"""
if model_path is None:
model_dir = os.path.split(__file__)[0]
model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt')
self.device = device
self.n_jobs = n_jobs if n_jobs != 1 else 0
self.batch_size = batch_size
keras_config = torch.load(model_path)
self.model = load_imported_model(keras_config)
self.model.eval()
self.canonize = canonize
def get_predictions(self, smiles_list):
if len(smiles_list) == 0:
return np.zeros((0, 512))
dataloader = DataLoader(
SmilesDataset(smiles_list, canonize=self.canonize),
batch_size=self.batch_size,
num_workers=self.n_jobs
)
with todevice(self.model, self.device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
self.model(
batch.transpose(1, 2).float().to(self.device)
).to('cpu').detach().numpy()
)
return np.row_stack(chemnet_activations)
def precalc(self, smiles_list):
if len(smiles_list) < 2:
warnings.warn("Can't compute FCD for less than 2 molecules"
"({} given)".format(len(smiles_list)))
return {}
chemnet_activations = self.get_predictions(smiles_list)
mu = chemnet_activations.mean(0)
sigma = np.cov(chemnet_activations.T)
return {'mu': mu, 'sigma': sigma}
def metric(self, pref, pgen):
if 'mu' not in pref or 'sigma' not in pgen:
warnings.warn("Failed to compute FCD (check ref)")
return np.nan
if 'mu' not in pgen or 'sigma' not in pgen:
warnings.warn("Failed to compute FCD (check gen)")
return np.nan
return calculate_frechet_distance(
pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma']
)
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
assert (ref is None) != (pref is None), "specify ref xor pref"
assert (gen is None) != (pgen is None), "specify gen xor pgen"
if pref is None:
pref = self.precalc(ref)
if pgen is None:
pgen = self.precalc(gen)
return self.metric(pref, pgen)
|