igashov commited on
Commit
95ba5bc
·
1 Parent(s): e361c7a

DiffLinker code

Browse files
app.py CHANGED
@@ -1,6 +1,14 @@
1
  import gradio as gr
2
  import os
3
- import sys
 
 
 
 
 
 
 
 
4
 
5
 
6
  HTML_TEMPLATE = """<!DOCTYPE html>
@@ -43,20 +51,88 @@ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
43
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
44
 
45
 
46
- def read_molecule(path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with open(path, "r") as f:
48
  return "".join(f.readlines())
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def generate(input_file):
52
  try:
53
  path = input_file.name
54
  molecule = read_molecule(path)
55
- fmt = path.split('.')[-1]
56
- except:
57
- return 'Error: could not open the provided file'
58
-
59
- html = HTML_TEMPLATE.format(molecule=molecule, fmt=fmt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return IFRAME_TEMPLATE.format(html=html)
61
 
62
 
 
1
  import gradio as gr
2
  import os
3
+ import torch
4
+ import subprocess
5
+
6
+ from rdkit import Chem
7
+ from src import const
8
+ from src.visualizer import save_xyz_file
9
+ from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
10
+ from src.lightning import DDPM
11
+ from src.linker_size_lightning import SizeClassifier
12
 
13
 
14
  HTML_TEMPLATE = """<!DOCTYPE html>
 
51
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
52
 
53
 
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ os.makedirs("results", exist_ok=True)
56
+ print('Created results directory')
57
+
58
+ size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
59
+ print('Loaded SizeGNN model')
60
+
61
+ ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
62
+ print('Loaded diffusion model')
63
+
64
+
65
+ def sample_fn(_data):
66
+ output, _ = size_nn.forward(_data)
67
+ probabilities = torch.softmax(output, dim=1)
68
+ distribution = torch.distributions.Categorical(probs=probabilities)
69
+ samples = distribution.sample()
70
+ sizes = []
71
+ for label in samples.detach().cpu().numpy():
72
+ sizes.append(size_nn.linker_id2size[label])
73
+ sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
74
+ return sizes
75
+
76
+
77
+ def read_molecule_content(path):
78
  with open(path, "r") as f:
79
  return "".join(f.readlines())
80
 
81
 
82
+ def read_molecule(path):
83
+ if path.endswith('.pdb'):
84
+ return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
85
+ elif path.endswith('.mol'):
86
+ return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
87
+ elif path.endswith('.mol2'):
88
+ return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
89
+ elif path.endswith('.sdf'):
90
+ return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
91
+ raise Exception('Unknown file extension')
92
+
93
+
94
  def generate(input_file):
95
  try:
96
  path = input_file.name
97
  molecule = read_molecule(path)
98
+ name = '.'.join(molecule.split('/')[-1].split('.')[:-1])
99
+ out_sdf = f'results/{name}_generated.sdf'
100
+ print(f'Input path={path}, name={name}')
101
+ except Exception as e:
102
+ return f'Could not read the molecule: {e}'
103
+
104
+ positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
105
+ positions = torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device)
106
+ one_hot = torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device)
107
+ print('Read and parsed molecule')
108
+
109
+ dataset = [{
110
+ 'uuid': '0',
111
+ 'name': '0',
112
+ 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
113
+ 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
114
+ 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
115
+ 'anchors': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
116
+ 'fragment_mask': torch.ones_like(charges, dtype=const.TORCH_FLOAT, device=device),
117
+ 'linker_mask': torch.zeros_like(charges, dtype=const.TORCH_FLOAT, device=device),
118
+ 'num_atoms': len(positions),
119
+ }]
120
+ dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
121
+ print('Created dataloader')
122
+
123
+ for data in dataloader:
124
+ chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
125
+ print('Generated linker')
126
+ x = chain[0][:, :, :ddpm.n_dims]
127
+ h = chain[0][:, :, ddpm.n_dims:]
128
+ save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='generated')
129
+ print('Saved XYZ file')
130
+ subprocess.run(f'obabel results/{name}_generated.xyz -O {out_sdf}', shell=True)
131
+ print('Converted to SDF')
132
+ break
133
+
134
+ generated_molecule = read_molecule_content(out_sdf)
135
+ html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf')
136
  return IFRAME_TEMPLATE.format(html=html)
137
 
138
 
