File size: 5,174 Bytes
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import numpy as np

from rdkit import Chem
from rdkit.Chem import AllChem
from src import const
from src.molecule_builder import get_bond_order
from scipy.stats import wasserstein_distance

from pdb import set_trace


def is_valid(mol):
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return False
    return True


def is_connected(mol):
    try:
        mol_frags = Chem.GetMolFrags(mol, asMols=True)
    except Chem.rdchem.AtomValenceException:
        return False
    if len(mol_frags) != 1:
        return False
    return True


def get_valid_molecules(molecules):
    valid = []
    for mol in molecules:
        if is_valid(mol):
            valid.append(mol)
    return valid


def get_connected_molecules(molecules):
    connected = []
    for mol in molecules:
        if is_connected(mol):
            connected.append(mol)
    return connected


def get_unique_smiles(valid_molecules):
    unique = set()
    for mol in valid_molecules:
        unique.add(Chem.MolToSmiles(mol))
    return list(unique)


def get_novel_smiles(unique_true_smiles, unique_pred_smiles):
    return list(set(unique_pred_smiles).difference(set(unique_true_smiles)))


def compute_energy(mol):
    mp = AllChem.MMFFGetMoleculeProperties(mol)
    energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy()
    return energy


def wasserstein_distance_between_energies(true_molecules, pred_molecules):
    true_energy_dist = []
    for mol in true_molecules:
        try:
            energy = compute_energy(mol)
            true_energy_dist.append(energy)
        except:
            continue

    pred_energy_dist = []
    for mol in pred_molecules:
        try:
            energy = compute_energy(mol)
            pred_energy_dist.append(energy)
        except:
            continue

    if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0:
        return wasserstein_distance(true_energy_dist, pred_energy_dist)
    else:
        return 0


def compute_metrics(pred_molecules, true_molecules):
    if len(pred_molecules) == 0:
        return {
            'validity': 0,
            'validity_and_connectivity': 0,
            'validity_as_in_delinker': 0,
            'uniqueness': 0,
            'novelty': 0,
            'energies': 0,
        }

    # Passing rdkit.Chem.Sanitize filter
    true_valid = get_valid_molecules(true_molecules)
    pred_valid = get_valid_molecules(pred_molecules)
    validity = len(pred_valid) / len(pred_molecules)

    # Checking if molecule consists of a single connected part
    true_valid_and_connected = get_connected_molecules(true_valid)
    pred_valid_and_connected = get_connected_molecules(pred_valid)
    validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules)

    # Unique molecules
    true_unique = get_unique_smiles(true_valid_and_connected)
    pred_unique = get_unique_smiles(pred_valid_and_connected)
    uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0

    # Novel molecules
    pred_novel = get_novel_smiles(true_unique, pred_unique)
    novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0

    # Difference between Energy distributions
    energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected)

    return {
        'validity': validity,
        'validity_and_connectivity': validity_and_connectivity,
        'uniqueness': uniqueness,
        'novelty': novelty,
        'energies': energies,
    }


# def check_stability(positions, atom_types):
#     assert len(positions.shape) == 2
#     assert positions.shape[1] == 3
#     x = positions[:, 0]
#     y = positions[:, 1]
#     z = positions[:, 2]
#
#     nr_bonds = np.zeros(len(x), dtype='int')
#     for i in range(len(x)):
#         for j in range(i + 1, len(x)):
#             p1 = np.array([x[i], y[i], z[i]])
#             p2 = np.array([x[j], y[j], z[j]])
#             dist = np.sqrt(np.sum((p1 - p2) ** 2))
#             atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()]
#             order = get_bond_order(atom1, atom2, dist)
#             nr_bonds[i] += order
#             nr_bonds[j] += order
#     nr_stable_bonds = 0
#     for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds):
#         possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]]
#         if type(possible_bonds) == int:
#             is_stable = possible_bonds == nr_bonds_i
#         else:
#             is_stable = nr_bonds_i in possible_bonds
#         nr_stable_bonds += int(is_stable)
#
#     molecule_stable = nr_stable_bonds == len(x)
#     return molecule_stable, nr_stable_bonds, len(x)
#
#
# def count_stable_molecules(one_hot, x, node_mask):
#     stable_molecules = 0
#     for i in range(len(one_hot)):
#         mol_size = node_mask[i].sum()
#         atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu()
#         positions = x[i][:mol_size, :].detach().cpu()
#         stable, _, _ = check_stability(positions, atom_types)
#         stable_molecules += int(stable)
#
#     return stable_molecules