Spaces:
Runtime error
Runtime error
| from fcd_torch.utils import SmilesDataset, \ | |
| calculate_frechet_distance, \ | |
| todevice, \ | |
| load_imported_model | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import os | |
| import numpy as np | |
| import warnings | |
| class FCD: | |
| """ | |
| Computes Frechet ChemNet Distance on PyTorch. | |
| * You can precalculate mean and sigma for further usage, | |
| e.g. if you use the statistics from the same dataset | |
| multiple times. | |
| * Supports GPU and selection of GPU index | |
| * Multithread SMILES parsing | |
| Example 1: | |
| fcd = FCD(device='cuda:0', n_jobs=8) | |
| smiles_list = ['CCC', 'CCNC'] | |
| fcd(smiles_list, smiles_list) | |
| Example 2: | |
| fcd = FCD(device='cuda:0', n_jobs=8) | |
| smiles_list = ['CCC', 'CCNC'] | |
| pgen = fcd.precalc(smiles_list) | |
| fcd(smiles_list, pgen=pgen) | |
| """ | |
| def __init__(self, device='cpu', n_jobs=1, | |
| batch_size=512, | |
| model_path=None, | |
| canonize=True): | |
| """ | |
| Loads ChemNet on device | |
| params: | |
| device: cpu for CPU, cuda:0 for GPU 0, etc. | |
| n_jobs: number of workers to parse SMILES | |
| batch_size: batch size for processing SMILES | |
| model_path: path to ChemNet_v0.13_pretrained.pt | |
| """ | |
| if model_path is None: | |
| model_dir = os.path.split(__file__)[0] | |
| model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt') | |
| self.device = device | |
| self.n_jobs = n_jobs if n_jobs != 1 else 0 | |
| self.batch_size = batch_size | |
| keras_config = torch.load(model_path) | |
| self.model = load_imported_model(keras_config) | |
| self.model.eval() | |
| self.canonize = canonize | |
| def get_predictions(self, smiles_list): | |
| if len(smiles_list) == 0: | |
| return np.zeros((0, 512)) | |
| dataloader = DataLoader( | |
| SmilesDataset(smiles_list, canonize=self.canonize), | |
| batch_size=self.batch_size, | |
| num_workers=self.n_jobs | |
| ) | |
| with todevice(self.model, self.device), torch.no_grad(): | |
| chemnet_activations = [] | |
| for batch in dataloader: | |
| chemnet_activations.append( | |
| self.model( | |
| batch.transpose(1, 2).float().to(self.device) | |
| ).to('cpu').detach().numpy() | |
| ) | |
| return np.row_stack(chemnet_activations) | |
| def precalc(self, smiles_list): | |
| if len(smiles_list) < 2: | |
| warnings.warn("Can't compute FCD for less than 2 molecules" | |
| "({} given)".format(len(smiles_list))) | |
| return {} | |
| chemnet_activations = self.get_predictions(smiles_list) | |
| mu = chemnet_activations.mean(0) | |
| sigma = np.cov(chemnet_activations.T) | |
| return {'mu': mu, 'sigma': sigma} | |
| def metric(self, pref, pgen): | |
| if 'mu' not in pref or 'sigma' not in pgen: | |
| warnings.warn("Failed to compute FCD (check ref)") | |
| return np.nan | |
| if 'mu' not in pgen or 'sigma' not in pgen: | |
| warnings.warn("Failed to compute FCD (check gen)") | |
| return np.nan | |
| return calculate_frechet_distance( | |
| pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma'] | |
| ) | |
| def __call__(self, ref=None, gen=None, pref=None, pgen=None): | |
| assert (ref is None) != (pref is None), "specify ref xor pref" | |
| assert (gen is None) != (pgen is None), "specify gen xor pgen" | |
| if pref is None: | |
| pref = self.precalc(ref) | |
| if pgen is None: | |
| pgen = self.precalc(gen) | |
| return self.metric(pref, pgen) | |