src/__init__.py ADDED
File without changes
src/const.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from rdkit import Chem
4
+
5
+
6
+ TORCH_FLOAT = torch.float32
7
+ TORCH_INT = torch.int8
8
+
9
+ # #################################################################################### #
10
+ # ####################################### ZINC ####################################### #
11
+ # #################################################################################### #
12
+
13
+ # Atom idx for one-hot encoding
14
+ ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7}
15
+ IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I'}
16
+
17
+ # Atomic numbers (Z)
18
+ CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}
19
+
20
+ # One-hot atom types
21
+ NUMBER_OF_ATOM_TYPES = len(ATOM2IDX)
22
+
23
+
24
+ # #################################################################################### #
25
+ # ####################################### GEOM ####################################### #
26
+ # #################################################################################### #
27
+
28
+ # Atom idx for one-hot encoding
29
+ GEOM_ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7, 'P': 8}
30
+ GEOM_IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I', 8: 'P'}
31
+
32
+ # Atomic numbers (Z)
33
+ GEOM_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
34
+
35
+ # One-hot atom types
36
+ GEOM_NUMBER_OF_ATOM_TYPES = len(GEOM_ATOM2IDX)
37
+
38
+ # Dataset keys
39
+ DATA_LIST_ATTRS = {
40
+ 'uuid', 'name', 'fragments_smi', 'linker_smi', 'num_atoms'
41
+ }
42
+ DATA_ATTRS_TO_PAD = {
43
+ 'positions', 'one_hot', 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
44
+ }
45
+ DATA_ATTRS_TO_ADD_LAST_DIM = {
46
+ 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
47
+ }
48
+
49
+ # Distribution of linker size in train data
50
+ LINKER_SIZE_DIST = {
51
+ 4: 85540,
52
+ 3: 113928,
53
+ 6: 70946,
54
+ 7: 30408,
55
+ 5: 77671,
56
+ 9: 5177,
57
+ 10: 1214,
58
+ 8: 12712,
59
+ 11: 158,
60
+ 12: 7,
61
+ }
62
+
63
+
64
+ # Bond lengths from:
65
+ # http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
66
+ # And:
67
+ # http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf
68
+ BONDS_1 = {
69
+ 'H': {
70
+ 'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92,
71
+ 'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134,
72
+ 'Cl': 127, 'Br': 141, 'I': 161
73
+ },
74
+ 'C': {
75
+ 'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135,
76
+ 'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194,
77
+ 'I': 214
78
+ },
79
+ 'N': {
80
+ 'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136,
81
+ 'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177
82
+ },
83
+ 'O': {
84
+ 'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142,
85
+ 'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164,
86
+ 'I': 194
87
+ },
88
+ 'F': {
89
+ 'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142,
90
+ 'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156,
91
+ 'I': 187
92
+ },
93
+ 'B': {
94
+ 'H': 119, 'Cl': 175
95
+ },
96
+ 'Si': {
97
+ 'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200,
98
+ 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243,
99
+ },
100
+ 'Cl': {
101
+ 'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164,
102
+ 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166,
103
+ 'Br': 214
104
+ },
105
+ 'S': {
106
+ 'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204,
107
+ 'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210,
108
+ 'I': 234
109
+ },
110
+ 'Br': {
111
+ 'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214,
112
+ 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222
113
+ },
114
+ 'P': {
115
+ 'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203,
116
+ 'S': 210, 'F': 156, 'N': 177, 'Br': 222
117
+ },
118
+ 'I': {
119
+ 'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194,
120
+ 'S': 234, 'F': 187, 'I': 266
121
+ },
122
+ 'As': {
123
+ 'H': 152
124
+ }
125
+ }
126
+
127
+ BONDS_2 = {
128
+ 'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160},
129
+ 'N': {'C': 129, 'N': 125, 'O': 121},
130
+ 'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150},
131
+ 'P': {'O': 150, 'S': 186},
132
+ 'S': {'P': 186}
133
+ }
134
+
135
+ BONDS_3 = {
136
+ 'C': {'C': 120, 'N': 116, 'O': 113},
137
+ 'N': {'C': 116, 'N': 110},
138
+ 'O': {'C': 113}
139
+ }
140
+
141
+ BOND_DICT = [
142
+ None,
143
+ Chem.rdchem.BondType.SINGLE,
144
+ Chem.rdchem.BondType.DOUBLE,
145
+ Chem.rdchem.BondType.TRIPLE,
146
+ Chem.rdchem.BondType.AROMATIC,
147
+ ]
148
+
149
+ BOND2IDX = {
150
+ Chem.rdchem.BondType.SINGLE: 1,
151
+ Chem.rdchem.BondType.DOUBLE: 2,
152
+ Chem.rdchem.BondType.TRIPLE: 3,
153
+ Chem.rdchem.BondType.AROMATIC: 4,
154
+ }
155
+
156
+ ALLOWED_BONDS = {
157
+ 'H': 1,
158
+ 'C': 4,
159
+ 'N': 3,
160
+ 'O': 2,
161
+ 'F': 1,
162
+ 'B': 3,
163
+ 'Al': 3,
164
+ 'Si': 4,
165
+ 'P': [3, 5],
166
+ 'S': 4,
167
+ 'Cl': 1,
168
+ 'As': 3,
169
+ 'Br': 1,
170
+ 'I': 1,
171
+ 'Hg': [1, 2],
172
+ 'Bi': [3, 5]
173
+ }
174
+
175
+ MARGINS_EDM = [10, 5, 2]
176
+
177
+ COLORS = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8']
178
+ # RADII = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
179
+ RADII = [0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77]
180
+
181
+ ZINC_TRAIN_LINKER_ID2SIZE = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
182
+ ZINC_TRAIN_LINKER_SIZE2ID = {
183
+ size: idx
184
+ for idx, size in enumerate(ZINC_TRAIN_LINKER_ID2SIZE)
185
+ }
186
+ ZINC_TRAIN_LINKER_SIZE_WEIGHTS = [
187
+ 3.47347831e-01,
188
+ 4.63079100e-01,
189
+ 5.12370917e-01,
190
+ 5.62392614e-01,
191
+ 1.30294388e+00,
192
+ 3.24247801e+00,
193
+ 8.12391184e+00,
194
+ 3.45634358e+01,
195
+ 2.72428571e+02,
196
+ 6.26585714e+03
197
+ ]
198
+
199
+
200
+ GEOM_TRAIN_LINKER_ID2SIZE = [
201
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
202
+ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 38, 41
203
+ ]
204
+ GEOM_TRAIN_LINKER_SIZE2ID = {
205
+ size: idx
206
+ for idx, size in enumerate(GEOM_TRAIN_LINKER_ID2SIZE)
207
+ }
208
+ GEOM_TRAIN_LINKER_SIZE_WEIGHTS = [
209
+ 1.07790681e+00, 4.54693604e-01, 3.62575713e-01, 3.75199484e-01,
210
+ 3.67812588e-01, 3.92388528e-01, 3.83421054e-01, 4.26924670e-01,
211
+ 4.92768040e-01, 4.99761944e-01, 4.92342726e-01, 5.71456905e-01,
212
+ 7.30631393e-01, 8.45412928e-01, 9.97252243e-01, 1.25423985e+00,
213
+ 1.57316129e+00, 2.19902962e+00, 3.22640431e+00, 4.25481066e+00,
214
+ 6.34749573e+00, 9.00676236e+00, 1.43084017e+01, 2.25763173e+01,
215
+ 3.36867096e+01, 9.50713805e+01, 2.08693274e+02, 2.51659537e+02,
216
+ 7.77856749e+02, 8.55642424e+03, 8.55642424e+03, 4.27821212e+03,
217
+ 4.27821212e+03
218
+ ]
src/datasets.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pickle
5
+ import torch
6
+
7
+ from rdkit import Chem
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from tqdm import tqdm
10
+ from src import const
11
+
12
+
13
+ from pdb import set_trace
14
+
15
+
16
+ def read_sdf(sdf_path):
17
+ with Chem.SDMolSupplier(sdf_path, sanitize=False) as supplier:
18
+ for molecule in supplier:
19
+ yield molecule
20
+
21
+
22
+ def get_one_hot(atom, atoms_dict):
23
+ one_hot = np.zeros(len(atoms_dict))
24
+ one_hot[atoms_dict[atom]] = 1
25
+ return one_hot
26
+
27
+
28
+ def parse_molecule(mol, is_geom):
29
+ one_hot = []
30
+ charges = []
31
+ atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
32
+ charges_dict = const.GEOM_CHARGES if is_geom else const.CHARGES
33
+ for atom in mol.GetAtoms():
34
+ one_hot.append(get_one_hot(atom.GetSymbol(), atom2idx))
35
+ charges.append(charges_dict[atom.GetSymbol()])
36
+ positions = mol.GetConformer().GetPositions()
37
+ return positions, np.array(one_hot), np.array(charges)
38
+
39
+
40
+ class ZincDataset(Dataset):
41
+ def __init__(self, data_path, prefix, device):
42
+ dataset_path = os.path.join(data_path, f'{prefix}.pt')
43
+ if os.path.exists(dataset_path):
44
+ self.data = torch.load(dataset_path, map_location=device)
45
+ else:
46
+ print(f'Preprocessing dataset with prefix {prefix}')
47
+ self.data = ZincDataset.preprocess(data_path, prefix, device)
48
+ torch.save(self.data, dataset_path)
49
+
50
+ def __len__(self):
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, item):
54
+ return self.data[item]
55
+
56
+ @staticmethod
57
+ def preprocess(data_path, prefix, device):
58
+ data = []
59
+ table_path = os.path.join(data_path, f'{prefix}_table.csv')
60
+ fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
61
+ linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
62
+
63
+ is_geom = ('geom' in prefix) or ('MOAD' in prefix)
64
+ is_multifrag = 'multifrag' in prefix
65
+
66
+ table = pd.read_csv(table_path)
67
+ generator = tqdm(zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path)), total=len(table))
68
+ for (_, row), fragments, linker in generator:
69
+ uuid = row['uuid']
70
+ name = row['molecule']
71
+ frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
72
+ link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
73
+
74
+ positions = np.concatenate([frag_pos, link_pos], axis=0)
75
+ one_hot = np.concatenate([frag_one_hot, link_one_hot], axis=0)
76
+ charges = np.concatenate([frag_charges, link_charges], axis=0)
77
+ anchors = np.zeros_like(charges)
78
+
79
+ if is_multifrag:
80
+ for anchor_idx in map(int, row['anchors'].split('-')):
81
+ anchors[anchor_idx] = 1
82
+ else:
83
+ anchors[row['anchor_1']] = 1
84
+ anchors[row['anchor_2']] = 1
85
+ fragment_mask = np.concatenate([np.ones_like(frag_charges), np.zeros_like(link_charges)])
86
+ linker_mask = np.concatenate([np.zeros_like(frag_charges), np.ones_like(link_charges)])
87
+
88
+ data.append({
89
+ 'uuid': uuid,
90
+ 'name': name,
91
+ 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
92
+ 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
93
+ 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
94
+ 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
95
+ 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
96
+ 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
97
+ 'num_atoms': len(positions),
98
+ })
99
+
100
+ return data
101
+
102
+
103
+ class MOADDataset(Dataset):
104
+ def __init__(self, data_path, prefix, device):
105
+ prefix, pocket_mode = prefix.split('.')
106
+
107
+ dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
108
+ if os.path.exists(dataset_path):
109
+ self.data = torch.load(dataset_path, map_location=device)
110
+ else:
111
+ print(f'Preprocessing dataset with prefix {prefix}')
112
+ self.data = MOADDataset.preprocess(data_path, prefix, pocket_mode, device)
113
+ torch.save(self.data, dataset_path)
114
+
115
+ def __len__(self):
116
+ return len(self.data)
117
+
118
+ def __getitem__(self, item):
119
+ return self.data[item]
120
+
121
+ @staticmethod
122
+ def preprocess(data_path, prefix, pocket_mode, device):
123
+ data = []
124
+ table_path = os.path.join(data_path, f'{prefix}_table.csv')
125
+ fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
126
+ linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
127
+ pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl')
128
+
129
+ is_geom = True
130
+ is_multifrag = 'multifrag' in prefix
131
+
132
+ with open(pockets_path, 'rb') as f:
133
+ pockets = pickle.load(f)
134
+
135
+ table = pd.read_csv(table_path)
136
+ generator = tqdm(
137
+ zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets),
138
+ total=len(table)
139
+ )
140
+ for (_, row), fragments, linker, pocket_data in generator:
141
+ uuid = row['uuid']
142
+ name = row['molecule']
143
+ frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
144
+ link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
145
+
146
+ # Parsing pocket data
147
+ pocket_pos = pocket_data[f'{pocket_mode}_coord']
148
+ pocket_one_hot = []
149
+ pocket_charges = []
150
+ for atom_type in pocket_data[f'{pocket_mode}_types']:
151
+ pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
152
+ pocket_charges.append(const.GEOM_CHARGES[atom_type])
153
+ pocket_one_hot = np.array(pocket_one_hot)
154
+ pocket_charges = np.array(pocket_charges)
155
+
156
+ positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0)
157
+ one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0)
158
+ charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0)
159
+ anchors = np.zeros_like(charges)
160
+
161
+ if is_multifrag:
162
+ for anchor_idx in map(int, row['anchors'].split('-')):
163
+ anchors[anchor_idx] = 1
164
+ else:
165
+ anchors[row['anchor_1']] = 1
166
+ anchors[row['anchor_2']] = 1
167
+
168
+ fragment_only_mask = np.concatenate([
169
+ np.ones_like(frag_charges),
170
+ np.zeros_like(pocket_charges),
171
+ np.zeros_like(link_charges)
172
+ ])
173
+ pocket_mask = np.concatenate([
174
+ np.zeros_like(frag_charges),
175
+ np.ones_like(pocket_charges),
176
+ np.zeros_like(link_charges)
177
+ ])
178
+ linker_mask = np.concatenate([
179
+ np.zeros_like(frag_charges),
180
+ np.zeros_like(pocket_charges),
181
+ np.ones_like(link_charges)
182
+ ])
183
+ fragment_mask = np.concatenate([
184
+ np.ones_like(frag_charges),
185
+ np.ones_like(pocket_charges),
186
+ np.zeros_like(link_charges)
187
+ ])
188
+
189
+ data.append({
190
+ 'uuid': uuid,
191
+ 'name': name,
192
+ 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
193
+ 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
194
+ 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
195
+ 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
196
+ 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
197
+ 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
198
+ 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
199
+ 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
200
+ 'num_atoms': len(positions),
201
+ })
202
+
203
+ return data
204
+
205
+ @staticmethod
206
+ def create_edges(positions, fragment_mask_only, linker_mask_only):
207
+ ligand_mask = fragment_mask_only.astype(bool) | linker_mask_only.astype(bool)
208
+ ligand_adj = ligand_mask[:, None] & ligand_mask[None, :]
209
+ proximity_adj = np.linalg.norm(positions[:, None, :] - positions[None, :, :], axis=-1) <= 6
210
+ full_adj = ligand_adj | proximity_adj
211
+ full_adj &= ~np.eye(len(positions)).astype(bool)
212
+
213
+ curr_rows, curr_cols = np.where(full_adj)
214
+ return [curr_rows, curr_cols]
215
+
216
+
217
+ def collate(batch):
218
+ out = {}
219
+
220
+ # Filter out big molecules
221
+ if 'pocket_mask' not in batch[0].keys():
222
+ batch = [data for data in batch if data['num_atoms'] <= 50]
223
+ else:
224
+ batch = [data for data in batch if data['num_atoms'] <= 1000]
225
+
226
+ for i, data in enumerate(batch):
227
+ for key, value in data.items():
228
+ out.setdefault(key, []).append(value)
229
+
230
+ for key, value in out.items():
231
+ if key in const.DATA_LIST_ATTRS:
232
+ continue
233
+ if key in const.DATA_ATTRS_TO_PAD:
234
+ out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
235
+ continue
236
+ raise Exception(f'Unknown batch key: {key}')
237
+
238
+ atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
239
+ out['atom_mask'] = atom_mask[:, :, None]
240
+
241
+ batch_size, n_nodes = atom_mask.size()
242
+
243
+ # In case of MOAD edge_mask is batch_idx
244
+ if 'pocket_mask' in batch[0].keys():
245
+ batch_mask = torch.cat([
246
+ torch.ones(n_nodes, dtype=const.TORCH_INT) * i
247
+ for i in range(batch_size)
248
+ ]).to(atom_mask.device)
249
+ out['edge_mask'] = batch_mask
250
+ else:
251
+ edge_mask = atom_mask[:, None, :] * atom_mask[:, :, None]
252
+ diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=atom_mask.device).unsqueeze(0)
253
+ edge_mask *= diag_mask
254
+ out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
255
+
256
+ for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
257
+ if key in out.keys():
258
+ out[key] = out[key][:, :, None]
259
+
260
+ return out
261
+
262
+
263
+ def collate_with_fragment_edges(batch):
264
+ out = {}
265
+
266
+ # Filter out big molecules
267
+ batch = [data for data in batch if data['num_atoms'] <= 50]
268
+
269
+ for i, data in enumerate(batch):
270
+ for key, value in data.items():
271
+ out.setdefault(key, []).append(value)
272
+
273
+ for key, value in out.items():
274
+ if key in const.DATA_LIST_ATTRS:
275
+ continue
276
+ if key in const.DATA_ATTRS_TO_PAD:
277
+ out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0)
278
+ continue
279
+ raise Exception(f'Unknown batch key: {key}')
280
+
281
+ frag_mask = out['fragment_mask']
282
+ edge_mask = frag_mask[:, None, :] * frag_mask[:, :, None]
283
+ diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=frag_mask.device).unsqueeze(0)
284
+ edge_mask *= diag_mask
285
+
286
+ batch_size, n_nodes = frag_mask.size()
287
+ out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1)
288
+
289
+ # Building edges and covalent bond values
290
+ rows, cols, bonds = [], [], []
291
+ for batch_idx in range(batch_size):
292
+ for i in range(n_nodes):
293
+ for j in range(n_nodes):
294
+ rows.append(i + batch_idx * n_nodes)
295
+ cols.append(j + batch_idx * n_nodes)
296
+
297
+ edges = [torch.LongTensor(rows).to(frag_mask.device), torch.LongTensor(cols).to(frag_mask.device)]
298
+ out['edges'] = edges
299
+
300
+ atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT)
301
+ out['atom_mask'] = atom_mask[:, :, None]
302
+
303
+ for key in const.DATA_ATTRS_TO_ADD_LAST_DIM:
304
+ if key in out.keys():
305
+ out[key] = out[key][:, :, None]
306
+
307
+ return out
308
+
309
+
310
+ def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False):
311
+ return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle)
312
+
313
+
314
+ def create_template(tensor, fragment_size, linker_size, fill=0):
315
+ values_to_keep = tensor[:fragment_size]
316
+ values_to_add = torch.ones(linker_size, tensor.shape[1], dtype=values_to_keep.dtype, device=values_to_keep.device)
317
+ values_to_add = values_to_add * fill
318
+ return torch.cat([values_to_keep, values_to_add], dim=0)
319
+
320
+
321
+ def create_templates_for_linker_generation(data, linker_sizes):
322
+ """
323
+ Takes data batch and new linker size and returns data batch where fragment-related data is the same
324
+ but linker-related data is replaced with zero templates with new linker sizes
325
+ """
326
+ decoupled_data = []
327
+ for i, linker_size in enumerate(linker_sizes):
328
+ data_dict = {}
329
+ fragment_mask = data['fragment_mask'][i].squeeze()
330
+ fragment_size = fragment_mask.sum().int()
331
+ for k, v in data.items():
332
+ if k == 'num_atoms':
333
+ # Computing new number of atoms (fragment_size + linker_size)
334
+ data_dict[k] = fragment_size + linker_size
335
+ continue
336
+ if k in const.DATA_LIST_ATTRS:
337
+ # These attributes are written without modification
338
+ data_dict[k] = v[i]
339
+ continue
340
+ if k in const.DATA_ATTRS_TO_PAD:
341
+ # Should write fragment-related data + (zeros x linker_size)
342
+ fill_value = 1 if k == 'linker_mask' else 0
343
+ template = create_template(v[i], fragment_size, linker_size, fill=fill_value)
344
+ if k in const.DATA_ATTRS_TO_ADD_LAST_DIM:
345
+ template = template.squeeze(-1)
346
+ data_dict[k] = template
347
+
348
+ decoupled_data.append(data_dict)
349
+
350
+ return collate(decoupled_data)
src/delinker.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import numpy as np
3
+
4
+ from rdkit import Chem
5
+ from rdkit.Chem import MolStandardize
6
+ from src import metrics
7
+ from src.delinker_utils import sascorer, calc_SC_RDKit
8
+ from tqdm import tqdm
9
+
10
+ from pdb import set_trace
11
+
12
+
13
+ def get_valid_as_in_delinker(data, progress=False):
14
+ valid = []
15
+ generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data)
16
+ for i, m in generator:
17
+ pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False)
18
+ true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False)
19
+ frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False)
20
+
21
+ pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False)
22
+ pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms())
23
+
24
+ try:
25
+ Chem.SanitizeMol(pred_mol_filtered)
26
+ Chem.SanitizeMol(true_mol)
27
+ Chem.SanitizeMol(frag)
28
+ except:
29
+ continue
30
+
31
+ if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0:
32
+ valid.append({
33
+ 'pred_mol': m['pred_mol'],
34
+ 'true_mol': m['true_mol'],
35
+ 'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered),
36
+ 'true_mol_smi': Chem.MolToSmiles(true_mol),
37
+ 'frag_smi': Chem.MolToSmiles(frag)
38
+ })
39
+
40
+ return valid
41
+
42
+
43
+ def extract_linker_smiles(molecule, fragments):
44
+ match = molecule.GetSubstructMatch(fragments)
45
+ elinker = Chem.EditableMol(molecule)
46
+ for atom_id in sorted(match, reverse=True):
47
+ elinker.RemoveAtom(atom_id)
48
+ linker = elinker.GetMol()
49
+ Chem.RemoveStereochemistry(linker)
50
+ try:
51
+ linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker))
52
+ except:
53
+ linker = Chem.MolToSmiles(linker)
54
+ return linker
55
+
56
+
57
+ def compute_and_add_linker_smiles(data, progress=False):
58
+ data_with_linkers = []
59
+ generator = tqdm(data) if progress else data
60
+ for m in generator:
61
+ pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
62
+ true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
63
+ frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True)
64
+
65
+ pred_linker = extract_linker_smiles(pred_mol, frag)
66
+ true_linker = extract_linker_smiles(true_mol, frag)
67
+ data_with_linkers.append({
68
+ **m,
69
+ 'pred_linker': pred_linker,
70
+ 'true_linker': true_linker,
71
+ })
72
+
73
+ return data_with_linkers
74
+
75
+
76
+ def compute_uniqueness(data, progress=False):
77
+ mol_dictionary = {}
78
+ generator = tqdm(data) if progress else data
79
+ for m in generator:
80
+ frag = m['frag_smi']
81
+ pred_mol = m['pred_mol_smi']
82
+ true_mol = m['true_mol_smi']
83
+
84
+ key = f'{true_mol}.{frag}'
85
+ mol_dictionary.setdefault(key, []).append(pred_mol)
86
+
87
+ total_mol = 0
88
+ unique_mol = 0
89
+ for molecules in mol_dictionary.values():
90
+ total_mol += len(molecules)
91
+ unique_mol += len(set(molecules))
92
+
93
+ return unique_mol / total_mol
94
+
95
+
96
+ def compute_novelty(data, progress=False):
97
+ novel = 0
98
+ true_linkers = set([m['true_linker'] for m in data])
99
+ generator = tqdm(data) if progress else data
100
+ for m in generator:
101
+ pred_linker = m['pred_linker']
102
+ if pred_linker in true_linkers:
103
+ continue
104
+ else:
105
+ novel += 1
106
+
107
+ return novel / len(data)
108
+
109
+
110
+ def compute_recovery_rate(data, progress=False):
111
+ total = set()
112
+ recovered = set()
113
+ generator = tqdm(data) if progress else data
114
+ for m in generator:
115
+ pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
116
+ Chem.RemoveStereochemistry(pred_mol)
117
+ pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol))
118
+
119
+ true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
120
+ Chem.RemoveStereochemistry(true_mol)
121
+ true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol))
122
+
123
+ true_link = m['true_linker']
124
+ total.add(f'{true_mol}.{true_link}')
125
+ if pred_mol == true_mol:
126
+ recovered.add(f'{true_mol}.{true_link}')
127
+
128
+ return len(recovered) / len(total)
129
+
130
+
131
+ def calc_sa_score_mol(mol):
132
+ if mol is None:
133
+ return None
134
+ return sascorer.calculateScore(mol)
135
+
136
+
137
+ def check_ring_filter(linker):
138
+ check = True
139
+ # Get linker rings
140
+ ssr = Chem.GetSymmSSSR(linker)
141
+ # Check rings
142
+ for ring in ssr:
143
+ for atom_idx in ring:
144
+ for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
145
+ if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
146
+ check = False
147
+ return check
148
+
149
+
150
+ def check_pains(mol, pains_smarts):
151
+ for pain in pains_smarts:
152
+ if mol.HasSubstructMatch(pain):
153
+ return False
154
+ return True
155
+
156
+
157
+ def calc_2d_filters(toks, pains_smarts):
158
+ pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi'])
159
+ frag = Chem.MolFromSmiles(toks['frag_smi'])
160
+ linker = Chem.MolFromSmiles(toks['pred_linker'])
161
+
162
+ result = [False, False, False]
163
+ if len(pred_mol.GetSubstructMatch(frag)) > 0:
164
+ sa_score = False
165
+ ra_score = False
166
+ pains_score = False
167
+
168
+ try:
169
+ sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag)
170
+ except Exception as e:
171
+ print(f'Could not compute SA score: {e}')
172
+ try:
173
+ ra_score = check_ring_filter(linker)
174
+ except Exception as e:
175
+ print(f'Could not compute RA score: {e}')
176
+ try:
177
+ pains_score = check_pains(pred_mol, pains_smarts)
178
+ except Exception as e:
179
+ print(f'Could not compute PAINS score: {e}')
180
+
181
+ result = [sa_score, ra_score, pains_score]
182
+
183
+ return result
184
+
185
+
186
+ def calc_filters_2d_dataset(data):
187
+ with open('models/wehi_pains.csv', 'r') as f:
188
+ pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
189
+
190
+ pass_all = pass_SA = pass_RA = pass_PAINS = 0
191
+ for m in data:
192
+ filters_2d = calc_2d_filters(m, pains_smarts)
193
+ pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2]
194
+ pass_SA += filters_2d[0]
195
+ pass_RA += filters_2d[1]
196
+ pass_PAINS += filters_2d[2]
197
+
198
+ return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data)
199
+
200
+
201
+ def calc_sc_rdkit_full_mol(gen_mol, ref_mol):
202
+ try:
203
+ score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol)
204
+ return score
205
+ except:
206
+ return -0.5
207
+
208
+
209
+ def sc_rdkit_score(data):
210
+ scores = []
211
+ for m in data:
212
+ score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol'])
213
+ scores.append(score)
214
+
215
+ return np.mean(scores)
216
+
217
+
218
+ def get_delinker_metrics(pred_molecules, true_molecules, true_fragments):
219
+ default_values = {
220
+ 'DeLinker/validity': 0,
221
+ 'DeLinker/uniqueness': 0,
222
+ 'DeLinker/novelty': 0,
223
+ 'DeLinker/recovery': 0,
224
+ 'DeLinker/2D_filters': 0,
225
+ 'DeLinker/2D_filters_SA': 0,
226
+ 'DeLinker/2D_filters_RA': 0,
227
+ 'DeLinker/2D_filters_PAINS': 0,
228
+ 'DeLinker/SC_RDKit': 0,
229
+ }
230
+ if len(pred_molecules) == 0:
231
+ return default_values
232
+
233
+ data = []
234
+ for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments):
235
+ data.append({
236
+ 'pred_mol': pred_mol,
237
+ 'true_mol': true_mol,
238
+ 'pred_mol_smi': Chem.MolToSmiles(pred_mol),
239
+ 'true_mol_smi': Chem.MolToSmiles(true_mol),
240
+ 'frag_smi': Chem.MolToSmiles(true_frag)
241
+ })
242
+
243
+ # Validity according to DeLinker paper:
244
+ # Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments
245
+ valid_data = get_valid_as_in_delinker(data)
246
+ validity_as_in_delinker = len(valid_data) / len(data)
247
+ if len(valid_data) == 0:
248
+ return default_values
249
+
250
+ # Compute linkers and add to results
251
+ valid_data = compute_and_add_linker_smiles(valid_data)
252
+
253
+ # Compute uniqueness
254
+ uniqueness = compute_uniqueness(valid_data)
255
+
256
+ # Compute novelty
257
+ novelty = compute_novelty(valid_data)
258
+
259
+ # Compute recovered molecules
260
+ recovery_rate = compute_recovery_rate(valid_data)
261
+
262
+ # 2D filters
263
+ pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data)
264
+
265
+ # 3D Filters
266
+ sc_rdkit = sc_rdkit_score(valid_data)
267
+
268
+ return {
269
+ 'DeLinker/validity': validity_as_in_delinker,
270
+ 'DeLinker/uniqueness': uniqueness,
271
+ 'DeLinker/novelty': novelty,
272
+ 'DeLinker/recovery': recovery_rate,
273
+ 'DeLinker/2D_filters': pass_all,
274
+ 'DeLinker/2D_filters_SA': pass_SA,
275
+ 'DeLinker/2D_filters_RA': pass_RA,
276
+ 'DeLinker/2D_filters_PAINS': pass_PAINS,
277
+ 'DeLinker/SC_RDKit': sc_rdkit,
278
+ }
src/delinker_utils/__init__.py ADDED
File without changes
src/delinker_utils/calc_SC_RDKit.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rdkit import Chem
3
+ from rdkit.Chem import AllChem, rdShapeHelpers
4
+ from rdkit.Chem.FeatMaps import FeatMaps
5
+ from rdkit import RDConfig
6
+
7
+ # Set up features to use in FeatureMap
8
+ fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
9
+ fdef = AllChem.BuildFeatureFactory(fdefName)
10
+
11
+ fmParams = {}
12
+ for k in fdef.GetFeatureFamilies():
13
+ fparams = FeatMaps.FeatMapParams()
14
+ fmParams[k] = fparams
15
+
16
+ keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable',
17
+ 'ZnBinder', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe')
18
+
19
+
20
+ def get_FeatureMapScore(query_mol, ref_mol):
21
+ featLists = []
22
+ for m in [query_mol, ref_mol]:
23
+ rawFeats = fdef.GetFeaturesForMol(m)
24
+ # filter that list down to only include the ones we're intereted in
25
+ featLists.append([f for f in rawFeats if f.GetFamily() in keep])
26
+ fms = [FeatMaps.FeatMap(feats=x, weights=[1] * len(x), params=fmParams) for x in featLists]
27
+ fms[0].scoreMode = FeatMaps.FeatMapScoreMode.Best
28
+ fm_score = fms[0].ScoreFeats(featLists[1]) / min(fms[0].GetNumFeatures(), len(featLists[1]))
29
+
30
+ return fm_score
31
+
32
+
33
+ def calc_SC_RDKit_score(query_mol, ref_mol):
34
+ fm_score = get_FeatureMapScore(query_mol, ref_mol)
35
+
36
+ protrude_dist = rdShapeHelpers.ShapeProtrudeDist(query_mol, ref_mol,
37
+ allowReordering=False)
38
+ SC_RDKit_score = 0.5 * fm_score + 0.5 * (1 - protrude_dist)
39
+
40
+ return SC_RDKit_score
src/delinker_utils/frag_utils.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import networkx as nx
3
+
4
+ from joblib import Parallel, delayed
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem
7
+ from src.delinker_utils import sascorer
8
+
9
+
10
+ def read_triples_file(filename):
11
+ '''Reads .smi file '''
12
+ '''Returns array containing smiles strings of molecules'''
13
+ smiles, names = [], []
14
+ with open(filename, 'r') as f:
15
+ for line in f:
16
+ if line:
17
+ smiles.append(line.strip().split(' ')[0:3])
18
+ return smiles
19
+
20
+
21
+ def remove_dummys(smi_string):
22
+ return Chem.MolToSmiles(Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(smi_string),Chem.MolFromSmiles('*'),Chem.MolFromSmiles('[H]'),True)[0]))
23
+
24
+
25
+ def sa_filter(results, verbose=True):
26
+ count = 0
27
+ total = 0
28
+ for processed, res in enumerate(results):
29
+ total += len(res)
30
+ for m in res:
31
+ # Check SA score has improved
32
+ if calc_mol_props(m[1])[1] < calc_mol_props(m[0])[1]:
33
+ count += 1
34
+ # Progress
35
+ if verbose:
36
+ if processed % 10 == 0:
37
+ print("\rProcessed %d" % processed, end="")
38
+ print("\r",end="")
39
+ return count/total
40
+
41
+
42
+ def ring_check_res(res, clean_frag):
43
+ check = True
44
+ gen_mol = Chem.MolFromSmiles(res[1])
45
+ linker = Chem.DeleteSubstructs(gen_mol, clean_frag)
46
+
47
+ # Get linker rings
48
+ ssr = Chem.GetSymmSSSR(linker)
49
+ # Check rings
50
+ for ring in ssr:
51
+ for atom_idx in ring:
52
+ for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
53
+ if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
54
+ check = False
55
+ return check
56
+
57
+
58
+ def ring_filter(results, verbose=True):
59
+ count = 0
60
+ total = 0
61
+ du = Chem.MolFromSmiles('*')
62
+ for processed, res in enumerate(results):
63
+ total += len(res)
64
+ for m in res:
65
+ # Clean frags
66
+ clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(m[0]),du,Chem.MolFromSmiles('[H]'),True)[0])
67
+ if ring_check_res(m, clean_frag):
68
+ count += 1
69
+ # Progress
70
+ if verbose:
71
+ if processed % 10 == 0:
72
+ print("\rProcessed %d" % processed, end="")
73
+ print("\r",end="")
74
+ return count/total
75
+
76
+
77
+ def check_ring_filter(linker):
78
+ check = True
79
+ # Get linker rings
80
+ ssr = Chem.GetSymmSSSR(linker)
81
+ # Check rings
82
+ for ring in ssr:
83
+ for atom_idx in ring:
84
+ for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
85
+ if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
86
+ check = False
87
+ return check
88
+
89
+
90
+ def check_pains(mol, pains_smarts):
91
+ for pain in pains_smarts:
92
+ if mol.HasSubstructMatch(pain):
93
+ return False
94
+ return True
95
+
96
+
97
+ def calc_2d_filters(toks, pains_smarts):
98
+ try:
99
+ # Input format: (Full Molecule (SMILES), Linker (SMILES), Unlinked Fragments (SMILES))
100
+ frags = Chem.MolFromSmiles(toks[2])
101
+ linker = Chem.MolFromSmiles(toks[1])
102
+ full_mol = Chem.MolFromSmiles(toks[0])
103
+ # Remove dummy atoms from unlinked fragments
104
+ du = Chem.MolFromSmiles('*')
105
+ clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(frags, du, Chem.MolFromSmiles('[H]'), True)[0])
106
+
107
+ res = []
108
+ # Check: Unlinked fragments in full molecule
109
+ if len(full_mol.GetSubstructMatch(clean_frag)) > 0:
110
+ # Check: SA score improved from unlinked fragments to full molecule
111
+ if calc_sa_score_mol(full_mol) < calc_sa_score_mol(frags):
112
+ res.append(True)
113
+ else:
114
+ res.append(False)
115
+ # Check: No non-aromatic rings with double bonds
116
+ if check_ring_filter(linker):
117
+ res.append(True)
118
+ else:
119
+ res.append(False)
120
+ # Check: Pass pains filters
121
+ if check_pains(full_mol, pains_smarts):
122
+ res.append(True)
123
+ else:
124
+ res.append(False)
125
+ return res
126
+ except:
127
+ return [False, False, False]
128
+
129
+
130
+ def calc_filters_2d_dataset(results, pains_smarts_loc, n_cores=1):
131
+ # Load pains filters
132
+ with open(pains_smarts_loc, 'r') as f:
133
+ pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
134
+ # calc_2d_filters([results[0][2], results[0][4], results[0][1]], pains_smarts)
135
+ with Parallel(n_jobs=n_cores, backend='multiprocessing') as parallel:
136
+ filters_2d = parallel(delayed(calc_2d_filters)([toks[2], toks[4], toks[1]], pains_smarts) for toks in results)
137
+
138
+ return filters_2d
139
+
140
+
141
+ def calc_mol_props(smiles):
142
+ # Create RDKit mol
143
+ mol = Chem.MolFromSmiles(smiles)
144
+ if mol is None:
145
+ print("Error passing: %s" % smiles)
146
+ return None
147
+
148
+ # QED
149
+ qed = Chem.QED.qed(mol)
150
+ # Synthetic accessibility score - number of cycles (rings with > 6 atoms)
151
+ sas = sascorer.calculateScore(mol)
152
+ # Cyles with >6 atoms
153
+ ri = mol.GetRingInfo()
154
+ nMacrocycles = 0
155
+ for x in ri.AtomRings():
156
+ if len(x) > 6:
157
+ nMacrocycles += 1
158
+
159
+ prop_array = [qed, sas]
160
+
161
+ return prop_array
162
+
163
+
164
+ def calc_sa_score_mol(mol, verbose=False):
165
+ if mol is None:
166
+ if verbose:
167
+ print("Error passing: %s" % mol)
168
+ return None
169
+ # Synthetic accessibility score
170
+ return sascorer.calculateScore(mol)
171
+
172
+
173
+ def get_linker(full_mol, clean_frag, starting_point):
174
+ # INPUT FORMAT: molecule (RDKit mol object), clean fragments (RDKit mol object), starting fragments (SMILES)
175
+
176
+ # Get matches of fragments
177
+ matches = list(full_mol.GetSubstructMatches(clean_frag))
178
+
179
+ # If no matches, terminate
180
+ if len(matches) == 0:
181
+ print("No matches")
182
+ return ""
183
+
184
+ # Get number of atoms in linker
185
+ linker_len = full_mol.GetNumHeavyAtoms() - clean_frag.GetNumHeavyAtoms()
186
+ if linker_len == 0:
187
+ return ""
188
+
189
+ # Setup
190
+ mol_to_break = Chem.Mol(full_mol)
191
+ Chem.Kekulize(full_mol, clearAromaticFlags=True)
192
+
193
+ poss_linker = []
194
+
195
+ if len(matches) > 0:
196
+ # Loop over matches
197
+ for match in matches:
198
+ mol_rw = Chem.RWMol(full_mol)
199
+ # Get linker atoms
200
+ linker_atoms = list(set(list(range(full_mol.GetNumHeavyAtoms()))).difference(match))
201
+ linker_bonds = []
202
+ atoms_joined_to_linker = []
203
+ # Loop over starting fragments atoms
204
+ # Get (i) bonds between starting fragments and linker, (ii) atoms joined to linker
205
+ for idx_to_delete in sorted(match, reverse=True):
206
+ nei = [x.GetIdx() for x in mol_rw.GetAtomWithIdx(idx_to_delete).GetNeighbors()]
207
+ intersect = set(nei).intersection(set(linker_atoms))
208
+ if len(intersect) == 1:
209
+ linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, list(intersect)[0]).GetIdx())
210
+ atoms_joined_to_linker.append(idx_to_delete)
211
+ elif len(intersect) > 1:
212
+ for idx_nei in list(intersect):
213
+ linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, idx_nei).GetIdx())
214
+ atoms_joined_to_linker.append(idx_to_delete)
215
+
216
+ # Check number of atoms joined to linker
217
+ # If not == 2, check next match
218
+ if len(set(atoms_joined_to_linker)) != 2:
219
+ continue
220
+
221
+ # Delete starting fragments atoms
222
+ for idx_to_delete in sorted(match, reverse=True):
223
+ mol_rw.RemoveAtom(idx_to_delete)
224
+
225
+ linker = Chem.Mol(mol_rw)
226
+ # Check linker required num atoms
227
+ if linker.GetNumHeavyAtoms() == linker_len:
228
+ mol_rw = Chem.RWMol(full_mol)
229
+ # Delete linker atoms
230
+ for idx_to_delete in sorted(linker_atoms, reverse=True):
231
+ mol_rw.RemoveAtom(idx_to_delete)
232
+ frags = Chem.Mol(mol_rw)
233
+ # Check there are two disconnected fragments
234
+ if len(Chem.rdmolops.GetMolFrags(frags)) == 2:
235
+ # Fragment molecule into starting fragments and linker
236
+ fragmented_mol = Chem.FragmentOnBonds(mol_to_break, linker_bonds)
237
+ # Remove starting fragments from fragmentation
238
+ linker_to_return = Chem.Mol(fragmented_mol)
239
+ qp = Chem.AdjustQueryParameters()
240
+ qp.makeDummiesQueries = True
241
+ for f in starting_point.split('.'):
242
+ qfrag = Chem.AdjustQueryProperties(Chem.MolFromSmiles(f), qp)
243
+ linker_to_return = AllChem.DeleteSubstructs(linker_to_return, qfrag, onlyFrags=True)
244
+
245
+ # Check linker is connected and two bonds to outside molecule
246
+ if len(Chem.rdmolops.GetMolFrags(linker)) == 1 and len(linker_bonds) == 2:
247
+ Chem.Kekulize(linker_to_return, clearAromaticFlags=True)
248
+ # If for some reason a starting fragment isn't removed (and it's larger than the linker), remove (happens v. occassionally)
249
+ if len(Chem.rdmolops.GetMolFrags(linker_to_return)) > 1:
250
+ for frag in Chem.MolToSmiles(linker_to_return).split('.'):
251
+ if Chem.MolFromSmiles(frag).GetNumHeavyAtoms() == linker_len:
252
+ return frag
253
+ return Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(linker_to_return)))
254
+
255
+ # If not, add to possible linkers (above doesn't capture some complex cases)
256
+ else:
257
+ fragmented_mol = Chem.MolFromSmiles(Chem.MolToSmiles(fragmented_mol), sanitize=False)
258
+ linker_to_return = AllChem.DeleteSubstructs(fragmented_mol, Chem.MolFromSmiles(starting_point))
259
+ poss_linker.append(Chem.MolToSmiles(linker_to_return))
260
+
261
+ # If only one possibility, return linker
262
+ if len(poss_linker) == 1:
263
+ return poss_linker[0]
264
+ # If no possibilities, process failed
265
+ elif len(poss_linker) == 0:
266
+ print("FAIL:", Chem.MolToSmiles(full_mol), Chem.MolToSmiles(clean_frag), starting_point)
267
+ return ""
268
+ # If multiple possibilities, process probably failed
269
+ else:
270
+ print("More than one poss linker. ", poss_linker)
271
+ return poss_linker[0]
272
+
273
+
274
+ def get_linker_v2(full_mol, clean_frag):
275
+ # INPUT FORMAT: molecule (RDKit mol object), clean fragments (RDKit mol object), starting fragments (SMILES)
276
+
277
+ # Get matches of fragments
278
+ matches = list(full_mol.GetSubstructMatches(clean_frag))
279
+
280
+ # If no matches, terminate
281
+ if len(matches) == 0:
282
+ print("No matches")
283
+ return ""
284
+
285
+ # Get number of atoms in linker
286
+ linker_len = full_mol.GetNumHeavyAtoms() - clean_frag.GetNumHeavyAtoms()
287
+ if linker_len == 0:
288
+ return ""
289
+
290
+ # Setup
291
+ mol_to_break = Chem.Mol(full_mol)
292
+ Chem.Kekulize(full_mol, clearAromaticFlags=True)
293
+
294
+ poss_linker = []
295
+
296
+ if len(matches) > 0:
297
+ # Loop over matches
298
+ for match in matches:
299
+ mol_rw = Chem.RWMol(full_mol)
300
+ # Get linker atoms
301
+ linker_atoms = list(set(list(range(full_mol.GetNumHeavyAtoms()))).difference(match))
302
+ linker_bonds = []
303
+ atoms_joined_to_linker = []
304
+ # Loop over starting fragments atoms
305
+ # Get (i) bonds between starting fragments and linker, (ii) atoms joined to linker
306
+ for idx_to_delete in sorted(match, reverse=True):
307
+ nei = [x.GetIdx() for x in mol_rw.GetAtomWithIdx(idx_to_delete).GetNeighbors()]
308
+ intersect = set(nei).intersection(set(linker_atoms))
309
+ if len(intersect) == 1:
310
+ linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, list(intersect)[0]).GetIdx())
311
+ atoms_joined_to_linker.append(idx_to_delete)
312
+ elif len(intersect) > 1:
313
+ for idx_nei in list(intersect):
314
+ linker_bonds.append(mol_rw.GetBondBetweenAtoms(idx_to_delete, idx_nei).GetIdx())
315
+ atoms_joined_to_linker.append(idx_to_delete)
316
+
317
+ # Check number of atoms joined to linker
318
+ # If not == 2, check next match
319
+ if len(set(atoms_joined_to_linker)) != 2:
320
+ continue
321
+
322
+ # Delete starting fragments atoms
323
+ for idx_to_delete in sorted(match, reverse=True):
324
+ mol_rw.RemoveAtom(idx_to_delete)
325
+
326
+ linker = Chem.Mol(mol_rw)
327
+ # Check linker required num atoms
328
+ if linker.GetNumHeavyAtoms() == linker_len:
329
+ mol_rw = Chem.RWMol(full_mol)
330
+ # Delete linker atoms
331
+ for idx_to_delete in sorted(linker_atoms, reverse=True):
332
+ mol_rw.RemoveAtom(idx_to_delete)
333
+ frags = Chem.Mol(mol_rw)
334
+
335
+ # Check linker is connected and two bonds to outside molecule
336
+ if len(Chem.rdmolops.GetMolFrags(linker)) == 1 and len(linker_bonds) == 2:
337
+ Chem.Kekulize(linker, clearAromaticFlags=True)
338
+ # If for some reason a starting fragment isn't removed (and it's larger than the linker), remove (happens v. occassionally)
339
+ if len(Chem.rdmolops.GetMolFrags(linker)) > 1:
340
+ for frag in Chem.MolToSmiles(linker).split('.'):
341
+ if Chem.MolFromSmiles(frag).GetNumHeavyAtoms() == linker_len:
342
+ return frag
343
+ return Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(linker)))
344
+
345
+ # If not, add to possible linkers (above doesn't capture some complex cases)
346
+ else:
347
+ poss_linker.append(Chem.MolToSmiles(linker))
348
+
349
+ # If only one possibility, return linker
350
+ if len(poss_linker) == 1:
351
+ return poss_linker[0]
352
+ # If no possibilities, process failed
353
+ elif len(poss_linker) == 0:
354
+ print("FAIL:", Chem.MolToSmiles(full_mol), Chem.MolToSmiles(clean_frag))
355
+ return ""
356
+ # If multiple possibilities, process probably failed
357
+ else:
358
+ print("More than one poss linker. ", poss_linker)
359
+ return poss_linker[0]
360
+
361
+
362
+ def unique(results):
363
+ total_dupes = 0
364
+ total = 0
365
+ for res in results:
366
+ original_num = len(res)
367
+ test_data = set(res)
368
+ new_num = len(test_data)
369
+ total_dupes += original_num - new_num
370
+ total += original_num
371
+ return 1 - total_dupes/float(total)
372
+
373
+
374
+ def check_recovered_original_mol_with_idx(results):
375
+ outcomes = []
376
+ rec_idx = []
377
+ for res in results:
378
+ success = False
379
+ # Load original mol and canonicalise
380
+ orig_mol = Chem.MolFromSmiles(res[0][0][0])
381
+ Chem.RemoveStereochemistry(orig_mol)
382
+ orig_mol = Chem.MolToSmiles(Chem.RemoveHs(orig_mol))
383
+ #orig_mol = MolStandardize.canonicalize_tautomer_smiles(orig_mol)
384
+ # Check generated mols
385
+ for m in res:
386
+ # print(1)
387
+ gen_mol = Chem.MolFromSmiles(m[0][2])
388
+ Chem.RemoveStereochemistry(gen_mol)
389
+ gen_mol = Chem.MolToSmiles(Chem.RemoveHs(gen_mol))
390
+ #gen_mol = MolStandardize.canonicalize_tautomer_smiles(gen_mol)
391
+ if gen_mol == orig_mol:
392
+ # outcomes.append(True)
393
+ success = True
394
+ rec_idx.append(m[1])
395
+ # break
396
+ if not success:
397
+ outcomes.append(False)
398
+ else:
399
+ outcomes.append(True)
400
+ return outcomes, rec_idx
401
+
402
+
403
+ def topology_from_rdkit(rdkit_molecule):
404
+ topology = nx.Graph()
405
+ for atom in rdkit_molecule.GetAtoms():
406
+ # Add the atoms as nodes
407
+ topology.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum())
408
+
409
+ # Add the bonds as edges
410
+ for bond in rdkit_molecule.GetBonds():
411
+ topology.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond_type=bond.GetBondType())
412
+
413
+ return topology
src/delinker_utils/sascorer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # calculation of synthetic accessibility score as described in:
3
+ #
4
+ # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5
+ # Peter Ertl and Ansgar Schuffenhauer
6
+ # Journal of Cheminformatics 1:8 (2009)
7
+ # http://www.jcheminf.com/content/1/1/8
8
+ #
9
+ # several small modifications to the original paper are included
10
+ # particularly slightly different formula for marocyclic penalty
11
+ # and taking into account also molecule symmetry (fingerprint density)
12
+ #
13
+ # for a set of 10k diverse molecules the agreement between the original method
14
+ # as implemented in PipelinePilot and this implementation is r2 = 0.97
15
+ #
16
+ # peter ertl & greg landrum, september 2013
17
+ #
18
+ from __future__ import print_function
19
+
20
+ from rdkit import Chem
21
+ from rdkit.Chem import rdMolDescriptors
22
+ from rdkit.six.moves import cPickle
23
+ from rdkit.six import iteritems
24
+
25
+ import math
26
+ from collections import defaultdict
27
+
28
+ import os.path as op
29
+
30
+ _fscores = None
31
+
32
+
33
+ def readFragmentScores(name='models/fpscores'):
34
+ import gzip
35
+ global _fscores
36
+ # generate the full path filename:
37
+ if name == "fpscores":
38
+ name = op.join(op.dirname(__file__), name)
39
+ _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
40
+ outDict = {}
41
+ for i in _fscores:
42
+ for j in range(1, len(i)):
43
+ outDict[i[j]] = float(i[0])
44
+ _fscores = outDict
45
+
46
+
47
+ def numBridgeheadsAndSpiro(mol, ri=None):
48
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
49
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
50
+ return nBridgehead, nSpiro
51
+
52
+
53
+ def calculateScore(m):
54
+ if _fscores is None:
55
+ readFragmentScores()
56
+
57
+ # fragment score
58
+ fp = rdMolDescriptors.GetMorganFingerprint(m,
59
+ 2) #<- 2 is the *radius* of the circular fingerprint
60
+ fps = fp.GetNonzeroElements()
61
+ score1 = 0.
62
+ nf = 0
63
+ for bitId, v in iteritems(fps):
64
+ nf += v
65
+ sfp = bitId
66
+ score1 += _fscores.get(sfp, -4) * v
67
+ score1 /= nf
68
+
69
+ # features score
70
+ nAtoms = m.GetNumAtoms()
71
+ nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
72
+ ri = m.GetRingInfo()
73
+ nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
74
+ nMacrocycles = 0
75
+ for x in ri.AtomRings():
76
+ if len(x) > 8:
77
+ nMacrocycles += 1
78
+
79
+ sizePenalty = nAtoms**1.005 - nAtoms
80
+ stereoPenalty = math.log10(nChiralCenters + 1)
81
+ spiroPenalty = math.log10(nSpiro + 1)
82
+ bridgePenalty = math.log10(nBridgeheads + 1)
83
+ macrocyclePenalty = 0.
84
+ # ---------------------------------------
85
+ # This differs from the paper, which defines:
86
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
87
+ # This form generates better results when 2 or more macrocycles are present
88
+ if nMacrocycles > 0:
89
+ macrocyclePenalty = math.log10(2)
90
+
91
+ score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
92
+
93
+ # correction for the fingerprint density
94
+ # not in the original publication, added in version 1.1
95
+ # to make highly symmetrical molecules easier to synthetise
96
+ score3 = 0.
97
+ if nAtoms > len(fps):
98
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
99
+
100
+ sascore = score1 + score2 + score3
101
+
102
+ # need to transform "raw" value into scale between 1 and 10
103
+ min = -4.0
104
+ max = 2.5
105
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
106
+ # smooth the 10-end
107
+ if sascore > 8.:
108
+ sascore = 8. + math.log(sascore + 1. - 9.)
109
+ if sascore > 10.:
110
+ sascore = 10.0
111
+ elif sascore < 1.:
112
+ sascore = 1.0
113
+
114
+ return sascore
115
+
116
+
117
+ def processMols(mols):
118
+ print('smiles\tName\tsa_score')
119
+ for i, m in enumerate(mols):
120
+ if m is None:
121
+ continue
122
+
123
+ s = calculateScore(m)
124
+
125
+ smiles = Chem.MolToSmiles(m)
126
+ print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
127
+
128
+
129
+ if __name__ == '__main__':
130
+ import sys, time
131
+
132
+ t1 = time.time()
133
+ readFragmentScores("fpscores")
134
+ t2 = time.time()
135
+
136
+ suppl = Chem.SmilesMolSupplier(sys.argv[1])
137
+ t3 = time.time()
138
+ processMols(suppl)
139
+ t4 = time.time()
140
+
141
+ print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142
+ file=sys.stderr)
143
+
144
+ #
145
+ # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146
+ # All rights reserved.
147
+ #
148
+ # Redistribution and use in source and binary forms, with or without
149
+ # modification, are permitted provided that the following conditions are
150
+ # met:
151
+ #
152
+ # * Redistributions of source code must retain the above copyright
153
+ # notice, this list of conditions and the following disclaimer.
154
+ # * Redistributions in binary form must reproduce the above
155
+ # copyright notice, this list of conditions and the following
156
+ # disclaimer in the documentation and/or other materials provided
157
+ # with the distribution.
158
+ # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159
+ # nor the names of its contributors may be used to endorse or promote
160
+ # products derived from this software without specific prior written permission.
161
+ #
162
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166
+ # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173
+ #
src/edm.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import math
5
+
6
+ from src import utils
7
+ from src.egnn import Dynamics
8
+ from src.noise import GammaNetwork, PredefinedNoiseSchedule
9
+ from typing import Union
10
+
11
+ from pdb import set_trace
12
+
13
+
14
+ class EDM(torch.nn.Module):
15
+ def __init__(
16
+ self,
17
+ dynamics: Union[Dynamics],
18
+ in_node_nf: int,
19
+ n_dims: int,
20
+ timesteps: int = 1000,
21
+ noise_schedule='learned',
22
+ noise_precision=1e-4,
23
+ loss_type='vlb',
24
+ norm_values=(1., 1., 1.),
25
+ norm_biases=(None, 0., 0.),
26
+ ):
27
+ super().__init__()
28
+ if noise_schedule == 'learned':
29
+ assert loss_type == 'vlb', 'A noise schedule can only be learned with a vlb objective'
30
+ self.gamma = GammaNetwork()
31
+ else:
32
+ self.gamma = PredefinedNoiseSchedule(noise_schedule, timesteps=timesteps, precision=noise_precision)
33
+
34
+ self.dynamics = dynamics
35
+ self.in_node_nf = in_node_nf
36
+ self.n_dims = n_dims
37
+ self.T = timesteps
38
+ self.norm_values = norm_values
39
+ self.norm_biases = norm_biases
40
+
41
+ def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None):
42
+ # Normalization and concatenation
43
+ x, h = self.normalize(x, h)
44
+ xh = torch.cat([x, h], dim=2)
45
+
46
+ # Volume change loss term
47
+ delta_log_px = self.delta_log_px(linker_mask).mean()
48
+
49
+ # Sample t
50
+ t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float()
51
+ s_int = t_int - 1
52
+ t = t_int / self.T
53
+ s = s_int / self.T
54
+
55
+ # Masks for t=0 and t>0
56
+ t_is_zero = (t_int == 0).squeeze().float()
57
+ t_is_not_zero = 1 - t_is_zero
58
+
59
+ # Compute gamma_t and gamma_s according to the noise schedule
60
+ gamma_t = self.inflate_batch_array(self.gamma(t), x)
61
+ gamma_s = self.inflate_batch_array(self.gamma(s), x)
62
+
63
+ # Compute alpha_t and sigma_t from gamma
64
+ alpha_t = self.alpha(gamma_t, x)
65
+ sigma_t = self.sigma(gamma_t, x)
66
+
67
+ # Sample noise
68
+ # Note: only for linker
69
+ eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=linker_mask)
70
+
71
+ # Sample z_t given x, h for timestep t, from q(z_t | x, h)
72
+ # Note: keep fragments unchanged
73
+ z_t = alpha_t * xh + sigma_t * eps_t
74
+ z_t = xh * fragment_mask + z_t * linker_mask
75
+
76
+ # Neural net prediction
77
+ eps_t_hat = self.dynamics.forward(
78
+ xh=z_t,
79
+ t=t,
80
+ node_mask=node_mask,
81
+ linker_mask=linker_mask,
82
+ context=context,
83
+ edge_mask=edge_mask,
84
+ )
85
+ eps_t_hat = eps_t_hat * linker_mask
86
+
87
+ # Computing basic error (further used for computing NLL and L2-loss)
88
+ error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2)
89
+
90
+ # Computing L2-loss for t>0
91
+ normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(linker_mask)
92
+ l2_loss = error_t / normalization
93
+ l2_loss = l2_loss.mean()
94
+
95
+ # The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero)
96
+ kl_prior = self.kl_prior(xh, linker_mask).mean()
97
+
98
+ # Computing NLL middle term
99
+ SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1)
100
+ loss_term_t = self.T * 0.5 * SNR_weight * error_t
101
+ loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum()
102
+
103
+ # Computing noise returned by dynamics
104
+ noise = torch.norm(eps_t_hat, dim=[1, 2])
105
+ noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum()
106
+
107
+ if t_is_zero.sum() > 0:
108
+ # The _constants_ depending on sigma_0 from the
109
+ # cross entropy term E_q(z0 | x) [log p(x | z0)]
110
+ neg_log_constants = -self.log_constant_of_p_x_given_z0(x, linker_mask)
111
+
112
+ # Computes the L_0 term (even if gamma_t is not actually gamma_0)
113
+ # and selected only relevant via masking
114
+ loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, linker_mask)
115
+ loss_term_0 = loss_term_0 + neg_log_constants
116
+ loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum()
117
+
118
+ # Computing noise returned by dynamics
119
+ noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum()
120
+ else:
121
+ loss_term_0 = 0.
122
+ noise_0 = 0.
123
+
124
+ return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0
125
+
126
+ @torch.no_grad()
127
+ def sample_chain(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context, keep_frames=None):
128
+ n_samples = x.size(0)
129
+ n_nodes = x.size(1)
130
+
131
+ # Normalization and concatenation
132
+ x, h, = self.normalize(x, h)
133
+ xh = torch.cat([x, h], dim=2)
134
+
135
+ # Initial linker sampling from N(0, I)
136
+ z = self.sample_combined_position_feature_noise(n_samples, n_nodes, mask=linker_mask)
137
+ z = xh * fragment_mask + z * linker_mask
138
+
139
+ if keep_frames is None:
140
+ keep_frames = self.T
141
+ else:
142
+ assert keep_frames <= self.T
143
+ chain = torch.zeros((keep_frames,) + z.size(), device=z.device)
144
+
145
+ # Sample p(z_s | z_t)
146
+ for s in reversed(range(0, self.T)):
147
+ s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
148
+ t_array = s_array + 1
149
+ s_array = s_array / self.T
150
+ t_array = t_array / self.T
151
+
152
+ z = self.sample_p_zs_given_zt_only_linker(
153
+ s=s_array,
154
+ t=t_array,
155
+ z_t=z,
156
+ node_mask=node_mask,
157
+ fragment_mask=fragment_mask,
158
+ linker_mask=linker_mask,
159
+ edge_mask=edge_mask,
160
+ context=context,
161
+ )
162
+ write_index = (s * keep_frames) // self.T
163
+ chain[write_index] = self.unnormalize_z(z)
164
+
165
+ # Finally sample p(x, h | z_0)
166
+ x, h = self.sample_p_xh_given_z0_only_linker(
167
+ z_0=z,
168
+ node_mask=node_mask,
169
+ fragment_mask=fragment_mask,
170
+ linker_mask=linker_mask,
171
+ edge_mask=edge_mask,
172
+ context=context,
173
+ )
174
+ chain[0] = torch.cat([x, h], dim=2)
175
+
176
+ return chain
177
+
178
+ def sample_p_zs_given_zt_only_linker(self, s, t, z_t, node_mask, fragment_mask, linker_mask, edge_mask, context):
179
+ """Samples from zs ~ p(zs | zt). Only used during sampling. Samples only linker features and coords"""
180
+ gamma_s = self.gamma(s)
181
+ gamma_t = self.gamma(t)
182
+
183
+ sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
184
+ sigma_s = self.sigma(gamma_s, target_tensor=z_t)
185
+ sigma_t = self.sigma(gamma_t, target_tensor=z_t)
186
+
187
+ # Neural net prediction.
188
+ eps_hat = self.dynamics.forward(
189
+ xh=z_t,
190
+ t=t,
191
+ node_mask=node_mask,
192
+ linker_mask=linker_mask,
193
+ context=context,
194
+ edge_mask=edge_mask,
195
+ )
196
+ eps_hat = eps_hat * linker_mask
197
+
198
+ # Compute mu for p(z_s | z_t)
199
+ mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat
200
+
201
+ # Compute sigma for p(z_s | z_t)
202
+ sigma = sigma_t_given_s * sigma_s / sigma_t
203
+
204
+ # Sample z_s given the parameters derived from zt
205
+ z_s = self.sample_normal(mu, sigma, linker_mask)
206
+ z_s = z_t * fragment_mask + z_s * linker_mask
207
+
208
+ return z_s
209
+
210
+ def sample_p_xh_given_z0_only_linker(self, z_0, node_mask, fragment_mask, linker_mask, edge_mask, context):
211
+ """Samples x ~ p(x|z0). Samples only linker features and coords"""
212
+ zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
213
+ gamma_0 = self.gamma(zeros)
214
+
215
+ # Computes sqrt(sigma_0^2 / alpha_0^2)
216
+ sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
217
+ eps_hat = self.dynamics.forward(
218
+ t=zeros,
219
+ xh=z_0,
220
+ node_mask=node_mask,
221
+ linker_mask=linker_mask,
222
+ edge_mask=edge_mask,
223
+ context=context
224
+ )
225
+ eps_hat = eps_hat * linker_mask
226
+
227
+ mu_x = self.compute_x_pred(eps_t=eps_hat, z_t=z_0, gamma_t=gamma_0)
228
+ xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=linker_mask)
229
+ xh = z_0 * fragment_mask + xh * linker_mask
230
+
231
+ x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
232
+ x, h = self.unnormalize(x, h)
233
+ h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
234
+
235
+ return x, h
236
+
237
+ def compute_x_pred(self, eps_t, z_t, gamma_t):
238
+ """Computes x_pred, i.e. the most likely prediction of x."""
239
+ sigma_t = self.sigma(gamma_t, target_tensor=eps_t)
240
+ alpha_t = self.alpha(gamma_t, target_tensor=eps_t)
241
+ x_pred = 1. / alpha_t * (z_t - sigma_t * eps_t)
242
+ return x_pred
243
+
244
+ def kl_prior(self, xh, mask):
245
+ """
246
+ Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
247
+ This is essentially a lot of work for something that is in practice negligible in the loss.
248
+ However, you compute it so that you see it when you've made a mistake in your noise schedule.
249
+ """
250
+ # Compute the last alpha value, alpha_T
251
+ ones = torch.ones((xh.size(0), 1), device=xh.device)
252
+ gamma_T = self.gamma(ones)
253
+ alpha_T = self.alpha(gamma_T, xh)
254
+
255
+ # Compute means
256
+ mu_T = alpha_T * xh
257
+ mu_T_x, mu_T_h = mu_T[:, :, :self.n_dims], mu_T[:, :, self.n_dims:]
258
+
259
+ # Compute standard deviations (only batch axis for x-part, inflated for h-part)
260
+ sigma_T_x = self.sigma(gamma_T, mu_T_x).view(-1) # Remove inflate, only keep batch dimension for x-part
261
+ sigma_T_h = self.sigma(gamma_T, mu_T_h)
262
+
263
+ # Compute KL for h-part
264
+ zeros, ones = torch.zeros_like(mu_T_h), torch.ones_like(sigma_T_h)
265
+ kl_distance_h = self.gaussian_kl(mu_T_h, sigma_T_h, zeros, ones)
266
+
267
+ # Compute KL for x-part
268
+ zeros, ones = torch.zeros_like(mu_T_x), torch.ones_like(sigma_T_x)
269
+ d = self.dimensionality(mask)
270
+ kl_distance_x = self.gaussian_kl_for_dimension(mu_T_x, sigma_T_x, zeros, ones, d=d)
271
+
272
+ return kl_distance_x + kl_distance_h
273
+
274
+ def log_constant_of_p_x_given_z0(self, x, mask):
275
+ batch_size = x.size(0)
276
+ degrees_of_freedom_x = self.dimensionality(mask)
277
+ zeros = torch.zeros((batch_size, 1), device=x.device)
278
+ gamma_0 = self.gamma(zeros)
279
+
280
+ # Recall that sigma_x = sqrt(sigma_0^2 / alpha_0^2) = SNR(-0.5 gamma_0)
281
+ log_sigma_x = 0.5 * gamma_0.view(batch_size)
282
+
283
+ return degrees_of_freedom_x * (- log_sigma_x - 0.5 * np.log(2 * np.pi))
284
+
285
+ def log_p_xh_given_z0_without_constants(self, h, z_0, gamma_0, eps, eps_hat, mask, epsilon=1e-10):
286
+ # Discrete properties are predicted directly from z_0
287
+ z_h = z_0[:, :, self.n_dims:]
288
+
289
+ # Take only part over x
290
+ eps_x = eps[:, :, :self.n_dims]
291
+ eps_hat_x = eps_hat[:, :, :self.n_dims]
292
+
293
+ # Compute sigma_0 and rescale to the integer scale of the data
294
+ sigma_0 = self.sigma(gamma_0, target_tensor=z_0) * self.norm_values[1]
295
+
296
+ # Computes the error for the distribution N(x | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
297
+ # the weighting in the epsilon parametrization is exactly '1'
298
+ log_p_x_given_z_without_constants = -0.5 * self.sum_except_batch((eps_x - eps_hat_x) ** 2)
299
+
300
+ # Categorical features
301
+ # Compute delta indicator masks
302
+ h = h * self.norm_values[1] + self.norm_biases[1]
303
+ estimated_h = z_h * self.norm_values[1] + self.norm_biases[1]
304
+
305
+ # Centered h_cat around 1, since onehot encoded
306
+ centered_h = estimated_h - 1
307
+
308
+ # Compute integrals from 0.5 to 1.5 of the normal distribution
309
+ # N(mean=centered_h_cat, stdev=sigma_0_cat)
310
+ log_p_h_proportional = torch.log(
311
+ self.cdf_standard_gaussian((centered_h + 0.5) / sigma_0) -
312
+ self.cdf_standard_gaussian((centered_h - 0.5) / sigma_0) +
313
+ epsilon
314
+ )
315
+
316
+ # Normalize the distribution over the categories
317
+ log_Z = torch.logsumexp(log_p_h_proportional, dim=2, keepdim=True)
318
+ log_probabilities = log_p_h_proportional - log_Z
319
+
320
+ # Select the log_prob of the current category using the onehot representation
321
+ log_p_h_given_z = self.sum_except_batch(log_probabilities * h * mask)
322
+
323
+ # Combine log probabilities for x and h
324
+ log_p_xh_given_z = log_p_x_given_z_without_constants + log_p_h_given_z
325
+
326
+ return log_p_xh_given_z
327
+
328
+ def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask):
329
+ z_x = utils.sample_gaussian_with_mask(
330
+ size=(n_samples, n_nodes, self.n_dims),
331
+ device=mask.device,
332
+ node_mask=mask
333
+ )
334
+ z_h = utils.sample_gaussian_with_mask(
335
+ size=(n_samples, n_nodes, self.in_node_nf),
336
+ device=mask.device,
337
+ node_mask=mask
338
+ )
339
+ z = torch.cat([z_x, z_h], dim=2)
340
+ return z
341
+
342
+ def sample_normal(self, mu, sigma, node_mask):
343
+ """Samples from a Normal distribution."""
344
+ eps = self.sample_combined_position_feature_noise(mu.size(0), mu.size(1), node_mask)
345
+ return mu + sigma * eps
346
+
347
+ def normalize(self, x, h):
348
+ new_x = x / self.norm_values[0]
349
+ new_h = (h.float() - self.norm_biases[1]) / self.norm_values[1]
350
+ return new_x, new_h
351
+
352
+ def unnormalize(self, x, h):
353
+ new_x = x * self.norm_values[0]
354
+ new_h = h * self.norm_values[1] + self.norm_biases[1]
355
+ return new_x, new_h
356
+
357
+ def unnormalize_z(self, z):
358
+ assert z.size(2) == self.n_dims + self.in_node_nf
359
+ x, h = z[:, :, :self.n_dims], z[:, :, self.n_dims:]
360
+ x, h = self.unnormalize(x, h)
361
+ return torch.cat([x, h], dim=2)
362
+
363
+ def delta_log_px(self, mask):
364
+ return -self.dimensionality(mask) * np.log(self.norm_values[0])
365
+
366
+ def dimensionality(self, mask):
367
+ return self.numbers_of_nodes(mask) * self.n_dims
368
+
369
+ def sigma(self, gamma, target_tensor):
370
+ """Computes sigma given gamma."""
371
+ return self.inflate_batch_array(torch.sqrt(torch.sigmoid(gamma)), target_tensor)
372
+
373
+ def alpha(self, gamma, target_tensor):
374
+ """Computes alpha given gamma."""
375
+ return self.inflate_batch_array(torch.sqrt(torch.sigmoid(-gamma)), target_tensor)
376
+
377
+ def SNR(self, gamma):
378
+ """Computes signal to noise ratio (alpha^2/sigma^2) given gamma."""
379
+ return torch.exp(-gamma)
380
+
381
+ def sigma_and_alpha_t_given_s(self, gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_tensor: torch.Tensor):
382
+ """
383
+ Computes sigma t given s, using gamma_t and gamma_s. Used during sampling.
384
+
385
+ These are defined as:
386
+ alpha t given s = alpha t / alpha s,
387
+ sigma t given s = sqrt(1 - (alpha t given s) ^2 ).
388
+ """
389
+ sigma2_t_given_s = self.inflate_batch_array(
390
+ -self.expm1(self.softplus(gamma_s) - self.softplus(gamma_t)),
391
+ target_tensor
392
+ )
393
+
394
+ # alpha_t_given_s = alpha_t / alpha_s
395
+ log_alpha2_t = F.logsigmoid(-gamma_t)
396
+ log_alpha2_s = F.logsigmoid(-gamma_s)
397
+ log_alpha2_t_given_s = log_alpha2_t - log_alpha2_s
398
+
399
+ alpha_t_given_s = torch.exp(0.5 * log_alpha2_t_given_s)
400
+ alpha_t_given_s = self.inflate_batch_array(alpha_t_given_s, target_tensor)
401
+ sigma_t_given_s = torch.sqrt(sigma2_t_given_s)
402
+
403
+ return sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s
404
+
405
+ @staticmethod
406
+ def numbers_of_nodes(mask):
407
+ return torch.sum(mask.squeeze(2), dim=1)
408
+
409
+ @staticmethod
410
+ def inflate_batch_array(array, target):
411
+ """
412
+ Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,),
413
+ or possibly more empty axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape.
414
+ """
415
+ target_shape = (array.size(0),) + (1,) * (len(target.size()) - 1)
416
+ return array.view(target_shape)
417
+
418
+ @staticmethod
419
+ def sum_except_batch(x):
420
+ return x.view(x.size(0), -1).sum(-1)
421
+
422
+ @staticmethod
423
+ def expm1(x: torch.Tensor) -> torch.Tensor:
424
+ return torch.expm1(x)
425
+
426
+ @staticmethod
427
+ def softplus(x: torch.Tensor) -> torch.Tensor:
428
+ return F.softplus(x)
429
+
430
+ @staticmethod
431
+ def cdf_standard_gaussian(x):
432
+ return 0.5 * (1. + torch.erf(x / math.sqrt(2)))
433
+
434
+ @staticmethod
435
+ def gaussian_kl(q_mu, q_sigma, p_mu, p_sigma):
436
+ """
437
+ Computes the KL distance between two normal distributions.
438
+ Args:
439
+ q_mu: Mean of distribution q.
440
+ q_sigma: Standard deviation of distribution q.
441
+ p_mu: Mean of distribution p.
442
+ p_sigma: Standard deviation of distribution p.
443
+ Returns:
444
+ The KL distance, summed over all dimensions except the batch dim.
445
+ """
446
+ kl = torch.log(p_sigma / q_sigma) + 0.5 * (q_sigma ** 2 + (q_mu - p_mu) ** 2) / (p_sigma ** 2) - 0.5
447
+ return EDM.sum_except_batch(kl)
448
+
449
+ @staticmethod
450
+ def gaussian_kl_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d):
451
+ """
452
+ Computes the KL distance between two normal distributions taking the dimension into account.
453
+ Args:
454
+ q_mu: Mean of distribution q.
455
+ q_sigma: Standard deviation of distribution q.
456
+ p_mu: Mean of distribution p.
457
+ p_sigma: Standard deviation of distribution p.
458
+ d: dimension
459
+ Returns:
460
+ The KL distance, summed over all dimensions except the batch dim.
461
+ """
462
+ mu_norm_2 = EDM.sum_except_batch((q_mu - p_mu) ** 2)
463
+ return d * torch.log(p_sigma / q_sigma) + 0.5 * (d * q_sigma ** 2 + mu_norm_2) / (p_sigma ** 2) - 0.5 * d
464
+
465
+
466
+ class InpaintingEDM(EDM):
467
+ def forward(self, x, h, node_mask, fragment_mask, linker_mask, edge_mask, context=None):
468
+ # Normalization and concatenation
469
+ x, h = self.normalize(x, h)
470
+ xh = torch.cat([x, h], dim=2)
471
+
472
+ # Volume change loss term
473
+ delta_log_px = self.delta_log_px(node_mask).mean()
474
+
475
+ # Sample t
476
+ t_int = torch.randint(0, self.T + 1, size=(x.size(0), 1), device=x.device).float()
477
+ s_int = t_int - 1
478
+ t = t_int / self.T
479
+ s = s_int / self.T
480
+
481
+ # Masks for t=0 and t>0
482
+ t_is_zero = (t_int == 0).squeeze().float()
483
+ t_is_not_zero = 1 - t_is_zero
484
+
485
+ # Compute gamma_t and gamma_s according to the noise schedule
486
+ gamma_t = self.inflate_batch_array(self.gamma(t), x)
487
+ gamma_s = self.inflate_batch_array(self.gamma(s), x)
488
+
489
+ # Compute alpha_t and sigma_t from gamma
490
+ alpha_t = self.alpha(gamma_t, x)
491
+ sigma_t = self.sigma(gamma_t, x)
492
+
493
+ # Sample noise
494
+ eps_t = self.sample_combined_position_feature_noise(n_samples=x.size(0), n_nodes=x.size(1), mask=node_mask)
495
+
496
+ # Sample z_t given x, h for timestep t, from q(z_t | x, h)
497
+ # Note: keep fragments unchanged
498
+ z_t = alpha_t * xh + sigma_t * eps_t
499
+
500
+ # Neural net prediction
501
+ eps_t_hat = self.dynamics.forward(
502
+ xh=z_t,
503
+ t=t,
504
+ node_mask=node_mask,
505
+ linker_mask=None,
506
+ context=context,
507
+ edge_mask=edge_mask,
508
+ )
509
+
510
+ # Computing basic error (further used for computing NLL and L2-loss)
511
+ error_t = self.sum_except_batch((eps_t - eps_t_hat) ** 2)
512
+
513
+ # Computing L2-loss for t>0
514
+ normalization = (self.n_dims + self.in_node_nf) * self.numbers_of_nodes(node_mask)
515
+ l2_loss = error_t / normalization
516
+ l2_loss = l2_loss.mean()
517
+
518
+ # The KL between q(z_T | x) and p(z_T) = Normal(0, 1) (should be close to zero)
519
+ kl_prior = self.kl_prior(xh, node_mask).mean()
520
+
521
+ # Computing NLL middle term
522
+ SNR_weight = (self.SNR(gamma_s - gamma_t) - 1).squeeze(1).squeeze(1)
523
+ loss_term_t = self.T * 0.5 * SNR_weight * error_t
524
+ loss_term_t = (loss_term_t * t_is_not_zero).sum() / t_is_not_zero.sum()
525
+
526
+ # Computing noise returned by dynamics
527
+ noise = torch.norm(eps_t_hat, dim=[1, 2])
528
+ noise_t = (noise * t_is_not_zero).sum() / t_is_not_zero.sum()
529
+
530
+ if t_is_zero.sum() > 0:
531
+ # The _constants_ depending on sigma_0 from the
532
+ # cross entropy term E_q(z0 | x) [log p(x | z0)]
533
+ neg_log_constants = -self.log_constant_of_p_x_given_z0(x, node_mask)
534
+
535
+ # Computes the L_0 term (even if gamma_t is not actually gamma_0)
536
+ # and selected only relevant via masking
537
+ loss_term_0 = -self.log_p_xh_given_z0_without_constants(h, z_t, gamma_t, eps_t, eps_t_hat, node_mask)
538
+ loss_term_0 = loss_term_0 + neg_log_constants
539
+ loss_term_0 = (loss_term_0 * t_is_zero).sum() / t_is_zero.sum()
540
+
541
+ # Computing noise returned by dynamics
542
+ noise_0 = (noise * t_is_zero).sum() / t_is_zero.sum()
543
+ else:
544
+ loss_term_0 = 0.
545
+ noise_0 = 0.
546
+
547
+ return delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0
548
+
549
+ @torch.no_grad()
550
+ def sample_chain(self, x, h, node_mask, edge_mask, fragment_mask, linker_mask, context, keep_frames=None):
551
+ n_samples = x.size(0)
552
+ n_nodes = x.size(1)
553
+
554
+ # Normalization and concatenation
555
+ x, h, = self.normalize(x, h)
556
+ xh = torch.cat([x, h], dim=2)
557
+
558
+ # Sampling initial noise
559
+ z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)
560
+
561
+ if keep_frames is None:
562
+ keep_frames = self.T
563
+ else:
564
+ assert keep_frames <= self.T
565
+ chain = torch.zeros((keep_frames,) + z.size(), device=z.device)
566
+
567
+ # Sample p(z_s | z_t)
568
+ for s in reversed(range(0, self.T)):
569
+ s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
570
+ t_array = s_array + 1
571
+ s_array = s_array / self.T
572
+ t_array = t_array / self.T
573
+
574
+ z_linker_only_sampled = self.sample_p_zs_given_zt(
575
+ s=s_array,
576
+ t=t_array,
577
+ z_t=z,
578
+ node_mask=node_mask,
579
+ edge_mask=edge_mask,
580
+ context=context,
581
+ )
582
+ z_fragments_only_sampled = self.sample_q_zs_given_zt_and_x(
583
+ s=s_array,
584
+ t=t_array,
585
+ z_t=z,
586
+ x=xh * fragment_mask,
587
+ node_mask=fragment_mask,
588
+ )
589
+ z = z_linker_only_sampled * linker_mask + z_fragments_only_sampled * fragment_mask
590
+
591
+ # Project down to avoid numerical runaway of the center of gravity
592
+ z_x = utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask)
593
+ z_h = z[:, :, self.n_dims:]
594
+ z = torch.cat([z_x, z_h], dim=2)
595
+
596
+ # Saving step to the chain
597
+ write_index = (s * keep_frames) // self.T
598
+ chain[write_index] = self.unnormalize_z(z)
599
+
600
+ # Finally sample p(x, h | z_0)
601
+ x_out_linker, h_out_linker = self.sample_p_xh_given_z0(
602
+ z_0=z,
603
+ node_mask=node_mask,
604
+ edge_mask=edge_mask,
605
+ context=context,
606
+ )
607
+ x_out_fragments, h_out_fragments = self.sample_q_xh_given_z0_and_x(z_0=z, node_mask=node_mask)
608
+
609
+ xh_out_linker = torch.cat([x_out_linker, h_out_linker], dim=2)
610
+ xh_out_fragments = torch.cat([x_out_fragments, h_out_fragments], dim=2)
611
+ xh_out = xh_out_linker * linker_mask + xh_out_fragments * fragment_mask
612
+
613
+ # Overwrite last frame with the resulting x and h
614
+ chain[0] = xh_out
615
+
616
+ return chain
617
+
618
+ def sample_p_zs_given_zt(self, s, t, z_t, node_mask, edge_mask, context):
619
+ """Samples from zs ~ p(zs | zt). Only used during sampling"""
620
+ gamma_s = self.gamma(s)
621
+ gamma_t = self.gamma(t)
622
+ sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
623
+
624
+ sigma_s = self.sigma(gamma_s, target_tensor=z_t)
625
+ sigma_t = self.sigma(gamma_t, target_tensor=z_t)
626
+
627
+ # Neural net prediction.
628
+ eps_hat = self.dynamics.forward(
629
+ xh=z_t,
630
+ t=t,
631
+ node_mask=node_mask,
632
+ linker_mask=None,
633
+ edge_mask=edge_mask,
634
+ context=context
635
+ )
636
+
637
+ # Checking that epsilon is centered around linker COM
638
+ utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask)
639
+
640
+ # Compute mu for p(z_s | z_t)
641
+ mu = z_t / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_hat
642
+
643
+ # Compute sigma for p(z_s | z_t)
644
+ sigma = sigma_t_given_s * sigma_s / sigma_t
645
+
646
+ # Sample z_s given the parameters derived from z_t
647
+ z_s = self.sample_normal(mu, sigma, node_mask)
648
+ return z_s
649
+
650
+ def sample_q_zs_given_zt_and_x(self, s, t, z_t, x, node_mask):
651
+ """Samples from zs ~ q(zs | zt, x). Only used during sampling. Samples only linker features and coords"""
652
+ gamma_s = self.gamma(s)
653
+ gamma_t = self.gamma(t)
654
+ sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, z_t)
655
+
656
+ sigma_s = self.sigma(gamma_s, target_tensor=z_t)
657
+ sigma_t = self.sigma(gamma_t, target_tensor=z_t)
658
+ alpha_s = self.alpha(gamma_s, x)
659
+
660
+ mu = (
661
+ alpha_t_given_s * (sigma_s ** 2) / (sigma_t ** 2) * z_t +
662
+ alpha_s * sigma2_t_given_s / (sigma_t ** 2) * x
663
+ )
664
+
665
+ # Compute sigma for p(zs | zt)
666
+ sigma = sigma_t_given_s * sigma_s / sigma_t
667
+
668
+ # Sample zs given the parameters derived from zt
669
+ z_s = self.sample_normal(mu, sigma, node_mask)
670
+ return z_s
671
+
672
+ def sample_p_xh_given_z0(self, z_0, node_mask, edge_mask, context):
673
+ """Samples x ~ p(x|z0). Samples only linker features and coords"""
674
+ zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
675
+ gamma_0 = self.gamma(zeros)
676
+
677
+ # Computes sqrt(sigma_0^2 / alpha_0^2)
678
+ sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
679
+ eps_hat = self.dynamics.forward(
680
+ xh=z_0,
681
+ t=zeros,
682
+ node_mask=node_mask,
683
+ linker_mask=None,
684
+ edge_mask=edge_mask,
685
+ context=context
686
+ )
687
+ utils.assert_mean_zero_with_mask(eps_hat[:, :, :self.n_dims], node_mask)
688
+
689
+ mu_x = self.compute_x_pred(eps_hat, z_0, gamma_0)
690
+ xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=node_mask)
691
+
692
+ x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
693
+ x, h = self.unnormalize(x, h)
694
+ h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
695
+
696
+ return x, h
697
+
698
+ def sample_q_xh_given_z0_and_x(self, z_0, node_mask):
699
+ """Samples x ~ q(x|z0). Samples only linker features and coords"""
700
+ zeros = torch.zeros(size=(z_0.size(0), 1), device=z_0.device)
701
+ gamma_0 = self.gamma(zeros)
702
+ alpha_0 = self.alpha(gamma_0, z_0)
703
+ sigma_0 = self.sigma(gamma_0, z_0)
704
+
705
+ eps = self.sample_combined_position_feature_noise(z_0.size(0), z_0.size(1), node_mask)
706
+
707
+ xh = (1 / alpha_0) * z_0 - (sigma_0 / alpha_0) * eps
708
+
709
+ x, h = xh[:, :, :self.n_dims], xh[:, :, self.n_dims:]
710
+ x, h = self.unnormalize(x, h)
711
+ h = F.one_hot(torch.argmax(h, dim=2), self.in_node_nf) * node_mask
712
+
713
+ return x, h
714
+
715
+ def sample_combined_position_feature_noise(self, n_samples, n_nodes, mask):
716
+ z_x = utils.sample_center_gravity_zero_gaussian_with_mask(
717
+ size=(n_samples, n_nodes, self.n_dims),
718
+ device=mask.device,
719
+ node_mask=mask
720
+ )
721
+ z_h = utils.sample_gaussian_with_mask(
722
+ size=(n_samples, n_nodes, self.in_node_nf),
723
+ device=mask.device,
724
+ node_mask=mask
725
+ )
726
+ z = torch.cat([z_x, z_h], dim=2)
727
+ return z
728
+
729
+ def dimensionality(self, mask):
730
+ return (self.numbers_of_nodes(mask) - 1) * self.n_dims
src/egnn.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from src import utils
7
+ from pdb import set_trace
8
+
9
+
10
+ class GCL(nn.Module):
11
+ def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method, activation,
12
+ edges_in_d=0, nodes_att_dim=0, attention=False, normalization=None):
13
+ super(GCL, self).__init__()
14
+ input_edge = input_nf * 2
15
+ self.normalization_factor = normalization_factor
16
+ self.aggregation_method = aggregation_method
17
+ self.attention = attention
18
+
19
+ self.edge_mlp = nn.Sequential(
20
+ nn.Linear(input_edge + edges_in_d, hidden_nf),
21
+ activation,
22
+ nn.Linear(hidden_nf, hidden_nf),
23
+ activation)
24
+
25
+ if normalization is None:
26
+ self.node_mlp = nn.Sequential(
27
+ nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
28
+ activation,
29
+ nn.Linear(hidden_nf, output_nf)
30
+ )
31
+ elif normalization == 'batch_norm':
32
+ self.node_mlp = nn.Sequential(
33
+ nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
34
+ nn.BatchNorm1d(hidden_nf),
35
+ activation,
36
+ nn.Linear(hidden_nf, output_nf),
37
+ nn.BatchNorm1d(output_nf),
38
+ )
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ if self.attention:
43
+ self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid())
44
+
45
+ def edge_model(self, source, target, edge_attr, edge_mask):
46
+ if edge_attr is None: # Unused.
47
+ out = torch.cat([source, target], dim=1)
48
+ else:
49
+ out = torch.cat([source, target, edge_attr], dim=1)
50
+ mij = self.edge_mlp(out)
51
+
52
+ if self.attention:
53
+ att_val = self.att_mlp(mij)
54
+ out = mij * att_val
55
+ else:
56
+ out = mij
57
+
58
+ if edge_mask is not None:
59
+ out = out * edge_mask
60
+ return out, mij
61
+
62
+ def node_model(self, x, edge_index, edge_attr, node_attr):
63
+ row, col = edge_index
64
+ agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
65
+ normalization_factor=self.normalization_factor,
66
+ aggregation_method=self.aggregation_method)
67
+ if node_attr is not None:
68
+ agg = torch.cat([x, agg, node_attr], dim=1)
69
+ else:
70
+ agg = torch.cat([x, agg], dim=1)
71
+ out = x + self.node_mlp(agg)
72
+ return out, agg
73
+
74
+ def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
75
+ row, col = edge_index
76
+ edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
77
+ h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
78
+ if node_mask is not None:
79
+ h = h * node_mask
80
+ return h, mij
81
+
82
+
83
+ class EquivariantUpdate(nn.Module):
84
+ def __init__(self, hidden_nf, normalization_factor, aggregation_method,
85
+ edges_in_d=1, activation=nn.SiLU(), tanh=False, coords_range=10.0):
86
+ super(EquivariantUpdate, self).__init__()
87
+ self.tanh = tanh
88
+ self.coords_range = coords_range
89
+ input_edge = hidden_nf * 2 + edges_in_d
90
+ layer = nn.Linear(hidden_nf, 1, bias=False)
91
+ torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
92
+ self.coord_mlp = nn.Sequential(
93
+ nn.Linear(input_edge, hidden_nf),
94
+ activation,
95
+ nn.Linear(hidden_nf, hidden_nf),
96
+ activation,
97
+ layer)
98
+ self.normalization_factor = normalization_factor
99
+ self.aggregation_method = aggregation_method
100
+
101
+ def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask):
102
+ row, col = edge_index
103
+ input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
104
+ if self.tanh:
105
+ trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
106
+ else:
107
+ trans = coord_diff * self.coord_mlp(input_tensor)
108
+ if edge_mask is not None:
109
+ trans = trans * edge_mask
110
+ agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
111
+ normalization_factor=self.normalization_factor,
112
+ aggregation_method=self.aggregation_method)
113
+ if linker_mask is not None:
114
+ agg = agg * linker_mask
115
+
116
+ coord = coord + agg
117
+ return coord
118
+
119
+ def forward(
120
+ self, h, coord, edge_index, coord_diff, edge_attr=None, linker_mask=None, node_mask=None, edge_mask=None
121
+ ):
122
+ coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask)
123
+ if node_mask is not None:
124
+ coord = coord * node_mask
125
+ return coord
126
+
127
+
128
+ class EquivariantBlock(nn.Module):
129
+ def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', activation=nn.SiLU(), n_layers=2, attention=True,
130
+ norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
131
+ normalization_factor=100, aggregation_method='sum'):
132
+ super(EquivariantBlock, self).__init__()
133
+ self.hidden_nf = hidden_nf
134
+ self.device = device
135
+ self.n_layers = n_layers
136
+ self.coords_range_layer = float(coords_range)
137
+ self.norm_diff = norm_diff
138
+ self.norm_constant = norm_constant
139
+ self.sin_embedding = sin_embedding
140
+ self.normalization_factor = normalization_factor
141
+ self.aggregation_method = aggregation_method
142
+
143
+ for i in range(0, n_layers):
144
+ self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
145
+ activation=activation, attention=attention,
146
+ normalization_factor=self.normalization_factor,
147
+ aggregation_method=self.aggregation_method))
148
+ self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, activation=activation, tanh=tanh,
149
+ coords_range=self.coords_range_layer,
150
+ normalization_factor=self.normalization_factor,
151
+ aggregation_method=self.aggregation_method))
152
+ self.to(self.device)
153
+
154
+ def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None, edge_attr=None):
155
+ # Edit Emiel: Remove velocity as input
156
+ distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
157
+ if self.sin_embedding is not None:
158
+ distances = self.sin_embedding(distances)
159
+ edge_attr = torch.cat([distances, edge_attr], dim=1)
160
+ for i in range(0, self.n_layers):
161
+ h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
162
+ x = self._modules["gcl_equiv"](
163
+ h, x,
164
+ edge_index=edge_index,
165
+ coord_diff=coord_diff,
166
+ edge_attr=edge_attr,
167
+ linker_mask=linker_mask,
168
+ node_mask=node_mask,
169
+ edge_mask=edge_mask,
170
+ )
171
+
172
+ # Important, the bias of the last linear might be non-zero
173
+ if node_mask is not None:
174
+ h = h * node_mask
175
+ return h, x
176
+
177
+
178
+ class EGNN(nn.Module):
179
+ def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', activation=nn.SiLU(), n_layers=3, attention=False,
180
+ norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
181
+ sin_embedding=False, normalization_factor=100, aggregation_method='sum'):
182
+ super(EGNN, self).__init__()
183
+ if out_node_nf is None:
184
+ out_node_nf = in_node_nf
185
+ self.hidden_nf = hidden_nf
186
+ self.device = device
187
+ self.n_layers = n_layers
188
+ self.coords_range_layer = float(coords_range/n_layers)
189
+ self.norm_diff = norm_diff
190
+ self.normalization_factor = normalization_factor
191
+ self.aggregation_method = aggregation_method
192
+
193
+ if sin_embedding:
194
+ self.sin_embedding = SinusoidsEmbeddingNew()
195
+ edge_feat_nf = self.sin_embedding.dim * 2
196
+ else:
197
+ self.sin_embedding = None
198
+ edge_feat_nf = 2
199
+
200
+ self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
201
+ self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
202
+ for i in range(0, n_layers):
203
+ self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
204
+ activation=activation, n_layers=inv_sublayers,
205
+ attention=attention, norm_diff=norm_diff, tanh=tanh,
206
+ coords_range=coords_range, norm_constant=norm_constant,
207
+ sin_embedding=self.sin_embedding,
208
+ normalization_factor=self.normalization_factor,
209
+ aggregation_method=self.aggregation_method))
210
+ self.to(self.device)
211
+
212
+ def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None):
213
+ # Edit Emiel: Remove velocity as input
214
+ distances, _ = coord2diff(x, edge_index)
215
+ if self.sin_embedding is not None:
216
+ distances = self.sin_embedding(distances)
217
+
218
+ h = self.embedding(h)
219
+ for i in range(0, self.n_layers):
220
+ h, x = self._modules["e_block_%d" % i](
221
+ h, x, edge_index,
222
+ node_mask=node_mask,
223
+ linker_mask=linker_mask,
224
+ edge_mask=edge_mask,
225
+ edge_attr=distances
226
+ )
227
+
228
+ # Important, the bias of the last linear might be non-zero
229
+ h = self.embedding_out(h)
230
+ if node_mask is not None:
231
+ h = h * node_mask
232
+ return h, x
233
+
234
+
235
+ class GNN(nn.Module):
236
+ def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu',
237
+ activation=nn.SiLU(), n_layers=4, attention=False, normalization_factor=1,
238
+ out_node_nf=None, normalization=None):
239
+ super(GNN, self).__init__()
240
+ if out_node_nf is None:
241
+ out_node_nf = in_node_nf
242
+ self.hidden_nf = hidden_nf
243
+ self.device = device
244
+ self.n_layers = n_layers
245
+
246
+ # Encoder
247
+ self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
248
+ self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
249
+ for i in range(0, n_layers):
250
+ self.add_module("gcl_%d" % i, GCL(
251
+ self.hidden_nf, self.hidden_nf, self.hidden_nf,
252
+ normalization_factor=normalization_factor,
253
+ aggregation_method=aggregation_method,
254
+ edges_in_d=in_edge_nf, activation=activation,
255
+ attention=attention, normalization=normalization))
256
+ self.to(self.device)
257
+
258
+ def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
259
+ # Edit Emiel: Remove velocity as input
260
+ h = self.embedding(h)
261
+ for i in range(0, self.n_layers):
262
+ h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
263
+ h = self.embedding_out(h)
264
+
265
+ # Important, the bias of the last linear might be non-zero
266
+ if node_mask is not None:
267
+ h = h * node_mask
268
+ return h
269
+
270
+
271
+ class SinusoidsEmbeddingNew(nn.Module):
272
+ def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4):
273
+ super().__init__()
274
+ self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1
275
+ self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res
276
+ self.dim = len(self.frequencies) * 2
277
+
278
+ def forward(self, x):
279
+ x = torch.sqrt(x + 1e-8)
280
+ emb = x * self.frequencies[None, :].to(x.device)
281
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
282
+ return emb.detach()
283
+
284
+
285
+ def coord2diff(x, edge_index, norm_constant=1):
286
+ row, col = edge_index
287
+ coord_diff = x[row] - x[col]
288
+ radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
289
+ norm = torch.sqrt(radial + 1e-8)
290
+ coord_diff = coord_diff/(norm + norm_constant)
291
+ return radial, coord_diff
292
+
293
+
294
+ def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
295
+ """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
296
+ Normalization: 'sum' or 'mean'.
297
+ """
298
+ result_shape = (num_segments, data.size(1))
299
+ result = data.new_full(result_shape, 0) # Init empty result tensor.
300
+ segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
301
+ result.scatter_add_(0, segment_ids, data)
302
+ if aggregation_method == 'sum':
303
+ result = result / normalization_factor
304
+
305
+ if aggregation_method == 'mean':
306
+ norm = data.new_zeros(result.shape)
307
+ norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
308
+ norm[norm == 0] = 1
309
+ result = result / norm
310
+ return result
311
+
312
+
313
+ class Dynamics(nn.Module):
314
+ def __init__(
315
+ self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
316
+ n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
317
+ sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
318
+ normalization=None, centering=False,
319
+ ):
320
+ super().__init__()
321
+ self.device = device
322
+ self.n_dims = n_dims
323
+ self.context_node_nf = context_node_nf
324
+ self.condition_time = condition_time
325
+ self.model = model
326
+ self.centering = centering
327
+
328
+ in_node_nf = in_node_nf + context_node_nf + condition_time
329
+ if self.model == 'egnn_dynamics':
330
+ self.dynamics = EGNN(
331
+ in_node_nf=in_node_nf,
332
+ in_edge_nf=1,
333
+ hidden_nf=hidden_nf, device=device,
334
+ activation=activation,
335
+ n_layers=n_layers,
336
+ attention=attention,
337
+ tanh=tanh,
338
+ norm_constant=norm_constant,
339
+ inv_sublayers=inv_sublayers,
340
+ sin_embedding=sin_embedding,
341
+ normalization_factor=normalization_factor,
342
+ aggregation_method=aggregation_method,
343
+ )
344
+ elif self.model == 'gnn_dynamics':
345
+ self.dynamics = GNN(
346
+ in_node_nf=in_node_nf+3,
347
+ in_edge_nf=0,
348
+ hidden_nf=hidden_nf,
349
+ out_node_nf=in_node_nf+3,
350
+ device=device,
351
+ activation=activation,
352
+ n_layers=n_layers,
353
+ attention=attention,
354
+ normalization_factor=normalization_factor,
355
+ aggregation_method=aggregation_method,
356
+ normalization=normalization,
357
+ )
358
+ else:
359
+ raise NotImplementedError
360
+
361
+ self.edge_cache = {}
362
+
363
+ def forward(self, t, xh, node_mask, linker_mask, edge_mask, context):
364
+ """
365
+ - t: (B)
366
+ - xh: (B, N, D), where D = 3 + nf
367
+ - node_mask: (B, N, 1)
368
+ - edge_mask: (B*N*N, 1)
369
+ - context: (B, N, C)
370
+ """
371
+
372
+ bs, n_nodes = xh.shape[0], xh.shape[1]
373
+ edges = self.get_edges(n_nodes, bs) # (2, B*N)
374
+ node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
375
+
376
+ if linker_mask is not None:
377
+ linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
378
+
379
+ # Reshaping node features & adding time feature
380
+ xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
381
+ x = xh[:, :self.n_dims].clone() # (B*N, 3)
382
+ h = xh[:, self.n_dims:].clone() # (B*N, nf)
383
+ if self.condition_time:
384
+ if np.prod(t.size()) == 1:
385
+ # t is the same for all elements in batch.
386
+ h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
387
+ else:
388
+ # t is different over the batch dimension.
389
+ h_time = t.view(bs, 1).repeat(1, n_nodes)
390
+ h_time = h_time.view(bs * n_nodes, 1)
391
+ h = torch.cat([h, h_time], dim=1) # (B*N, nf+1)
392
+ if context is not None:
393
+ context = context.view(bs*n_nodes, self.context_node_nf)
394
+ h = torch.cat([h, context], dim=1)
395
+
396
+ # Forward EGNN
397
+ # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3)
398
+ if self.model == 'egnn_dynamics':
399
+ h_final, x_final = self.dynamics(
400
+ h,
401
+ x,
402
+ edges,
403
+ node_mask=node_mask,
404
+ linker_mask=linker_mask,
405
+ edge_mask=edge_mask
406
+ )
407
+ vel = (x_final - x) * node_mask # This masking operation is redundant but just in case
408
+ elif self.model == 'gnn_dynamics':
409
+ xh = torch.cat([x, h], dim=1)
410
+ output = self.dynamics(xh, edges, node_mask=node_mask)
411
+ vel = output[:, 0:3] * node_mask
412
+ h_final = output[:, 3:]
413
+ else:
414
+ raise NotImplementedError
415
+
416
+ # Slice off context size
417
+ if context is not None:
418
+ h_final = h_final[:, :-self.context_node_nf]
419
+
420
+ # Slice off last dimension which represented time.
421
+ if self.condition_time:
422
+ h_final = h_final[:, :-1]
423
+
424
+ vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
425
+ h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
426
+ node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
427
+
428
+ if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)):
429
+ raise utils.FoundNaNException(vel, h_final)
430
+
431
+ if self.centering:
432
+ vel = utils.remove_mean_with_mask(vel, node_mask)
433
+
434
+ return torch.cat([vel, h_final], dim=2)
435
+
436
+ def get_edges(self, n_nodes, batch_size):
437
+ if n_nodes in self.edge_cache:
438
+ edges_dic_b = self.edge_cache[n_nodes]
439
+ if batch_size in edges_dic_b:
440
+ return edges_dic_b[batch_size]
441
+ else:
442
+ # get edges for a single sample
443
+ rows, cols = [], []
444
+ for batch_idx in range(batch_size):
445
+ for i in range(n_nodes):
446
+ for j in range(n_nodes):
447
+ rows.append(i + batch_idx * n_nodes)
448
+ cols.append(j + batch_idx * n_nodes)
449
+ edges = [torch.LongTensor(rows).to(self.device), torch.LongTensor(cols).to(self.device)]
450
+ edges_dic_b[batch_size] = edges
451
+ return edges
452
+ else:
453
+ self.edge_cache[n_nodes] = {}
454
+ return self.get_edges(n_nodes, batch_size)
455
+
456
+
457
+ class DynamicsWithPockets(Dynamics):
458
+ def forward(self, t, xh, node_mask, linker_mask, edge_mask, context):
459
+ """
460
+ - t: (B)
461
+ - xh: (B, N, D), where D = 3 + nf
462
+ - node_mask: (B, N, 1)
463
+ - edge_mask: (B*N*N, 1)
464
+ - context: (B, N, C)
465
+ """
466
+
467
+ bs, n_nodes = xh.shape[0], xh.shape[1]
468
+ node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
469
+
470
+ if linker_mask is not None:
471
+ linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
472
+
473
+ # Reshaping node features & adding time feature
474
+ xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
475
+ x = xh[:, :self.n_dims].clone() # (B*N, 3)
476
+ h = xh[:, self.n_dims:].clone() # (B*N, nf)
477
+
478
+ edges = self.get_dist_edges(x, node_mask, edge_mask)
479
+ if self.condition_time:
480
+ if np.prod(t.size()) == 1:
481
+ # t is the same for all elements in batch.
482
+ h_time = torch.empty_like(h[:, 0:1]).fill_(t.item())
483
+ else:
484
+ # t is different over the batch dimension.
485
+ h_time = t.view(bs, 1).repeat(1, n_nodes)
486
+ h_time = h_time.view(bs * n_nodes, 1)
487
+ h = torch.cat([h, h_time], dim=1) # (B*N, nf+1)
488
+ if context is not None:
489
+ context = context.view(bs*n_nodes, self.context_node_nf)
490
+ h = torch.cat([h, context], dim=1)
491
+
492
+ # Forward EGNN
493
+ # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3)
494
+ if self.model == 'egnn_dynamics':
495
+ h_final, x_final = self.dynamics(
496
+ h,
497
+ x,
498
+ edges,
499
+ node_mask=node_mask,
500
+ linker_mask=linker_mask,
501
+ edge_mask=None
502
+ )
503
+ vel = (x_final - x) * node_mask # This masking operation is redundant but just in case
504
+ elif self.model == 'gnn_dynamics':
505
+ xh = torch.cat([x, h], dim=1)
506
+ output = self.dynamics(xh, edges, node_mask=node_mask)
507
+ vel = output[:, 0:3] * node_mask
508
+ h_final = output[:, 3:]
509
+ else:
510
+ raise NotImplementedError
511
+
512
+ # Slice off context size
513
+ if context is not None:
514
+ h_final = h_final[:, :-self.context_node_nf]
515
+
516
+ # Slice off last dimension which represented time.
517
+ if self.condition_time:
518
+ h_final = h_final[:, :-1]
519
+
520
+ vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
521
+ h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
522
+ node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
523
+
524
+ if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)):
525
+ raise utils.FoundNaNException(vel, h_final)
526
+
527
+ if self.centering:
528
+ vel = utils.remove_mean_with_mask(vel, node_mask)
529
+
530
+ return torch.cat([vel, h_final], dim=2)
531
+
532
+ @staticmethod
533
+ def get_dist_edges(x, node_mask, batch_mask):
534
+ node_mask = node_mask.squeeze().bool()
535
+ batch_adj = (batch_mask[:, None] == batch_mask[None, :])
536
+ nodes_adj = (node_mask[:, None] & node_mask[None, :])
537
+ dists_adj = (torch.cdist(x, x) <= 4)
538
+ rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
539
+ adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
540
+ edges = torch.stack(torch.where(adj))
541
+ return edges
src/lightning.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import wandb
6
+
7
+ from src import metrics, utils, delinker
8
+ from src.const import LINKER_SIZE_DIST
9
+ from src.egnn import Dynamics, DynamicsWithPockets
10
+ from src.edm import EDM, InpaintingEDM
11
+ from src.datasets import (
12
+ ZincDataset, MOADDataset, create_templates_for_linker_generation, get_dataloader, collate
13
+ )
14
+ from src.linker_size import DistributionNodes
15
+ from src.molecule_builder import build_molecules
16
+ from src.visualizer import save_xyz_file, visualize_chain
17
+ from typing import Dict, List, Optional
18
+ from tqdm import tqdm
19
+
20
+ from pdb import set_trace
21
+
22
+
23
+ def get_activation(activation):
24
+ print(activation)
25
+ if activation == 'silu':
26
+ return torch.nn.SiLU()
27
+ else:
28
+ raise Exception("activation fn not supported yet. Add it here.")
29
+
30
+
31
+ class DDPM(pl.LightningModule):
32
+ train_dataset = None
33
+ val_dataset = None
34
+ test_dataset = None
35
+ starting_epoch = None
36
+ metrics: Dict[str, List[float]] = {}
37
+
38
+ FRAMES = 100
39
+
40
+ def __init__(
41
+ self,
42
+ in_node_nf, n_dims, context_node_nf, hidden_nf, activation, tanh, n_layers, attention, norm_constant,
43
+ inv_sublayers, sin_embedding, normalization_factor, aggregation_method,
44
+ diffusion_steps, diffusion_noise_schedule, diffusion_noise_precision, diffusion_loss_type,
45
+ normalize_factors, include_charges, model,
46
+ data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
47
+ normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
48
+ center_of_mass='fragments', inpainting=False, anchors_context=True,
49
+ ):
50
+ super(DDPM, self).__init__()
51
+
52
+ self.save_hyperparameters()
53
+ self.data_path = data_path
54
+ self.train_data_prefix = train_data_prefix
55
+ self.val_data_prefix = val_data_prefix
56
+ self.batch_size = batch_size
57
+ self.lr = lr
58
+ self.torch_device = torch_device
59
+ self.include_charges = include_charges
60
+ self.test_epochs = test_epochs
61
+ self.n_stability_samples = n_stability_samples
62
+ self.log_iterations = log_iterations
63
+ self.samples_dir = samples_dir
64
+ self.data_augmentation = data_augmentation
65
+ self.center_of_mass = center_of_mass
66
+ self.inpainting = inpainting
67
+ self.loss_type = diffusion_loss_type
68
+
69
+ self.n_dims = n_dims
70
+ self.num_classes = in_node_nf - include_charges
71
+ self.include_charges = include_charges
72
+ self.anchors_context = anchors_context
73
+
74
+ self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
75
+
76
+ if type(activation) is str:
77
+ activation = get_activation(activation)
78
+
79
+ dynamics_class = DynamicsWithPockets if '.' in train_data_prefix else Dynamics
80
+ dynamics = dynamics_class(
81
+ in_node_nf=in_node_nf,
82
+ n_dims=n_dims,
83
+ context_node_nf=context_node_nf,
84
+ device=torch_device,
85
+ hidden_nf=hidden_nf,
86
+ activation=activation,
87
+ n_layers=n_layers,
88
+ attention=attention,
89
+ tanh=tanh,
90
+ norm_constant=norm_constant,
91
+ inv_sublayers=inv_sublayers,
92
+ sin_embedding=sin_embedding,
93
+ normalization_factor=normalization_factor,
94
+ aggregation_method=aggregation_method,
95
+ model=model,
96
+ normalization=normalization,
97
+ centering=inpainting,
98
+ )
99
+ edm_class = InpaintingEDM if inpainting else EDM
100
+ self.edm = edm_class(
101
+ dynamics=dynamics,
102
+ in_node_nf=in_node_nf,
103
+ n_dims=n_dims,
104
+ timesteps=diffusion_steps,
105
+ noise_schedule=diffusion_noise_schedule,
106
+ noise_precision=diffusion_noise_precision,
107
+ loss_type=diffusion_loss_type,
108
+ norm_values=normalize_factors,
109
+ )
110
+ self.linker_size_sampler = DistributionNodes(LINKER_SIZE_DIST)
111
+
112
+ def setup(self, stage: Optional[str] = None):
113
+ dataset_type = MOADDataset if '.' in self.train_data_prefix else ZincDataset
114
+ if stage == 'fit':
115
+ self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
116
+ self.train_dataset = dataset_type(
117
+ data_path=self.data_path,
118
+ prefix=self.train_data_prefix,
119
+ device=self.torch_device
120
+ )
121
+ self.val_dataset = dataset_type(
122
+ data_path=self.data_path,
123
+ prefix=self.val_data_prefix,
124
+ device=self.torch_device
125
+ )
126
+ elif stage == 'val':
127
+ self.is_geom = ('geom' in self.val_data_prefix) or ('MOAD' in self.val_data_prefix)
128
+ self.val_dataset = dataset_type(
129
+ data_path=self.data_path,
130
+ prefix=self.val_data_prefix,
131
+ device=self.torch_device
132
+ )
133
+ else:
134
+ raise NotImplementedError
135
+
136
+ def train_dataloader(self, collate_fn=collate):
137
+ return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_fn, shuffle=True)
138
+
139
+ def val_dataloader(self, collate_fn=collate):
140
+ return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_fn)
141
+
142
+ def test_dataloader(self, collate_fn=collate):
143
+ return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_fn)
144
+
145
+ def forward(self, data, training):
146
+ x = data['positions']
147
+ h = data['one_hot']
148
+ node_mask = data['atom_mask']
149
+ edge_mask = data['edge_mask']
150
+ anchors = data['anchors']
151
+ fragment_mask = data['fragment_mask']
152
+ linker_mask = data['linker_mask']
153
+
154
+ # Anchors and fragments labels are used as context
155
+ if self.anchors_context:
156
+ context = torch.cat([anchors, fragment_mask], dim=-1)
157
+ else:
158
+ context = fragment_mask
159
+
160
+ # Add information about pocket to the context
161
+ if '.' in self.train_data_prefix:
162
+ fragment_pocket_mask = fragment_mask
163
+ fragment_only_mask = data['fragment_only_mask']
164
+ pocket_only_mask = fragment_pocket_mask - fragment_only_mask
165
+ if self.anchors_context:
166
+ context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
167
+ else:
168
+ context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
169
+
170
+ # Removing COM of fragment from the atom coordinates
171
+ if self.inpainting:
172
+ center_of_mass_mask = node_mask
173
+ elif self.center_of_mass == 'fragments':
174
+ center_of_mass_mask = fragment_mask
175
+ elif self.center_of_mass == 'anchors':
176
+ center_of_mass_mask = anchors
177
+ else:
178
+ raise NotImplementedError(self.center_of_mass)
179
+ x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
180
+ utils.assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask)
181
+
182
+ # Applying random rotation
183
+ if training and self.data_augmentation:
184
+ x = utils.random_rotation(x)
185
+
186
+ return self.edm.forward(
187
+ x=x,
188
+ h=h,
189
+ node_mask=node_mask,
190
+ fragment_mask=fragment_mask,
191
+ linker_mask=linker_mask,
192
+ edge_mask=edge_mask,
193
+ context=context
194
+ )
195
+
196
+ def training_step(self, data, *args):
197
+ delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=True)
198
+ vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
199
+ if self.loss_type == 'l2':
200
+ loss = l2_loss
201
+ elif self.loss_type == 'vlb':
202
+ loss = vlb_loss
203
+ else:
204
+ raise NotImplementedError(self.loss_type)
205
+
206
+ training_metrics = {
207
+ 'loss': loss,
208
+ 'delta_log_px': delta_log_px,
209
+ 'kl_prior': kl_prior,
210
+ 'loss_term_t': loss_term_t,
211
+ 'loss_term_0': loss_term_0,
212
+ 'l2_loss': l2_loss,
213
+ 'vlb_loss': vlb_loss,
214
+ 'noise_t': noise_t,
215
+ 'noise_0': noise_0
216
+ }
217
+ if self.log_iterations is not None and self.global_step % self.log_iterations == 0:
218
+ for metric_name, metric in training_metrics.items():
219
+ self.metrics.setdefault(f'{metric_name}/train', []).append(metric)
220
+ self.log(f'{metric_name}/train', metric, prog_bar=True)
221
+ return training_metrics
222
+
223
+ def validation_step(self, data, *args):
224
+ delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
225
+ vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
226
+ if self.loss_type == 'l2':
227
+ loss = l2_loss
228
+ elif self.loss_type == 'vlb':
229
+ loss = vlb_loss
230
+ else:
231
+ raise NotImplementedError(self.loss_type)
232
+ return {
233
+ 'loss': loss,
234
+ 'delta_log_px': delta_log_px,
235
+ 'kl_prior': kl_prior,
236
+ 'loss_term_t': loss_term_t,
237
+ 'loss_term_0': loss_term_0,
238
+ 'l2_loss': l2_loss,
239
+ 'vlb_loss': vlb_loss,
240
+ 'noise_t': noise_t,
241
+ 'noise_0': noise_0
242
+ }
243
+
244
+ def test_step(self, data, *args):
245
+ delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
246
+ vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
247
+ if self.loss_type == 'l2':
248
+ loss = l2_loss
249
+ elif self.loss_type == 'vlb':
250
+ loss = vlb_loss
251
+ else:
252
+ raise NotImplementedError(self.loss_type)
253
+ return {
254
+ 'loss': loss,
255
+ 'delta_log_px': delta_log_px,
256
+ 'kl_prior': kl_prior,
257
+ 'loss_term_t': loss_term_t,
258
+ 'loss_term_0': loss_term_0,
259
+ 'l2_loss': l2_loss,
260
+ 'vlb_loss': vlb_loss,
261
+ 'noise_t': noise_t,
262
+ 'noise_0': noise_0
263
+ }
264
+
265
+ def training_epoch_end(self, training_step_outputs):
266
+ for metric in training_step_outputs[0].keys():
267
+ avg_metric = self.aggregate_metric(training_step_outputs, metric)
268
+ self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
269
+ self.log(f'{metric}/train', avg_metric, prog_bar=True)
270
+
271
+ def validation_epoch_end(self, validation_step_outputs):
272
+ for metric in validation_step_outputs[0].keys():
273
+ avg_metric = self.aggregate_metric(validation_step_outputs, metric)
274
+ self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
275
+ self.log(f'{metric}/val', avg_metric, prog_bar=True)
276
+
277
+ if (self.current_epoch + 1) % self.test_epochs == 0:
278
+ sampling_results = self.sample_and_analyze(self.val_dataloader())
279
+ for metric_name, metric_value in sampling_results.items():
280
+ self.log(f'{metric_name}/val', metric_value, prog_bar=True)
281
+ self.metrics.setdefault(f'{metric_name}/val', []).append(metric_value)
282
+
283
+ # Logging the results corresponding to the best validation_and_connectivity
284
+ best_metrics, best_epoch = self.compute_best_validation_metrics()
285
+ self.log('best_epoch', int(best_epoch), prog_bar=True, batch_size=self.batch_size)
286
+ for metric, value in best_metrics.items():
287
+ self.log(f'best_{metric}', value, prog_bar=True, batch_size=self.batch_size)
288
+
289
+ def test_epoch_end(self, test_step_outputs):
290
+ for metric in test_step_outputs[0].keys():
291
+ avg_metric = self.aggregate_metric(test_step_outputs, metric)
292
+ self.metrics.setdefault(f'{metric}/test', []).append(avg_metric)
293
+ self.log(f'{metric}/test', avg_metric, prog_bar=True)
294
+
295
+ if (self.current_epoch + 1) % self.test_epochs == 0:
296
+ sampling_results = self.sample_and_analyze(self.test_dataloader())
297
+ for metric_name, metric_value in sampling_results.items():
298
+ self.log(f'{metric_name}/test', metric_value, prog_bar=True)
299
+ self.metrics.setdefault(f'{metric_name}/test', []).append(metric_value)
300
+
301
+ def generate_animation(self, chain_batch, node_mask, batch_i):
302
+ batch_indices, mol_indices = utils.get_batch_idx_for_animation(self.batch_size, batch_i)
303
+ for bi, mi in zip(batch_indices, mol_indices):
304
+ chain = chain_batch[:, bi, :, :]
305
+ name = f'mol_{mi}'
306
+ chain_output = os.path.join(self.samples_dir, f'epoch_{self.current_epoch}', name)
307
+ os.makedirs(chain_output, exist_ok=True)
308
+
309
+ one_hot = chain[:, :, 3:-1] if self.include_charges else chain[:, :, 3:]
310
+ positions = chain[:, :, :3]
311
+ chain_node_mask = torch.cat([node_mask[bi].unsqueeze(0) for _ in range(self.FRAMES)], dim=0)
312
+ names = [f'{name}_{j}' for j in range(self.FRAMES)]
313
+
314
+ save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=self.is_geom)
315
+ visualize_chain(chain_output, wandb=wandb, mode=name, is_geom=self.is_geom)
316
+
317
+ def sample_and_analyze(self, dataloader):
318
+ pred_molecules = []
319
+ true_molecules = []
320
+ true_fragments = []
321
+
322
+ for b, data in tqdm(enumerate(dataloader), total=len(dataloader), desc='Sampling'):
323
+ atom_mask = data['atom_mask']
324
+ fragment_mask = data['fragment_mask']
325
+
326
+ # Save molecules without pockets
327
+ if '.' in self.train_data_prefix:
328
+ atom_mask = data['atom_mask'] - data['pocket_mask']
329
+ fragment_mask = data['fragment_only_mask']
330
+
331
+ true_molecules_batch = build_molecules(
332
+ data['one_hot'],
333
+ data['positions'],
334
+ atom_mask,
335
+ is_geom=self.is_geom,
336
+ )
337
+ true_fragments_batch = build_molecules(
338
+ data['one_hot'],
339
+ data['positions'],
340
+ fragment_mask,
341
+ is_geom=self.is_geom,
342
+ )
343
+
344
+ for sample_idx in tqdm(range(self.n_stability_samples)):
345
+ try:
346
+ chain_batch, node_mask = self.sample_chain(data, keep_frames=self.FRAMES)
347
+ except utils.FoundNaNException as e:
348
+ for idx in e.x_h_nan_idx:
349
+ smiles = data['name'][idx]
350
+ print(f'FoundNaNException: [xh], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
351
+ for idx in e.only_x_nan_idx:
352
+ smiles = data['name'][idx]
353
+ print(f'FoundNaNException: [x ], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
354
+ for idx in e.only_h_nan_idx:
355
+ smiles = data['name'][idx]
356
+ print(f'FoundNaNException: [ h], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
357
+ continue
358
+
359
+ # Get final molecules from chains – for computing metrics
360
+ x, h = utils.split_features(
361
+ z=chain_batch[0],
362
+ n_dims=self.n_dims,
363
+ num_classes=self.num_classes,
364
+ include_charges=self.include_charges,
365
+ )
366
+
367
+ # Save molecules without pockets
368
+ if '.' in self.train_data_prefix:
369
+ node_mask = node_mask - data['pocket_mask']
370
+
371
+ one_hot = h['categorical']
372
+ pred_molecules_batch = build_molecules(one_hot, x, node_mask, is_geom=self.is_geom)
373
+
374
+ # Adding only results for valid ground truth molecules
375
+ for pred_mol, true_mol, frag in zip(pred_molecules_batch, true_molecules_batch, true_fragments_batch):
376
+ if metrics.is_valid(true_mol):
377
+ pred_molecules.append(pred_mol)
378
+ true_molecules.append(true_mol)
379
+ true_fragments.append(frag)
380
+
381
+ # Generate animation – will always do it for molecules with idx 0, 110 and 360
382
+ if self.samples_dir is not None and sample_idx == 0:
383
+ self.generate_animation(chain_batch=chain_batch, node_mask=node_mask, batch_i=b)
384
+
385
+ # Our own & DeLinker metrics
386
+ our_metrics = metrics.compute_metrics(
387
+ pred_molecules=pred_molecules,
388
+ true_molecules=true_molecules
389
+ )
390
+ delinker_metrics = delinker.get_delinker_metrics(
391
+ pred_molecules=pred_molecules,
392
+ true_molecules=true_molecules,
393
+ true_fragments=true_fragments
394
+ )
395
+ return {
396
+ **our_metrics,
397
+ **delinker_metrics
398
+ }
399
+
400
+ def sample_chain(self, data, sample_fn=None, keep_frames=None):
401
+ if sample_fn is None:
402
+ linker_sizes = data['linker_mask'].sum(1).view(-1).int()
403
+ else:
404
+ linker_sizes = sample_fn(data)
405
+
406
+ if self.inpainting:
407
+ template_data = data
408
+ else:
409
+ template_data = create_templates_for_linker_generation(data, linker_sizes)
410
+
411
+ x = template_data['positions']
412
+ node_mask = template_data['atom_mask']
413
+ edge_mask = template_data['edge_mask']
414
+ h = template_data['one_hot']
415
+ anchors = template_data['anchors']
416
+ fragment_mask = template_data['fragment_mask']
417
+ linker_mask = template_data['linker_mask']
418
+
419
+ # Anchors and fragments labels are used as context
420
+ if self.anchors_context:
421
+ context = torch.cat([anchors, fragment_mask], dim=-1)
422
+ else:
423
+ context = fragment_mask
424
+
425
+ # Add information about pocket to the context
426
+ if '.' in self.train_data_prefix:
427
+ fragment_pocket_mask = fragment_mask
428
+ fragment_only_mask = data['fragment_only_mask']
429
+ pocket_only_mask = fragment_pocket_mask - fragment_only_mask
430
+ if self.anchors_context:
431
+ context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
432
+ else:
433
+ context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
434
+
435
+ # Removing COM of fragment from the atom coordinates
436
+ if self.inpainting:
437
+ center_of_mass_mask = node_mask
438
+ elif self.center_of_mass == 'fragments':
439
+ center_of_mass_mask = fragment_mask
440
+ elif self.center_of_mass == 'anchors':
441
+ center_of_mass_mask = anchors
442
+ else:
443
+ raise NotImplementedError(self.center_of_mass)
444
+ x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
445
+
446
+ chain = self.edm.sample_chain(
447
+ x=x,
448
+ h=h,
449
+ node_mask=node_mask,
450
+ edge_mask=edge_mask,
451
+ fragment_mask=fragment_mask,
452
+ linker_mask=linker_mask,
453
+ context=context,
454
+ keep_frames=keep_frames,
455
+ )
456
+ return chain, node_mask
457
+
458
+ def configure_optimizers(self):
459
+ return torch.optim.AdamW(self.edm.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
460
+
461
+ def compute_best_validation_metrics(self):
462
+ loss = self.metrics[f'validity_and_connectivity/val']
463
+ best_epoch = np.argmax(loss)
464
+ best_metrics = {
465
+ metric_name: metric_values[best_epoch]
466
+ for metric_name, metric_values in self.metrics.items()
467
+ if metric_name.endswith('/val')
468
+ }
469
+ return best_metrics, best_epoch
470
+
471
+ @staticmethod
472
+ def aggregate_metric(step_outputs, metric):
473
+ return torch.tensor([out[metric] for out in step_outputs]).mean()
src/linker_size.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from torch.distributions.categorical import Categorical
6
+ from src.egnn import GCL
7
+
8
+
9
+ class DistributionNodes:
10
+ def __init__(self, histogram):
11
+
12
+ self.n_nodes = []
13
+ prob = []
14
+ self.keys = {}
15
+ for i, nodes in enumerate(histogram):
16
+ self.n_nodes.append(nodes)
17
+ self.keys[nodes] = i
18
+ prob.append(histogram[nodes])
19
+ self.n_nodes = torch.tensor(self.n_nodes)
20
+ prob = np.array(prob)
21
+ prob = prob/np.sum(prob)
22
+
23
+ self.prob = torch.from_numpy(prob).float()
24
+
25
+ entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
26
+ print("Entropy of n_nodes: H[N]", entropy.item())
27
+
28
+ self.m = Categorical(torch.tensor(prob))
29
+
30
+ def sample(self, n_samples=1):
31
+ idx = self.m.sample((n_samples,))
32
+ return self.n_nodes[idx]
33
+
34
+ def log_prob(self, batch_n_nodes):
35
+ assert len(batch_n_nodes.size()) == 1
36
+
37
+ idcs = [self.keys[i.item()] for i in batch_n_nodes]
38
+ idcs = torch.tensor(idcs).to(batch_n_nodes.device)
39
+
40
+ log_p = torch.log(self.prob + 1e-30)
41
+
42
+ log_p = log_p.to(batch_n_nodes.device)
43
+
44
+ log_probs = log_p[idcs]
45
+
46
+ return log_probs
47
+
48
+
49
+ class SizeGNN(nn.Module):
50
+ def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'):
51
+ super(SizeGNN, self).__init__()
52
+ self.hidden_nf = hidden_nf
53
+ self.out_node_nf = out_node_nf
54
+ self.device = device
55
+
56
+ self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
57
+ self.gcl1 = GCL(
58
+ input_nf=self.hidden_nf,
59
+ output_nf=self.hidden_nf,
60
+ hidden_nf=self.hidden_nf,
61
+ normalization_factor=1,
62
+ aggregation_method='sum',
63
+ edges_in_d=1,
64
+ activation=nn.ReLU(),
65
+ attention=False,
66
+ normalization=normalization
67
+ )
68
+
69
+ layers = []
70
+ for i in range(n_layers - 1):
71
+ layer = GCL(
72
+ input_nf=self.hidden_nf,
73
+ output_nf=self.hidden_nf,
74
+ hidden_nf=self.hidden_nf,
75
+ normalization_factor=1,
76
+ aggregation_method='sum',
77
+ edges_in_d=1,
78
+ activation=nn.ReLU(),
79
+ attention=False,
80
+ normalization=normalization
81
+ )
82
+ layers.append(layer)
83
+
84
+ self.gcl_layers = nn.ModuleList(layers)
85
+ self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
86
+ self.to(self.device)
87
+
88
+ def forward(self, h, edges, distances, node_mask, edge_mask):
89
+ h = self.embedding_in(h)
90
+ h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
91
+ for gcl in self.gcl_layers:
92
+ h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
93
+
94
+ h = self.embedding_out(h)
95
+ return h
src/linker_size_lightning.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+ from src.const import ZINC_TRAIN_LINKER_ID2SIZE, ZINC_TRAIN_LINKER_SIZE2ID
5
+ from src.linker_size import SizeGNN
6
+ from src.egnn import coord2diff
7
+ from src.datasets import ZincDataset, get_dataloader, collate_with_fragment_edges
8
+ from typing import Dict, List, Optional
9
+ from torch.nn.functional import cross_entropy, mse_loss, sigmoid
10
+
11
+ from pdb import set_trace
12
+
13
+
14
+ class SizeClassifier(pl.LightningModule):
15
+ train_dataset = None
16
+ val_dataset = None
17
+ test_dataset = None
18
+ metrics: Dict[str, List[float]] = {}
19
+
20
+ def __init__(
21
+ self, data_path, train_data_prefix, val_data_prefix,
22
+ in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device,
23
+ normalization=None,
24
+ loss_weights=None,
25
+ min_linker_size=None,
26
+ linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID,
27
+ linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE,
28
+ task='classification',
29
+ ):
30
+ super(SizeClassifier, self).__init__()
31
+
32
+ self.save_hyperparameters()
33
+ self.data_path = data_path
34
+ self.train_data_prefix = train_data_prefix
35
+ self.val_data_prefix = val_data_prefix
36
+ self.min_linker_size = min_linker_size
37
+ self.linker_size2id = linker_size2id
38
+ self.linker_id2size = linker_id2size
39
+ self.batch_size = batch_size
40
+ self.lr = lr
41
+ self.torch_device = torch_device
42
+ self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=torch_device)
43
+ self.gnn = SizeGNN(
44
+ in_node_nf=in_node_nf,
45
+ hidden_nf=hidden_nf,
46
+ out_node_nf=out_node_nf,
47
+ n_layers=n_layers,
48
+ device=torch_device,
49
+ normalization=normalization,
50
+ )
51
+
52
+ def setup(self, stage: Optional[str] = None):
53
+ if stage == 'fit':
54
+ self.train_dataset = ZincDataset(
55
+ data_path=self.data_path,
56
+ prefix=self.train_data_prefix,
57
+ device=self.torch_device
58
+ )
59
+ self.val_dataset = ZincDataset(
60
+ data_path=self.data_path,
61
+ prefix=self.val_data_prefix,
62
+ device=self.torch_device
63
+ )
64
+ elif stage == 'val':
65
+ self.val_dataset = ZincDataset(
66
+ data_path=self.data_path,
67
+ prefix=self.val_data_prefix,
68
+ device=self.torch_device
69
+ )
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ def train_dataloader(self):
74
+ return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
75
+
76
+ def val_dataloader(self):
77
+ return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
78
+
79
+ def test_dataloader(self):
80
+ return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
+
82
+ def forward(self, data):
83
+ h = data['one_hot']
84
+ x = data['positions']
85
+ fragment_mask = data['fragment_mask']
86
+ linker_mask = data['linker_mask']
87
+ edge_mask = data['edge_mask']
88
+ edges = data['edges']
89
+
90
+ # Considering only fragments
91
+ x = x * fragment_mask
92
+ h = h * fragment_mask
93
+
94
+ # Reshaping
95
+ bs, n_nodes = x.shape[0], x.shape[1]
96
+ fragment_mask = fragment_mask.view(bs * n_nodes, 1)
97
+ x = x.view(bs * n_nodes, -1)
98
+ h = h.view(bs * n_nodes, -1)
99
+
100
+ # Prediction
101
+ distances, _ = coord2diff(x, edges)
102
+ distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
103
+ output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
104
+ output = output.view(bs, n_nodes, -1).mean(1)
105
+
106
+ true = self.get_true_labels(linker_mask)
107
+ loss = cross_entropy(output, true, weight=self.loss_weights)
108
+
109
+ return output, loss
110
+
111
+ def get_true_labels(self, linker_mask):
112
+ labels = []
113
+ sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy()
114
+ for size in sizes:
115
+ label = self.linker_size2id.get(size)
116
+ if label is None:
117
+ label = self.linker_size2id[max(self.linker_id2size)]
118
+ labels.append(label)
119
+ labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long)
120
+ return labels
121
+
122
+ def training_step(self, data, *args):
123
+ _, loss = self.forward(data)
124
+ return {'loss': loss}
125
+
126
+ def validation_step(self, data, *args):
127
+ _, loss = self.forward(data)
128
+ return {'loss': loss}
129
+
130
+ def test_step(self, data, *args):
131
+ loss = self.forward(data)
132
+ return {'loss': loss}
133
+
134
+ def training_epoch_end(self, training_step_outputs):
135
+ for metric in training_step_outputs[0].keys():
136
+ avg_metric = self.aggregate_metric(training_step_outputs, metric)
137
+ self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
138
+ self.log(f'{metric}/train', avg_metric, prog_bar=True)
139
+
140
+ def validation_epoch_end(self, validation_step_outputs):
141
+ for metric in validation_step_outputs[0].keys():
142
+ avg_metric = self.aggregate_metric(validation_step_outputs, metric)
143
+ self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
144
+ self.log(f'{metric}/val', avg_metric, prog_bar=True)
145
+
146
+ correct = 0
147
+ total = 0
148
+ for data in self.val_dataloader():
149
+ output, _ = self.forward(data)
150
+ pred = output.argmax(dim=-1)
151
+ true = self.get_true_labels(data['linker_mask'])
152
+ correct += (pred == true).sum()
153
+ total += len(pred)
154
+
155
+ accuracy = correct / total
156
+ self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
157
+ self.log(f'accuracy/val', accuracy, prog_bar=True)
158
+
159
+ def configure_optimizers(self):
160
+ return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
161
+
162
+ @staticmethod
163
+ def aggregate_metric(step_outputs, metric):
164
+ return torch.tensor([out[metric] for out in step_outputs]).mean()
165
+
166
+
167
+ class SizeOrdinalClassifier(pl.LightningModule):
168
+ train_dataset = None
169
+ val_dataset = None
170
+ test_dataset = None
171
+ metrics: Dict[str, List[float]] = {}
172
+
173
+ def __init__(
174
+ self, data_path, train_data_prefix, val_data_prefix,
175
+ in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device,
176
+ normalization=None,
177
+ min_linker_size=None,
178
+ linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID,
179
+ linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE,
180
+ task='ordinal',
181
+ ):
182
+ super(SizeOrdinalClassifier, self).__init__()
183
+
184
+ self.save_hyperparameters()
185
+ self.data_path = data_path
186
+ self.train_data_prefix = train_data_prefix
187
+ self.val_data_prefix = val_data_prefix
188
+ self.min_linker_size = min_linker_size
189
+ self.batch_size = batch_size
190
+ self.lr = lr
191
+ self.torch_device = torch_device
192
+ self.linker_size2id = linker_size2id
193
+ self.linker_id2size = linker_id2size
194
+ self.gnn = SizeGNN(
195
+ in_node_nf=in_node_nf,
196
+ hidden_nf=hidden_nf,
197
+ out_node_nf=out_node_nf,
198
+ n_layers=n_layers,
199
+ device=torch_device,
200
+ normalization=normalization,
201
+ )
202
+
203
+ def setup(self, stage: Optional[str] = None):
204
+ if stage == 'fit':
205
+ self.train_dataset = ZincDataset(
206
+ data_path=self.data_path,
207
+ prefix=self.train_data_prefix,
208
+ device=self.torch_device
209
+ )
210
+ self.val_dataset = ZincDataset(
211
+ data_path=self.data_path,
212
+ prefix=self.val_data_prefix,
213
+ device=self.torch_device
214
+ )
215
+ elif stage == 'val':
216
+ self.val_dataset = ZincDataset(
217
+ data_path=self.data_path,
218
+ prefix=self.val_data_prefix,
219
+ device=self.torch_device
220
+ )
221
+ else:
222
+ raise NotImplementedError
223
+
224
+ def train_dataloader(self):
225
+ return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
226
+
227
+ def val_dataloader(self):
228
+ return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
229
+
230
+ def test_dataloader(self):
231
+ return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
232
+
233
+ def forward(self, data):
234
+ h = data['one_hot']
235
+ x = data['positions']
236
+ fragment_mask = data['fragment_mask']
237
+ linker_mask = data['linker_mask']
238
+ edge_mask = data['edge_mask']
239
+ edges = data['edges']
240
+
241
+ # Considering only fragments
242
+ x = x * fragment_mask
243
+ h = h * fragment_mask
244
+
245
+ # Reshaping
246
+ bs, n_nodes = x.shape[0], x.shape[1]
247
+ fragment_mask = fragment_mask.view(bs * n_nodes, 1)
248
+ x = x.view(bs * n_nodes, -1)
249
+ h = h.view(bs * n_nodes, -1)
250
+
251
+ # Prediction
252
+ distances, _ = coord2diff(x, edges)
253
+ distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
254
+ output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
255
+ output = output.view(bs, n_nodes, -1).mean(1)
256
+ output = sigmoid(output)
257
+
258
+ true = self.get_true_labels(linker_mask)
259
+ loss = self.ordinal_loss(output, true)
260
+
261
+ return output, loss
262
+
263
+ def ordinal_loss(self, pred, true):
264
+ target = torch.zeros_like(pred, device=self.torch_device)
265
+ for i, label in enumerate(true):
266
+ target[i, 0:label + 1] = 1
267
+
268
+ return mse_loss(pred, target, reduction='none').sum(1).mean()
269
+
270
+ def get_true_labels(self, linker_mask):
271
+ labels = []
272
+ sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy()
273
+ for size in sizes:
274
+ label = self.linker_size2id.get(size)
275
+ if label is None:
276
+ label = self.linker_size2id[max(self.linker_id2size)]
277
+ labels.append(label)
278
+ labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long)
279
+ return labels
280
+
281
+ @staticmethod
282
+ def prediction2label(pred):
283
+ return torch.cumprod(pred > 0.5, dim=1).sum(dim=1) - 1
284
+
285
+ def training_step(self, data, *args):
286
+ _, loss = self.forward(data)
287
+ return {'loss': loss}
288
+
289
+ def validation_step(self, data, *args):
290
+ _, loss = self.forward(data)
291
+ return {'loss': loss}
292
+
293
+ def test_step(self, data, *args):
294
+ loss = self.forward(data)
295
+ return {'loss': loss}
296
+
297
+ def training_epoch_end(self, training_step_outputs):
298
+ for metric in training_step_outputs[0].keys():
299
+ avg_metric = self.aggregate_metric(training_step_outputs, metric)
300
+ self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
301
+ self.log(f'{metric}/train', avg_metric, prog_bar=True)
302
+
303
+ def validation_epoch_end(self, validation_step_outputs):
304
+ for metric in validation_step_outputs[0].keys():
305
+ avg_metric = self.aggregate_metric(validation_step_outputs, metric)
306
+ self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
307
+ self.log(f'{metric}/val', avg_metric, prog_bar=True)
308
+
309
+ correct = 0
310
+ total = 0
311
+ for data in self.val_dataloader():
312
+ output, _ = self.forward(data)
313
+ pred = self.prediction2label(output)
314
+ true = self.get_true_labels(data['linker_mask'])
315
+ correct += (pred == true).sum()
316
+ total += len(pred)
317
+
318
+ accuracy = correct / total
319
+ self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
320
+ self.log(f'accuracy/val', accuracy, prog_bar=True)
321
+
322
+ def configure_optimizers(self):
323
+ return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
324
+
325
+ @staticmethod
326
+ def aggregate_metric(step_outputs, metric):
327
+ return torch.tensor([out[metric] for out in step_outputs]).mean()
328
+
329
+
330
+ class SizeRegressor(pl.LightningModule):
331
+ train_dataset = None
332
+ val_dataset = None
333
+ test_dataset = None
334
+ metrics: Dict[str, List[float]] = {}
335
+
336
+ def __init__(
337
+ self, data_path, train_data_prefix, val_data_prefix,
338
+ in_node_nf, hidden_nf, n_layers, batch_size, lr, torch_device,
339
+ normalization=None, task='regression',
340
+ ):
341
+ super(SizeRegressor, self).__init__()
342
+
343
+ self.save_hyperparameters()
344
+ self.data_path = data_path
345
+ self.train_data_prefix = train_data_prefix
346
+ self.val_data_prefix = val_data_prefix
347
+ self.batch_size = batch_size
348
+ self.lr = lr
349
+ self.torch_device = torch_device
350
+ self.gnn = SizeGNN(
351
+ in_node_nf=in_node_nf,
352
+ hidden_nf=hidden_nf,
353
+ out_node_nf=1,
354
+ n_layers=n_layers,
355
+ device=torch_device,
356
+ normalization=normalization,
357
+ )
358
+
359
+ def setup(self, stage: Optional[str] = None):
360
+ if stage == 'fit':
361
+ self.train_dataset = ZincDataset(
362
+ data_path=self.data_path,
363
+ prefix=self.train_data_prefix,
364
+ device=self.torch_device
365
+ )
366
+ self.val_dataset = ZincDataset(
367
+ data_path=self.data_path,
368
+ prefix=self.val_data_prefix,
369
+ device=self.torch_device
370
+ )
371
+ elif stage == 'val':
372
+ self.val_dataset = ZincDataset(
373
+ data_path=self.data_path,
374
+ prefix=self.val_data_prefix,
375
+ device=self.torch_device
376
+ )
377
+ else:
378
+ raise NotImplementedError
379
+
380
+ def train_dataloader(self):
381
+ return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True)
382
+
383
+ def val_dataloader(self):
384
+ return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
385
+
386
+ def test_dataloader(self):
387
+ return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
388
+
389
+ def forward(self, data):
390
+ h = data['one_hot']
391
+ x = data['positions']
392
+ fragment_mask = data['fragment_mask']
393
+ linker_mask = data['linker_mask']
394
+ edge_mask = data['edge_mask']
395
+ edges = data['edges']
396
+
397
+ # Considering only fragments
398
+ x = x * fragment_mask
399
+ h = h * fragment_mask
400
+
401
+ # Reshaping
402
+ bs, n_nodes = x.shape[0], x.shape[1]
403
+ fragment_mask = fragment_mask.view(bs * n_nodes, 1)
404
+ x = x.view(bs * n_nodes, -1)
405
+ h = h.view(bs * n_nodes, -1)
406
+
407
+ # Prediction
408
+ distances, _ = coord2diff(x, edges)
409
+ distance_edge_mask = (edge_mask.bool() & (distances < 6)).long()
410
+ output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
411
+ output = output.view(bs, n_nodes, -1).mean(1).squeeze()
412
+
413
+ true = linker_mask.squeeze().sum(-1).float()
414
+ loss = mse_loss(output, true)
415
+
416
+ return output, loss
417
+
418
+ def training_step(self, data, *args):
419
+ _, loss = self.forward(data)
420
+ return {'loss': loss}
421
+
422
+ def validation_step(self, data, *args):
423
+ _, loss = self.forward(data)
424
+ return {'loss': loss}
425
+
426
+ def test_step(self, data, *args):
427
+ loss = self.forward(data)
428
+ return {'loss': loss}
429
+
430
+ def training_epoch_end(self, training_step_outputs):
431
+ for metric in training_step_outputs[0].keys():
432
+ avg_metric = self.aggregate_metric(training_step_outputs, metric)
433
+ self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
434
+ self.log(f'{metric}/train', avg_metric, prog_bar=True)
435
+
436
+ def validation_epoch_end(self, validation_step_outputs):
437
+ for metric in validation_step_outputs[0].keys():
438
+ avg_metric = self.aggregate_metric(validation_step_outputs, metric)
439
+ self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
440
+ self.log(f'{metric}/val', avg_metric, prog_bar=True)
441
+
442
+ correct = 0
443
+ total = 0
444
+ for data in self.val_dataloader():
445
+ output, _ = self.forward(data)
446
+ pred = torch.round(output).long()
447
+ true = data['linker_mask'].squeeze().sum(-1).long()
448
+ correct += (pred == true).sum()
449
+ total += len(pred)
450
+
451
+ accuracy = correct / total
452
+ self.metrics.setdefault(f'accuracy/val', []).append(accuracy)
453
+ self.log(f'accuracy/val', accuracy, prog_bar=True)
454
+
455
+ def configure_optimizers(self):
456
+ return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
457
+
458
+ @staticmethod
459
+ def aggregate_metric(step_outputs, metric):
460
+ return torch.tensor([out[metric] for out in step_outputs]).mean()
src/metrics.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem
5
+ from src import const
6
+ from src.molecule_builder import get_bond_order
7
+ from scipy.stats import wasserstein_distance
8
+
9
+ from pdb import set_trace
10
+
11
+
12
+ def is_valid(mol):
13
+ try:
14
+ Chem.SanitizeMol(mol)
15
+ except ValueError:
16
+ return False
17
+ return True
18
+
19
+
20
+ def is_connected(mol):
21
+ try:
22
+ mol_frags = Chem.GetMolFrags(mol, asMols=True)
23
+ except Chem.rdchem.AtomValenceException:
24
+ return False
25
+ if len(mol_frags) != 1:
26
+ return False
27
+ return True
28
+
29
+
30
+ def get_valid_molecules(molecules):
31
+ valid = []
32
+ for mol in molecules:
33
+ if is_valid(mol):
34
+ valid.append(mol)
35
+ return valid
36
+
37
+
38
+ def get_connected_molecules(molecules):
39
+ connected = []
40
+ for mol in molecules:
41
+ if is_connected(mol):
42
+ connected.append(mol)
43
+ return connected
44
+
45
+
46
+ def get_unique_smiles(valid_molecules):
47
+ unique = set()
48
+ for mol in valid_molecules:
49
+ unique.add(Chem.MolToSmiles(mol))
50
+ return list(unique)
51
+
52
+
53
+ def get_novel_smiles(unique_true_smiles, unique_pred_smiles):
54
+ return list(set(unique_pred_smiles).difference(set(unique_true_smiles)))
55
+
56
+
57
+ def compute_energy(mol):
58
+ mp = AllChem.MMFFGetMoleculeProperties(mol)
59
+ energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy()
60
+ return energy
61
+
62
+
63
+ def wasserstein_distance_between_energies(true_molecules, pred_molecules):
64
+ true_energy_dist = []
65
+ for mol in true_molecules:
66
+ try:
67
+ energy = compute_energy(mol)
68
+ true_energy_dist.append(energy)
69
+ except:
70
+ continue
71
+
72
+ pred_energy_dist = []
73
+ for mol in pred_molecules:
74
+ try:
75
+ energy = compute_energy(mol)
76
+ pred_energy_dist.append(energy)
77
+ except:
78
+ continue
79
+
80
+ if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0:
81
+ return wasserstein_distance(true_energy_dist, pred_energy_dist)
82
+ else:
83
+ return 0
84
+
85
+
86
+ def compute_metrics(pred_molecules, true_molecules):
87
+ if len(pred_molecules) == 0:
88
+ return {
89
+ 'validity': 0,
90
+ 'validity_and_connectivity': 0,
91
+ 'validity_as_in_delinker': 0,
92
+ 'uniqueness': 0,
93
+ 'novelty': 0,
94
+ 'energies': 0,
95
+ }
96
+
97
+ # Passing rdkit.Chem.Sanitize filter
98
+ true_valid = get_valid_molecules(true_molecules)
99
+ pred_valid = get_valid_molecules(pred_molecules)
100
+ validity = len(pred_valid) / len(pred_molecules)
101
+
102
+ # Checking if molecule consists of a single connected part
103
+ true_valid_and_connected = get_connected_molecules(true_valid)
104
+ pred_valid_and_connected = get_connected_molecules(pred_valid)
105
+ validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules)
106
+
107
+ # Unique molecules
108
+ true_unique = get_unique_smiles(true_valid_and_connected)
109
+ pred_unique = get_unique_smiles(pred_valid_and_connected)
110
+ uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0
111
+
112
+ # Novel molecules
113
+ pred_novel = get_novel_smiles(true_unique, pred_unique)
114
+ novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0
115
+
116
+ # Difference between Energy distributions
117
+ energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected)
118
+
119
+ return {
120
+ 'validity': validity,
121
+ 'validity_and_connectivity': validity_and_connectivity,
122
+ 'uniqueness': uniqueness,
123
+ 'novelty': novelty,
124
+ 'energies': energies,
125
+ }
126
+
127
+
128
+ # def check_stability(positions, atom_types):
129
+ # assert len(positions.shape) == 2
130
+ # assert positions.shape[1] == 3
131
+ # x = positions[:, 0]
132
+ # y = positions[:, 1]
133
+ # z = positions[:, 2]
134
+ #
135
+ # nr_bonds = np.zeros(len(x), dtype='int')
136
+ # for i in range(len(x)):
137
+ # for j in range(i + 1, len(x)):
138
+ # p1 = np.array([x[i], y[i], z[i]])
139
+ # p2 = np.array([x[j], y[j], z[j]])
140
+ # dist = np.sqrt(np.sum((p1 - p2) ** 2))
141
+ # atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()]
142
+ # order = get_bond_order(atom1, atom2, dist)
143
+ # nr_bonds[i] += order
144
+ # nr_bonds[j] += order
145
+ # nr_stable_bonds = 0
146
+ # for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds):
147
+ # possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]]
148
+ # if type(possible_bonds) == int:
149
+ # is_stable = possible_bonds == nr_bonds_i
150
+ # else:
151
+ # is_stable = nr_bonds_i in possible_bonds
152
+ # nr_stable_bonds += int(is_stable)
153
+ #
154
+ # molecule_stable = nr_stable_bonds == len(x)
155
+ # return molecule_stable, nr_stable_bonds, len(x)
156
+ #
157
+ #
158
+ # def count_stable_molecules(one_hot, x, node_mask):
159
+ # stable_molecules = 0
160
+ # for i in range(len(one_hot)):
161
+ # mol_size = node_mask[i].sum()
162
+ # atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu()
163
+ # positions = x[i][:mol_size, :].detach().cpu()
164
+ # stable, _, _ = check_stability(positions, atom_types)
165
+ # stable_molecules += int(stable)
166
+ #
167
+ # return stable_molecules
src/molecule_builder.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from rdkit import Chem, Geometry
5
+
6
+ from src import const
7
+
8
+
9
+ def create_conformer(coords):
10
+ conformer = Chem.Conformer()
11
+ for i, (x, y, z) in enumerate(coords):
12
+ conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z))
13
+ return conformer
14
+
15
+
16
+ def build_molecules(one_hot, x, node_mask, is_geom, margins=const.MARGINS_EDM):
17
+ molecules = []
18
+ for i in range(len(one_hot)):
19
+ mask = node_mask[i].squeeze() == 1
20
+ atom_types = one_hot[i][mask].argmax(dim=1).detach().cpu()
21
+ positions = x[i][mask].detach().cpu()
22
+ mol = build_molecule(positions, atom_types, is_geom, margins=margins)
23
+ molecules.append(mol)
24
+
25
+ return molecules
26
+
27
+
28
+ def build_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
29
+ idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
30
+ X, A, E = build_xae_molecule(positions, atom_types, is_geom=is_geom, margins=margins)
31
+ mol = Chem.RWMol()
32
+ for atom in X:
33
+ a = Chem.Atom(idx2atom[atom.item()])
34
+ mol.AddAtom(a)
35
+
36
+ all_bonds = torch.nonzero(A)
37
+ for bond in all_bonds:
38
+ mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()])
39
+
40
+ mol.AddConformer(create_conformer(positions.detach().cpu().numpy().astype(np.float64)))
41
+ return mol
42
+
43
+
44
+ def build_xae_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
45
+ """ Returns a triplet (X, A, E): atom_types, adjacency matrix, edge_types
46
+ args:
47
+ positions: N x 3 (already masked to keep final number nodes)
48
+ atom_types: N
49
+ returns:
50
+ X: N (int)
51
+ A: N x N (bool) (binary adjacency matrix)
52
+ E: N x N (int) (bond type, 0 if no bond) such that A = E.bool()
53
+ """
54
+ n = positions.shape[0]
55
+ X = atom_types
56
+ A = torch.zeros((n, n), dtype=torch.bool)
57
+ E = torch.zeros((n, n), dtype=torch.int)
58
+
59
+ idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
60
+
61
+ pos = positions.unsqueeze(0)
62
+ dists = torch.cdist(pos, pos, p=2).squeeze(0)
63
+ for i in range(n):
64
+ for j in range(i):
65
+
66
+ pair = sorted([atom_types[i], atom_types[j]])
67
+ order = get_bond_order(idx2atom[pair[0].item()], idx2atom[pair[1].item()], dists[i, j], margins=margins)
68
+
69
+ # TODO: a batched version of get_bond_order to avoid the for loop
70
+ if order > 0:
71
+ # Warning: the graph should be DIRECTED
72
+ A[i, j] = 1
73
+ E[i, j] = order
74
+
75
+ return X, A, E
76
+
77
+
78
+ def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM):
79
+ distance = 100 * distance # We change the metric
80
+
81
+ # Check exists for large molecules where some atom pairs do not have a
82
+ # typical bond length.
83
+ if check_exists:
84
+ if atom1 not in const.BONDS_1:
85
+ return 0
86
+ if atom2 not in const.BONDS_1[atom1]:
87
+ return 0
88
+
89
+ # margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples
90
+ if distance < const.BONDS_1[atom1][atom2] + margins[0]:
91
+
92
+ # Check if atoms in bonds2 dictionary.
93
+ if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]:
94
+ thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1]
95
+ if distance < thr_bond2:
96
+ if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]:
97
+ thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2]
98
+ if distance < thr_bond3:
99
+ return 3 # Triple
100
+ return 2 # Double
101
+ return 1 # Single
102
+ return 0 # No bond
src/noise.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+
6
+
7
+ def clip_noise_schedule(alphas2, clip_value=0.001):
8
+ """
9
+ For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
10
+ sampling.
11
+ """
12
+ alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)
13
+
14
+ alphas_step = (alphas2[1:] / alphas2[:-1])
15
+
16
+ alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.)
17
+ alphas2 = np.cumprod(alphas_step, axis=0)
18
+
19
+ return alphas2
20
+
21
+
22
+ def polynomial_schedule(timesteps: int, s=1e-4, power=3.):
23
+ """
24
+ A noise schedule based on a simple polynomial equation: 1 - x^power.
25
+ """
26
+ steps = timesteps + 1
27
+ x = np.linspace(0, steps, steps)
28
+ alphas2 = (1 - np.power(x / steps, power)) ** 2
29
+
30
+ alphas2 = clip_noise_schedule(alphas2, clip_value=0.001)
31
+
32
+ precision = 1 - 2 * s
33
+
34
+ alphas2 = precision * alphas2 + s
35
+
36
+ return alphas2
37
+
38
+
39
+ def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):
40
+ """
41
+ cosine schedule
42
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
43
+ """
44
+ steps = timesteps + 2
45
+ x = np.linspace(0, steps, steps)
46
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
47
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
48
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
49
+ betas = np.clip(betas, a_min=0, a_max=0.999)
50
+ alphas = 1. - betas
51
+ alphas_cumprod = np.cumprod(alphas, axis=0)
52
+
53
+ if raise_to_power != 1:
54
+ alphas_cumprod = np.power(alphas_cumprod, raise_to_power)
55
+
56
+ return alphas_cumprod
57
+
58
+
59
+ class PositiveLinear(torch.nn.Module):
60
+ """Linear layer with weights forced to be positive."""
61
+
62
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
63
+ weight_init_offset: int = -2):
64
+ super(PositiveLinear, self).__init__()
65
+ self.in_features = in_features
66
+ self.out_features = out_features
67
+ self.weight = torch.nn.Parameter(
68
+ torch.empty((out_features, in_features)))
69
+ if bias:
70
+ self.bias = torch.nn.Parameter(torch.empty(out_features))
71
+ else:
72
+ self.register_parameter('bias', None)
73
+ self.weight_init_offset = weight_init_offset
74
+ self.reset_parameters()
75
+
76
+ def reset_parameters(self) -> None:
77
+ torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
78
+
79
+ with torch.no_grad():
80
+ self.weight.add_(self.weight_init_offset)
81
+
82
+ if self.bias is not None:
83
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
84
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
85
+ torch.nn.init.uniform_(self.bias, -bound, bound)
86
+
87
+ def forward(self, x):
88
+ positive_weight = F.softplus(self.weight)
89
+ return F.linear(x, positive_weight, self.bias)
90
+
91
+
92
+ class PredefinedNoiseSchedule(torch.nn.Module):
93
+ """
94
+ Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.
95
+ """
96
+
97
+ def __init__(self, noise_schedule, timesteps, precision):
98
+ super(PredefinedNoiseSchedule, self).__init__()
99
+ self.timesteps = timesteps
100
+
101
+ if noise_schedule == 'cosine':
102
+ alphas2 = cosine_beta_schedule(timesteps)
103
+ elif 'polynomial' in noise_schedule:
104
+ splits = noise_schedule.split('_')
105
+ assert len(splits) == 2
106
+ power = float(splits[1])
107
+ alphas2 = polynomial_schedule(timesteps, s=precision, power=power)
108
+ else:
109
+ raise ValueError(noise_schedule)
110
+
111
+ # print('alphas2', alphas2)
112
+
113
+ sigmas2 = 1 - alphas2
114
+
115
+ log_alphas2 = np.log(alphas2)
116
+ log_sigmas2 = np.log(sigmas2)
117
+
118
+ log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2
119
+
120
+ # print('gamma', -log_alphas2_to_sigmas2)
121
+
122
+ self.gamma = torch.nn.Parameter(
123
+ torch.from_numpy(-log_alphas2_to_sigmas2).float(),
124
+ requires_grad=False)
125
+
126
+ def forward(self, t):
127
+ t_int = torch.round(t * self.timesteps).long()
128
+ return self.gamma[t_int]
129
+
130
+
131
+ class GammaNetwork(torch.nn.Module):
132
+ """The gamma network models a monotonic increasing function. Construction as in the VDM paper."""
133
+
134
+ def __init__(self):
135
+ super().__init__()
136
+
137
+ self.l1 = PositiveLinear(1, 1)
138
+ self.l2 = PositiveLinear(1, 1024)
139
+ self.l3 = PositiveLinear(1024, 1)
140
+
141
+ self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.]))
142
+ self.gamma_1 = torch.nn.Parameter(torch.tensor([10.]))
143
+ self.show_schedule()
144
+
145
+ def show_schedule(self, num_steps=50):
146
+ t = torch.linspace(0, 1, num_steps).view(num_steps, 1)
147
+ gamma = self.forward(t)
148
+ print('Gamma schedule:')
149
+ print(gamma.detach().cpu().numpy().reshape(num_steps))
150
+
151
+ def gamma_tilde(self, t):
152
+ l1_t = self.l1(t)
153
+ return l1_t + self.l3(torch.sigmoid(self.l2(l1_t)))
154
+
155
+ def forward(self, t):
156
+ zeros, ones = torch.zeros_like(t), torch.ones_like(t)
157
+ # Not super efficient.
158
+ gamma_tilde_0 = self.gamma_tilde(zeros)
159
+ gamma_tilde_1 = self.gamma_tilde(ones)
160
+ gamma_tilde_t = self.gamma_tilde(t)
161
+
162
+ # Normalize to [0, 1]
163
+ normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / (
164
+ gamma_tilde_1 - gamma_tilde_0)
165
+
166
+ # Rescale to [gamma_0, gamma_1]
167
+ gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma
168
+
169
+ return gamma
src/utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from datetime import datetime
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ class Logger(object):
8
+ def __init__(self, logpath, syspart=sys.stdout):
9
+ self.terminal = syspart
10
+ self.log = open(logpath, "a")
11
+
12
+ def write(self, message):
13
+
14
+ self.terminal.write(message)
15
+ self.log.write(message)
16
+ self.log.flush()
17
+
18
+ def flush(self):
19
+ # this flush method is needed for python 3 compatibility.
20
+ # this handles the flush command by doing nothing.
21
+ # you might want to specify some extra behavior here.
22
+ pass
23
+
24
+ def log(*args):
25
+ print(f'[{datetime.now()}]', *args)
26
+
27
+ class EMA:
28
+ def __init__(self, beta):
29
+ super().__init__()
30
+ self.beta = beta
31
+
32
+ def update_model_average(self, ma_model, current_model):
33
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
34
+ old_weight, up_weight = ma_params.data, current_params.data
35
+ ma_params.data = self.update_average(old_weight, up_weight)
36
+
37
+ def update_average(self, old, new):
38
+ if old is None:
39
+ return new
40
+ return old * self.beta + (1 - self.beta) * new
41
+
42
+
43
+ def sum_except_batch(x):
44
+ return x.reshape(x.size(0), -1).sum(dim=-1)
45
+
46
+
47
+ def remove_mean(x):
48
+ mean = torch.mean(x, dim=1, keepdim=True)
49
+ x = x - mean
50
+ return x
51
+
52
+
53
+ def remove_mean_with_mask(x, node_mask):
54
+ masked_max_abs_value = (x * (1 - node_mask)).abs().sum().item()
55
+ assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high'
56
+ N = node_mask.sum(1, keepdims=True)
57
+
58
+ mean = torch.sum(x, dim=1, keepdim=True) / N
59
+ x = x - mean * node_mask
60
+ return x
61
+
62
+
63
+ def remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask):
64
+ """
65
+ Subtract center of mass of fragments from coordinates of all atoms
66
+ """
67
+ x_masked = x * center_of_mass_mask
68
+ N = center_of_mass_mask.sum(1, keepdims=True)
69
+ mean = torch.sum(x_masked, dim=1, keepdim=True) / N
70
+ x = x - mean * node_mask
71
+ return x
72
+
73
+
74
+ def assert_mean_zero(x):
75
+ mean = torch.mean(x, dim=1, keepdim=True)
76
+ assert mean.abs().max().item() < 1e-4
77
+
78
+
79
+ def assert_mean_zero_with_mask(x, node_mask, eps=1e-10):
80
+ assert_correctly_masked(x, node_mask)
81
+ largest_value = x.abs().max().item()
82
+ error = torch.sum(x, dim=1, keepdim=True).abs().max().item()
83
+ rel_error = error / (largest_value + eps)
84
+ assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}'
85
+
86
+
87
+ def assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask, eps=1e-10):
88
+ assert_correctly_masked(x, node_mask)
89
+ x_masked = x * center_of_mass_mask
90
+ largest_value = x_masked.abs().max().item()
91
+ error = torch.sum(x_masked, dim=1, keepdim=True).abs().max().item()
92
+ rel_error = error / (largest_value + eps)
93
+ assert rel_error < 1e-2, f'Partial mean is not zero, relative_error {rel_error}'
94
+
95
+
96
+ def assert_correctly_masked(variable, node_mask):
97
+ assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \
98
+ 'Variables not masked properly.'
99
+
100
+
101
+ def check_mask_correct(variables, node_mask):
102
+ for i, variable in enumerate(variables):
103
+ if len(variable) > 0:
104
+ assert_correctly_masked(variable, node_mask)
105
+
106
+
107
+ def center_gravity_zero_gaussian_log_likelihood(x):
108
+ assert len(x.size()) == 3
109
+ B, N, D = x.size()
110
+ assert_mean_zero(x)
111
+
112
+ # r is invariant to a basis change in the relevant hyperplane.
113
+ r2 = sum_except_batch(x.pow(2))
114
+
115
+ # The relevant hyperplane is (N-1) * D dimensional.
116
+ degrees_of_freedom = (N-1) * D
117
+
118
+ # Normalizing constant and logpx are computed:
119
+ log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
120
+ log_px = -0.5 * r2 + log_normalizing_constant
121
+
122
+ return log_px
123
+
124
+
125
+ def sample_center_gravity_zero_gaussian(size, device):
126
+ assert len(size) == 3
127
+ x = torch.randn(size, device=device)
128
+
129
+ # This projection only works because Gaussian is rotation invariant around
130
+ # zero and samples are independent!
131
+ x_projected = remove_mean(x)
132
+ return x_projected
133
+
134
+
135
+ def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask):
136
+ assert len(x.size()) == 3
137
+ B, N_embedded, D = x.size()
138
+ assert_mean_zero_with_mask(x, node_mask)
139
+
140
+ # r is invariant to a basis change in the relevant hyperplane, the masked
141
+ # out values will have zero contribution.
142
+ r2 = sum_except_batch(x.pow(2))
143
+
144
+ # The relevant hyperplane is (N-1) * D dimensional.
145
+ N = node_mask.squeeze(2).sum(1) # N has shape [B]
146
+ degrees_of_freedom = (N-1) * D
147
+
148
+ # Normalizing constant and logpx are computed:
149
+ log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
150
+ log_px = -0.5 * r2 + log_normalizing_constant
151
+
152
+ return log_px
153
+
154
+
155
+ def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask):
156
+ assert len(size) == 3
157
+ x = torch.randn(size, device=device)
158
+
159
+ x_masked = x * node_mask
160
+
161
+ # This projection only works because Gaussian is rotation invariant around
162
+ # zero and samples are independent!
163
+ # TODO: check it
164
+ x_projected = remove_mean_with_mask(x_masked, node_mask)
165
+ return x_projected
166
+
167
+
168
+ def standard_gaussian_log_likelihood(x):
169
+ # Normalizing constant and logpx are computed:
170
+ log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi))
171
+ return log_px
172
+
173
+
174
+ def sample_gaussian(size, device):
175
+ x = torch.randn(size, device=device)
176
+ return x
177
+
178
+
179
+ def standard_gaussian_log_likelihood_with_mask(x, node_mask):
180
+ # Normalizing constant and logpx are computed:
181
+ log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi)
182
+ log_px = sum_except_batch(log_px_elementwise * node_mask)
183
+ return log_px
184
+
185
+
186
+ def sample_gaussian_with_mask(size, device, node_mask):
187
+ x = torch.randn(size, device=device)
188
+ x_masked = x * node_mask
189
+ return x_masked
190
+
191
+
192
+ def concatenate_features(x, h):
193
+ xh = torch.cat([x, h['categorical']], dim=2)
194
+ if 'integer' in h:
195
+ xh = torch.cat([xh, h['integer']], dim=2)
196
+ return xh
197
+
198
+
199
+ def split_features(z, n_dims, num_classes, include_charges):
200
+ assert z.size(2) == n_dims + num_classes + include_charges
201
+ x = z[:, :, 0:n_dims]
202
+ h = {'categorical': z[:, :, n_dims:n_dims+num_classes]}
203
+ if include_charges:
204
+ h['integer'] = z[:, :, n_dims+num_classes:n_dims+num_classes+1]
205
+
206
+ return x, h
207
+
208
+
209
+ # For gradient clipping
210
+
211
+ class Queue:
212
+ def __init__(self, max_len=50):
213
+ self.items = []
214
+ self.max_len = max_len
215
+
216
+ def __len__(self):
217
+ return len(self.items)
218
+
219
+ def add(self, item):
220
+ self.items.insert(0, item)
221
+ if len(self) > self.max_len:
222
+ self.items.pop()
223
+
224
+ def mean(self):
225
+ return np.mean(self.items)
226
+
227
+ def std(self):
228
+ return np.std(self.items)
229
+
230
+
231
+ def gradient_clipping(flow, gradnorm_queue):
232
+ # Allow gradient norm to be 150% + 2 * stdev of the recent history.
233
+ max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
234
+
235
+ # Clips gradient and returns the norm
236
+ grad_norm = torch.nn.utils.clip_grad_norm_(
237
+ flow.parameters(), max_norm=max_grad_norm, norm_type=2.0)
238
+
239
+ if float(grad_norm) > max_grad_norm:
240
+ gradnorm_queue.add(float(max_grad_norm))
241
+ else:
242
+ gradnorm_queue.add(float(grad_norm))
243
+
244
+ if float(grad_norm) > max_grad_norm:
245
+ print(f'Clipped gradient with value {grad_norm:.1f} while allowed {max_grad_norm:.1f}')
246
+ return grad_norm
247
+
248
+
249
+ def disable_rdkit_logging():
250
+ """
251
+ Disables RDKit whiny logging.
252
+ """
253
+ import rdkit.rdBase as rkrb
254
+ import rdkit.RDLogger as rkl
255
+ logger = rkl.logger()
256
+ logger.setLevel(rkl.ERROR)
257
+ rkrb.DisableLog('rdApp.error')
258
+
259
+
260
+ class FoundNaNException(Exception):
261
+ def __init__(self, x, h):
262
+ x_nan_idx = self.find_nan_idx(x)
263
+ h_nan_idx = self.find_nan_idx(h)
264
+
265
+ self.x_h_nan_idx = x_nan_idx & h_nan_idx
266
+ self.only_x_nan_idx = x_nan_idx.difference(h_nan_idx)
267
+ self.only_h_nan_idx = h_nan_idx.difference(x_nan_idx)
268
+
269
+ @staticmethod
270
+ def find_nan_idx(z):
271
+ idx = set()
272
+ for i in range(z.shape[0]):
273
+ if torch.any(torch.isnan(z[i])):
274
+ idx.add(i)
275
+ return idx
276
+
277
+
278
+ def get_batch_idx_for_animation(batch_size, batch_idx):
279
+ batch_indices = []
280
+ mol_indices = []
281
+ for idx in [0, 110, 360]:
282
+ if idx // batch_size == batch_idx:
283
+ batch_indices.append(idx % batch_size)
284
+ mol_indices.append(idx)
285
+ return batch_indices, mol_indices
286
+
287
+
288
+ # Rotation data augmntation
289
+ def random_rotation(x):
290
+ bs, n_nodes, n_dims = x.size()
291
+ device = x.device
292
+ angle_range = np.pi * 2
293
+ if n_dims == 2:
294
+ theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
295
+ cos_theta = torch.cos(theta)
296
+ sin_theta = torch.sin(theta)
297
+ R_row0 = torch.cat([cos_theta, -sin_theta], dim=2)
298
+ R_row1 = torch.cat([sin_theta, cos_theta], dim=2)
299
+ R = torch.cat([R_row0, R_row1], dim=1)
300
+
301
+ x = x.transpose(1, 2)
302
+ x = torch.matmul(R, x)
303
+ x = x.transpose(1, 2)
304
+
305
+ elif n_dims == 3:
306
+
307
+ # Build Rx
308
+ Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
309
+ theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
310
+ cos = torch.cos(theta)
311
+ sin = torch.sin(theta)
312
+ Rx[:, 1:2, 1:2] = cos
313
+ Rx[:, 1:2, 2:3] = sin
314
+ Rx[:, 2:3, 1:2] = - sin
315
+ Rx[:, 2:3, 2:3] = cos
316
+
317
+ # Build Ry
318
+ Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
319
+ theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
320
+ cos = torch.cos(theta)
321
+ sin = torch.sin(theta)
322
+ Ry[:, 0:1, 0:1] = cos
323
+ Ry[:, 0:1, 2:3] = -sin
324
+ Ry[:, 2:3, 0:1] = sin
325
+ Ry[:, 2:3, 2:3] = cos
326
+
327
+ # Build Rz
328
+ Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
329
+ theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
330
+ cos = torch.cos(theta)
331
+ sin = torch.sin(theta)
332
+ Rz[:, 0:1, 0:1] = cos
333
+ Rz[:, 0:1, 1:2] = sin
334
+ Rz[:, 1:2, 0:1] = -sin
335
+ Rz[:, 1:2, 1:2] = cos
336
+
337
+ x = x.transpose(1, 2)
338
+ x = torch.matmul(Rx, x)
339
+ #x = torch.matmul(Rx.transpose(1, 2), x)
340
+ x = torch.matmul(Ry, x)
341
+ #x = torch.matmul(Ry.transpose(1, 2), x)
342
+ x = torch.matmul(Rz, x)
343
+ #x = torch.matmul(Rz.transpose(1, 2), x)
344
+ x = x.transpose(1, 2)
345
+ else:
346
+ raise Exception("Not implemented Error")
347
+
348
+ return x.contiguous()
src/visualizer.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import imageio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import glob
7
+ import random
8
+
9
+ from sklearn.decomposition import PCA
10
+ from src import const
11
+ from src.molecule_builder import get_bond_order
12
+
13
+
14
+ def save_xyz_file(path, one_hot, positions, node_mask, names, is_geom, suffix=''):
15
+ idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
16
+
17
+ for batch_i in range(one_hot.size(0)):
18
+ mask = node_mask[batch_i].squeeze()
19
+ n_atoms = mask.sum()
20
+ atom_idx = torch.where(mask)[0]
21
+
22
+ f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w")
23
+ f.write("%d\n\n" % n_atoms)
24
+ atoms = torch.argmax(one_hot[batch_i], dim=1)
25
+ for atom_i in atom_idx:
26
+ atom = atoms[atom_i].item()
27
+ atom = idx2atom[atom]
28
+ f.write("%s %.9f %.9f %.9f\n" % (
29
+ atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2]
30
+ ))
31
+ f.close()
32
+
33
+
34
+ def load_xyz_files(path, suffix=''):
35
+ files = []
36
+ for fname in os.listdir(path):
37
+ if fname.endswith(f'_{suffix}.xyz'):
38
+ files.append(fname)
39
+ files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1]))
40
+ return [os.path.join(path, fname) for fname in files]
41
+
42
+
43
+ def load_molecule_xyz(file, is_geom):
44
+ atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
45
+ idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
46
+ with open(file, encoding='utf8') as f:
47
+ n_atoms = int(f.readline())
48
+ one_hot = torch.zeros(n_atoms, len(idx2atom))
49
+ charges = torch.zeros(n_atoms, 1)
50
+ positions = torch.zeros(n_atoms, 3)
51
+ f.readline()
52
+ atoms = f.readlines()
53
+ for i in range(n_atoms):
54
+ atom = atoms[i].split(' ')
55
+ atom_type = atom[0]
56
+ one_hot[i, atom2idx[atom_type]] = 1
57
+ position = torch.Tensor([float(e) for e in atom[1:]])
58
+ positions[i, :] = position
59
+ return positions, one_hot, charges
60
+
61
+
62
+ def draw_sphere(ax, x, y, z, size, color, alpha):
63
+ u = np.linspace(0, 2 * np.pi, 100)
64
+ v = np.linspace(0, np.pi, 100)
65
+
66
+ xs = size * np.outer(np.cos(u), np.sin(v))
67
+ ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
68
+ zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
69
+ ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)
70
+
71
+
72
+ def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None):
73
+ x = positions[:, 0]
74
+ y = positions[:, 1]
75
+ z = positions[:, 2]
76
+ # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
77
+
78
+ idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
79
+
80
+ colors_dic = np.array(const.COLORS)
81
+ radius_dic = np.array(const.RADII)
82
+ area_dic = 1500 * radius_dic ** 2
83
+
84
+ areas = area_dic[atom_type]
85
+ radii = radius_dic[atom_type]
86
+ colors = colors_dic[atom_type]
87
+
88
+ if fragment_mask is None:
89
+ fragment_mask = torch.ones(len(x))
90
+
91
+ for i in range(len(x)):
92
+ for j in range(i + 1, len(x)):
93
+ p1 = np.array([x[i], y[i], z[i]])
94
+ p2 = np.array([x[j], y[j], z[j]])
95
+ dist = np.sqrt(np.sum((p1 - p2) ** 2))
96
+ atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
97
+ draw_edge_int = get_bond_order(atom1, atom2, dist)
98
+ line_width = (3 - 2) * 2 * 2
99
+ draw_edge = draw_edge_int > 0
100
+ if draw_edge:
101
+ if draw_edge_int == 4:
102
+ linewidth_factor = 1.5
103
+ else:
104
+ linewidth_factor = 1
105
+ linewidth_factor *= 0.5
106
+ ax.plot(
107
+ [x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
108
+ linewidth=line_width * linewidth_factor * 2,
109
+ c=hex_bg_color,
110
+ alpha=alpha
111
+ )
112
+
113
+ # from pdb import set_trace
114
+ # set_trace()
115
+
116
+ if spheres_3d:
117
+ # idx = torch.where(fragment_mask[:len(x)] == 0)[0]
118
+ # ax.scatter(
119
+ # x[idx],
120
+ # y[idx],
121
+ # z[idx],
122
+ # alpha=0.9 * alpha,
123
+ # edgecolors='#FCBA03',
124
+ # facecolors='none',
125
+ # linewidths=2,
126
+ # s=900
127
+ # )
128
+ for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask):
129
+ if f == 1:
130
+ alpha = 1.0
131
+
132
+ draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)
133
+
134
+ else:
135
+ ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors)
136
+
137
+
138
+ def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
139
+ bg='black', alpha=1., fragment_mask=None):
140
+ black = (0, 0, 0)
141
+ white = (1, 1, 1)
142
+ hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'
143
+
144
+ fig = plt.figure(figsize=(10, 10))
145
+ ax = fig.add_subplot(projection='3d')
146
+ ax.set_aspect('auto')
147
+ ax.view_init(elev=camera_elev, azim=camera_azim)
148
+ if bg == 'black':
149
+ ax.set_facecolor(black)
150
+ else:
151
+ ax.set_facecolor(white)
152
+ ax.xaxis.pane.set_alpha(0)
153
+ ax.yaxis.pane.set_alpha(0)
154
+ ax.zaxis.pane.set_alpha(0)
155
+ ax._axis3don = False
156
+
157
+ if bg == 'black':
158
+ ax.w_xaxis.line.set_color("black")
159
+ else:
160
+ ax.w_xaxis.line.set_color("white")
161
+
162
+ plot_molecule(
163
+ ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask
164
+ )
165
+
166
+ max_value = positions.abs().max().item()
167
+ axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
168
+ ax.set_xlim(-axis_lim, axis_lim)
169
+ ax.set_ylim(-axis_lim, axis_lim)
170
+ ax.set_zlim(-axis_lim, axis_lim)
171
+ dpi = 120 if spheres_3d else 50
172
+
173
+ if save_path is not None:
174
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
175
+ # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)
176
+
177
+ if spheres_3d:
178
+ img = imageio.imread(save_path)
179
+ img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
180
+ imageio.imsave(save_path, img_brighter)
181
+ else:
182
+ plt.show()
183
+ plt.close()
184
+
185
+
186
+ def visualize_chain(
187
+ path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None
188
+ ):
189
+ files = load_xyz_files(path)
190
+ save_paths = []
191
+
192
+ # Fit PCA to the final molecule – to obtain the best orientation for visualization
193
+ positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
194
+ pca = PCA(n_components=3)
195
+ pca.fit(positions)
196
+
197
+ for i in range(len(files)):
198
+ file = files[i]
199
+
200
+ positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
201
+ atom_type = torch.argmax(one_hot, dim=1).numpy()
202
+
203
+ # Transform positions of each frame according to the best orientation of the last frame
204
+ positions = pca.transform(positions)
205
+ positions = torch.tensor(positions)
206
+
207
+ fn = file[:-4] + '.png'
208
+ plot_data3d(
209
+ positions, atom_type,
210
+ save_path=fn,
211
+ spheres_3d=spheres_3d,
212
+ alpha=alpha,
213
+ bg=bg,
214
+ camera_elev=90,
215
+ camera_azim=90,
216
+ is_geom=is_geom,
217
+ fragment_mask=fragment_mask,
218
+ )
219
+ save_paths.append(fn)
220
+
221
+ imgs = [imageio.imread(fn) for fn in save_paths]
222
+ dirname = os.path.dirname(save_paths[0])
223
+ gif_path = dirname + '/output.gif'
224
+ imageio.mimsave(gif_path, imgs, subrectangles=True)
225
+
226
+ if wandb is not None:
227
+ wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})