File size: 3,717 Bytes
36173e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from fcd_torch.utils import SmilesDataset, \
                            calculate_frechet_distance, \
                            todevice, \
                            load_imported_model
import torch
from torch.utils.data import DataLoader
import os
import numpy as np
import warnings


class FCD:
    """
    Computes Frechet ChemNet Distance on PyTorch.
    * You can precalculate mean and sigma for further usage,
      e.g. if you use the statistics from the same dataset
      multiple times.
    * Supports GPU and selection of GPU index
    * Multithread SMILES parsing

    Example 1:
        fcd = FCD(device='cuda:0', n_jobs=8)
        smiles_list = ['CCC', 'CCNC']
        fcd(smiles_list, smiles_list)

    Example 2:
        fcd = FCD(device='cuda:0', n_jobs=8)
        smiles_list = ['CCC', 'CCNC']
        pgen = fcd.precalc(smiles_list)
        fcd(smiles_list, pgen=pgen)
    """
    def __init__(self, device='cpu', n_jobs=1,
                 batch_size=512,
                 model_path=None,
                 canonize=True):
        """
        Loads ChemNet on device
        params:
            device: cpu for CPU, cuda:0 for GPU 0, etc.
            n_jobs: number of workers to parse SMILES
            batch_size: batch size for processing SMILES
            model_path: path to ChemNet_v0.13_pretrained.pt
        """
        if model_path is None:
            model_dir = os.path.split(__file__)[0]
            model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt')

        self.device = device
        self.n_jobs = n_jobs if n_jobs != 1 else 0
        self.batch_size = batch_size
        keras_config = torch.load(model_path)
        self.model = load_imported_model(keras_config)
        self.model.eval()
        self.canonize = canonize

    def get_predictions(self, smiles_list):
        if len(smiles_list) == 0:
            return np.zeros((0, 512))
        dataloader = DataLoader(
            SmilesDataset(smiles_list, canonize=self.canonize),
            batch_size=self.batch_size,
            num_workers=self.n_jobs
        )
        with todevice(self.model, self.device), torch.no_grad():
            chemnet_activations = []
            for batch in dataloader:
                chemnet_activations.append(
                    self.model(
                        batch.transpose(1, 2).float().to(self.device)
                    ).to('cpu').detach().numpy()
                )
        return np.row_stack(chemnet_activations)

    def precalc(self, smiles_list):
        if len(smiles_list) < 2:
            warnings.warn("Can't compute FCD for less than 2 molecules"
                          "({} given)".format(len(smiles_list)))
            return {}
        chemnet_activations = self.get_predictions(smiles_list)
        mu = chemnet_activations.mean(0)
        sigma = np.cov(chemnet_activations.T)
        return {'mu': mu, 'sigma': sigma}

    def metric(self, pref, pgen):
        if 'mu' not in pref or 'sigma' not in pgen:
            warnings.warn("Failed to compute FCD (check ref)")
            return np.nan
        if 'mu' not in pgen or 'sigma' not in pgen:
            warnings.warn("Failed to compute FCD (check gen)")
            return np.nan
        return calculate_frechet_distance(
            pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma']
        )

    def __call__(self, ref=None, gen=None, pref=None, pgen=None):
        assert (ref is None) != (pref is None), "specify ref xor pref"
        assert (gen is None) != (pgen is None), "specify gen xor pgen"
        if pref is None:
            pref = self.precalc(ref)
        if pgen is None:
            pgen = self.precalc(gen)
        return self.metric(pref, pgen)