Simon Duerr commited on
Commit
486fd8a
1 Parent(s): 3e6dce4

gradio update

Browse files
app.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ import copy
5
+ import os
6
+ import torch
7
+
8
+ import time
9
+ from argparse import ArgumentParser, Namespace, FileType
10
+ from rdkit.Chem import RemoveHs
11
+ from functools import partial
12
+ import numpy as np
13
+ import pandas as pd
14
+ from rdkit import RDLogger
15
+ from rdkit.Chem import MolFromSmiles, AddHs
16
+ from torch_geometric.loader import DataLoader
17
+ import yaml
18
+
19
+ from datasets.process_mols import (
20
+ read_molecule,
21
+ generate_conformer,
22
+ write_mol_with_coords,
23
+ )
24
+ from datasets.pdbbind import PDBBind
25
+ from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
26
+ from utils.sampling import randomize_position, sampling
27
+ from utils.utils import get_model
28
+ from utils.visualise import PDBFile
29
+ from tqdm import tqdm
30
+ from datasets.esm_embedding_preparation import esm_embedding_prep
31
+ import subprocess
32
+
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ with open(f"workdir/paper_score_model/model_parameters.yml") as f:
36
+ score_model_args = Namespace(**yaml.full_load(f))
37
+
38
+ with open(f"workdir/paper_confidence_model/model_parameters.yml") as f:
39
+ confidence_args = Namespace(**yaml.full_load(f))
40
+
41
+ t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
42
+
43
+ model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True)
44
+ state_dict = torch.load(
45
+ f"workdir/paper_score_model/best_ema_inference_epoch_model.pt",
46
+ map_location=torch.device("cpu"),
47
+ )
48
+ model.load_state_dict(state_dict, strict=True)
49
+ model = model.to(device)
50
+ model.eval()
51
+
52
+ confidence_model = get_model(
53
+ confidence_args,
54
+ device,
55
+ t_to_sigma=t_to_sigma,
56
+ no_parallel=True,
57
+ confidence_mode=True,
58
+ )
59
+ state_dict = torch.load(
60
+ f"workdir/paper_confidence_model/best_model_epoch75.pt",
61
+ map_location=torch.device("cpu"),
62
+ )
63
+ confidence_model.load_state_dict(state_dict, strict=True)
64
+ confidence_model = confidence_model.to(device)
65
+ confidence_model.eval()
66
+ tr_schedule = get_t_schedule(inference_steps=10)
67
+ rot_schedule = tr_schedule
68
+ tor_schedule = tr_schedule
69
+ print("common t schedule", tr_schedule)
70
+ failures, skipped, confidences_list, names_list, run_times, min_self_distances_list = (
71
+ 0,
72
+ 0,
73
+ [],
74
+ [],
75
+ [],
76
+ [],
77
+ )
78
+ N = 10
79
+
80
+
81
+ def get_pdb(pdb_code="", filepath=""):
82
+ if pdb_code is None or pdb_code == "":
83
+ try:
84
+ return filepath.name
85
+ except AttributeError as e:
86
+ return None
87
+ else:
88
+ os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
89
+ return f"{pdb_code}.pdb"
90
+
91
+
92
+ def get_ligand(smiles="", filepath=""):
93
+ if smiles is None or smiles == "":
94
+ try:
95
+ return filepath.name
96
+ except AttributeError as e:
97
+ return None
98
+ else:
99
+ return smiles
100
+
101
+
102
+ def read_mol(molpath):
103
+ with open(molpath, "r") as fp:
104
+ lines = fp.readlines()
105
+ mol = ""
106
+ for l in lines:
107
+ mol += l
108
+ return mol
109
+
110
+
111
+ def molecule(input_pdb, ligand_pdb):
112
+
113
+ structure = read_mol(input_pdb)
114
+ mol = read_mol(ligand_pdb)
115
+
116
+ x = (
117
+ """<!DOCTYPE html>
118
+ <html>
119
+ <head>
120
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
121
+ <style>
122
+ body{
123
+ font-family:sans-serif
124
+ }
125
+ .mol-container {
126
+ width: 600px;
127
+ height: 600px;
128
+ position: relative;
129
+ mx-auto:0
130
+ }
131
+ .mol-container select{
132
+ background-image:None;
133
+ }
134
+ </style>
135
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
136
+ </head>
137
+ <body>
138
+ <button id="startanimation">Replay diffusion process</button>
139
+ <div id="container" class="mol-container"></div>
140
+
141
+ <script>
142
+ let ligand = `"""
143
+ + mol
144
+ + """`
145
+ let structure = `"""
146
+ + structure
147
+ + """`
148
+
149
+ let viewer = null;
150
+
151
+ $(document).ready(function () {
152
+ let element = $("#container");
153
+ let config = { backgroundColor: "white" };
154
+ viewer = $3Dmol.createViewer(element, config);
155
+ viewer.addModel( structure, "pdb" );
156
+ viewer.setStyle({}, {cartoon: {color: "gray"}});
157
+ viewer.zoomTo();
158
+ viewer.zoom(0.7);
159
+ viewer.addModelsAsFrames(ligand, "pdb");
160
+ viewer.animate({loop: "forward",reps: 1});
161
+
162
+ viewer.getModel(1).setStyle({stick:{colorscheme:"magentaCarbon"}});
163
+ viewer.render();
164
+
165
+ })
166
+
167
+ $("#startanimation").click(function() {
168
+ viewer.animate({loop: "forward",reps: 1});
169
+ });
170
+ </script>
171
+ </body></html>"""
172
+ )
173
+
174
+ return f"""<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera;
175
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
176
+ allow-scripts allow-same-origin allow-popups
177
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
178
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
179
+
180
+
181
+ def esm(protein_path, out_file):
182
+ esm_embedding_prep(out_file, protein_path)
183
+ # create args object with defaults
184
+ os.environ["HOME"] = "esm/model_weights"
185
+
186
+ subprocess.call(
187
+ f"python esm/scripts/extract.py esm2_t33_650M_UR50D {out_file} data/esm2_output --repr_layers 33 --include per_tok",
188
+ shell=True,
189
+ )
190
+
191
+
192
+ def update(inp, file, ligand_inp, ligand_file):
193
+ pdb_path = get_pdb(inp, file)
194
+ ligand_path = get_ligand(ligand_inp, ligand_file)
195
+
196
+ esm(
197
+ pdb_path,
198
+ f"data/{os.path.basename(pdb_path)}_prepared_for_esm.fasta",
199
+ )
200
+
201
+ protein_path_list = [pdb_path]
202
+ ligand_descriptions = [ligand_path]
203
+ no_random = False
204
+ ode = False
205
+ no_final_step_noise = False
206
+ out_dir = "results/test"
207
+ test_dataset = PDBBind(
208
+ transform=None,
209
+ root="",
210
+ protein_path_list=protein_path_list,
211
+ ligand_descriptions=ligand_descriptions,
212
+ receptor_radius=score_model_args.receptor_radius,
213
+ cache_path="data/cache",
214
+ remove_hs=score_model_args.remove_hs,
215
+ max_lig_size=None,
216
+ c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
217
+ matching=False,
218
+ keep_original=False,
219
+ popsize=score_model_args.matching_popsize,
220
+ maxiter=score_model_args.matching_maxiter,
221
+ all_atoms=score_model_args.all_atoms,
222
+ atom_radius=score_model_args.atom_radius,
223
+ atom_max_neighbors=score_model_args.atom_max_neighbors,
224
+ esm_embeddings_path="data/esm2_output",
225
+ require_ligand=True,
226
+ num_workers=1,
227
+ keep_local_structures=False,
228
+ )
229
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
230
+ confidence_test_dataset = PDBBind(
231
+ transform=None,
232
+ root="",
233
+ protein_path_list=protein_path_list,
234
+ ligand_descriptions=ligand_descriptions,
235
+ receptor_radius=confidence_args.receptor_radius,
236
+ cache_path="data/cache",
237
+ remove_hs=confidence_args.remove_hs,
238
+ max_lig_size=None,
239
+ c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
240
+ matching=False,
241
+ keep_original=False,
242
+ popsize=confidence_args.matching_popsize,
243
+ maxiter=confidence_args.matching_maxiter,
244
+ all_atoms=confidence_args.all_atoms,
245
+ atom_radius=confidence_args.atom_radius,
246
+ atom_max_neighbors=confidence_args.atom_max_neighbors,
247
+ esm_embeddings_path="data/esm2_output",
248
+ require_ligand=True,
249
+ num_workers=1,
250
+ )
251
+ confidence_complex_dict = {d.name: d for d in confidence_test_dataset}
252
+ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
253
+ if (
254
+ confidence_model is not None
255
+ and not (
256
+ confidence_args.use_original_model_cache
257
+ or confidence_args.transfer_weights
258
+ )
259
+ and orig_complex_graph.name[0] not in confidence_complex_dict.keys()
260
+ ):
261
+ skipped += 1
262
+ print(
263
+ f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name[0]}. We are skipping this complex."
264
+ )
265
+ continue
266
+ try:
267
+ data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)]
268
+ randomize_position(
269
+ data_list,
270
+ score_model_args.no_torsion,
271
+ no_random,
272
+ score_model_args.tr_sigma_max,
273
+ )
274
+ pdb = None
275
+ lig = orig_complex_graph.mol[0]
276
+ visualization_list = []
277
+ for graph in data_list:
278
+ pdb = PDBFile(lig)
279
+ pdb.add(lig, 0, 0)
280
+ pdb.add(
281
+ (
282
+ orig_complex_graph["ligand"].pos
283
+ + orig_complex_graph.original_center
284
+ )
285
+ .detach()
286
+ .cpu(),
287
+ 1,
288
+ 0,
289
+ )
290
+ pdb.add(
291
+ (graph["ligand"].pos + graph.original_center).detach().cpu(),
292
+ part=1,
293
+ order=1,
294
+ )
295
+ visualization_list.append(pdb)
296
+
297
+ start_time = time.time()
298
+ if confidence_model is not None and not (
299
+ confidence_args.use_original_model_cache
300
+ or confidence_args.transfer_weights
301
+ ):
302
+ confidence_data_list = [
303
+ copy.deepcopy(confidence_complex_dict[orig_complex_graph.name[0]])
304
+ for _ in range(N)
305
+ ]
306
+ else:
307
+ confidence_data_list = None
308
+
309
+ data_list, confidence = sampling(
310
+ data_list=data_list,
311
+ model=model,
312
+ inference_steps=10,
313
+ tr_schedule=tr_schedule,
314
+ rot_schedule=rot_schedule,
315
+ tor_schedule=tor_schedule,
316
+ device=device,
317
+ t_to_sigma=t_to_sigma,
318
+ model_args=score_model_args,
319
+ no_random=no_random,
320
+ ode=ode,
321
+ visualization_list=visualization_list,
322
+ confidence_model=confidence_model,
323
+ confidence_data_list=confidence_data_list,
324
+ confidence_model_args=confidence_args,
325
+ batch_size=1,
326
+ no_final_step_noise=no_final_step_noise,
327
+ )
328
+ ligand_pos = np.asarray(
329
+ [
330
+ complex_graph["ligand"].pos.cpu().numpy()
331
+ + orig_complex_graph.original_center.cpu().numpy()
332
+ for complex_graph in data_list
333
+ ]
334
+ )
335
+ run_times.append(time.time() - start_time)
336
+
337
+ if confidence is not None and isinstance(
338
+ confidence_args.rmsd_classification_cutoff, list
339
+ ):
340
+ confidence = confidence[:, 0]
341
+ if confidence is not None:
342
+ confidence = confidence.cpu().numpy()
343
+ re_order = np.argsort(confidence)[::-1]
344
+ confidence = confidence[re_order]
345
+ confidences_list.append(confidence)
346
+ ligand_pos = ligand_pos[re_order]
347
+ write_dir = (
348
+ f'{out_dir}/index{idx}_{data_list[0]["name"][0].replace("/","-")}'
349
+ )
350
+ os.makedirs(write_dir, exist_ok=True)
351
+ for rank, pos in enumerate(ligand_pos):
352
+ mol_pred = copy.deepcopy(lig)
353
+ if score_model_args.remove_hs:
354
+ mol_pred = RemoveHs(mol_pred)
355
+ if rank == 0:
356
+ write_mol_with_coords(
357
+ mol_pred, pos, os.path.join(write_dir, f"rank{rank+1}.sdf")
358
+ )
359
+ write_mol_with_coords(
360
+ mol_pred,
361
+ pos,
362
+ os.path.join(
363
+ write_dir, f"rank{rank+1}_confidence{confidence[rank]:.2f}.sdf"
364
+ ),
365
+ )
366
+ self_distances = np.linalg.norm(
367
+ ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1
368
+ )
369
+ self_distances = np.where(
370
+ np.eye(self_distances.shape[2]), np.inf, self_distances
371
+ )
372
+ min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
373
+
374
+ filenames = []
375
+ if confidence is not None:
376
+ for rank, batch_idx in enumerate(re_order):
377
+ visualization_list[batch_idx].write(
378
+ os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
379
+ )
380
+ filenames.append(
381
+ os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
382
+ )
383
+ else:
384
+ for rank, batch_idx in enumerate(ligand_pos):
385
+ visualization_list[batch_idx].write(
386
+ os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
387
+ )
388
+ filenames.append(
389
+ os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb")
390
+ )
391
+ names_list.append(orig_complex_graph.name[0])
392
+ except Exception as e:
393
+ print("Failed on", orig_complex_graph["name"], e)
394
+ failures += 1
395
+ return None
396
+
397
+ labels = [f"rank {i+1}" for i in range(len(filenames))]
398
+ return (
399
+ molecule(pdb_path, filenames[0]),
400
+ gr.Dropdown.update(choices=labels, value="rank 1"),
401
+ filenames,
402
+ pdb_path,
403
+ )
404
+
405
+
406
+ def updateView(out, filenames, pdb):
407
+ i = int(out.replace("rank", ""))
408
+ return molecule(pdb, filenames[i])
409
+
410
+
411
+ demo = gr.Blocks()
412
+
413
+ with demo:
414
+ gr.Markdown("# DiffDock")
415
+ gr.Markdown(
416
+ ">**DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking**, Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi, arXiv:2210.01776 [GitHub](https://github.com/gcorso/diffdock)"
417
+ )
418
+ gr.Markdown("Runs the diffusion model `10` times with `10` inference steps")
419
+ with gr.Box():
420
+ with gr.Row():
421
+ with gr.Column():
422
+ gr.Markdown("## Protein")
423
+ inp = gr.Textbox(
424
+ placeholder="PDB Code or upload file below", label="Input structure"
425
+ )
426
+ file = gr.File(file_count="single", label="Input PDB")
427
+ with gr.Column():
428
+ gr.Markdown("## Ligand")
429
+ ligand_inp = gr.Textbox(
430
+ placeholder="Provide SMILES input or upload mol2/sdf file below",
431
+ label="SMILES string",
432
+ )
433
+ ligand_file = gr.File(file_count="single", label="Input Ligand")
434
+
435
+ btn = gr.Button("Run predictions")
436
+
437
+ gr.Markdown("## Output")
438
+ pdb = gr.Variable()
439
+ filenames = gr.Variable()
440
+ out = gr.Dropdown(interactive=True, label="Ranked samples")
441
+ mol = gr.HTML()
442
+ gr.Examples(
443
+ [
444
+ [
445
+ None,
446
+ "examples/1a46_protein_processed.pdb",
447
+ None,
448
+ "examples/1a46_ligand.sdf",
449
+ ]
450
+ ],
451
+ [inp, file, ligand_inp, ligand_file],
452
+ [mol, out],
453
+ # cache_examples=True,
454
+ )
455
+ btn.click(
456
+ fn=update,
457
+ inputs=[inp, file, ligand_inp, ligand_file],
458
+ outputs=[mol, out, filenames, pdb],
459
+ )
460
+ out.change(fn=updateView, inputs=[out, filenames, pdb], outputs=mol)
461
+ demo.launch()
datasets/esm_embedding_preparation.py CHANGED
@@ -9,79 +9,80 @@ from Bio.SeqRecord import SeqRecord
9
  from tqdm import tqdm
10
  from Bio import SeqIO
11
 
12
- parser = ArgumentParser()
13
- parser.add_argument('--out_file', type=str, default="data/prepared_for_esm.fasta")
14
- parser.add_argument('--protein_ligand_csv', type=str, default='data/protein_ligand_example_csv.csv', help='Path to a .csv specifying the input as described in the main README')
15
- parser.add_argument('--protein_path', type=str, default=None, help='Path to a single PDB file. If this is not None then it will be used instead of the --protein_ligand_csv')
16
- args = parser.parse_args()
17
 
18
- biopython_parser = PDBParser()
19
 
20
- three_to_one = {'ALA': 'A',
21
- 'ARG': 'R',
22
- 'ASN': 'N',
23
- 'ASP': 'D',
24
- 'CYS': 'C',
25
- 'GLN': 'Q',
26
- 'GLU': 'E',
27
- 'GLY': 'G',
28
- 'HIS': 'H',
29
- 'ILE': 'I',
30
- 'LEU': 'L',
31
- 'LYS': 'K',
32
- 'MET': 'M',
33
- 'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
34
- 'PHE': 'F',
35
- 'PRO': 'P',
36
- 'PYL': 'O',
37
- 'SER': 'S',
38
- 'SEC': 'U',
39
- 'THR': 'T',
40
- 'TRP': 'W',
41
- 'TYR': 'Y',
42
- 'VAL': 'V',
43
- 'ASX': 'B',
44
- 'GLX': 'Z',
45
- 'XAA': 'X',
46
- 'XLE': 'J'}
47
-
48
- if args.protein_path is not None:
49
- file_paths = [args.protein_path]
50
- else:
51
- df = pd.read_csv(args.protein_ligand_csv)
52
- file_paths = list(set(df['protein_path'].tolist()))
53
- sequences = []
54
- ids = []
55
- for file_path in tqdm(file_paths):
56
- structure = biopython_parser.get_structure('random_id', file_path)
57
- structure = structure[0]
58
- for i, chain in enumerate(structure):
59
- seq = ''
60
- for res_idx, residue in enumerate(chain):
61
- if residue.get_resname() == 'HOH':
62
- continue
63
- residue_coords = []
64
- c_alpha, n, c = None, None, None
65
- for atom in residue:
66
- if atom.name == 'CA':
67
- c_alpha = list(atom.get_vector())
68
- if atom.name == 'N':
69
- n = list(atom.get_vector())
70
- if atom.name == 'C':
71
- c = list(atom.get_vector())
72
- if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid
73
- try:
74
- seq += three_to_one[residue.get_resname()]
75
- except Exception as e:
76
- seq += '-'
77
- print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', file_path, '. Replacing it with a dash - .')
78
- sequences.append(seq)
79
- ids.append(f'{os.path.basename(file_path)}_chain_{i}')
80
- records = []
81
- for (index, seq) in zip(ids,sequences):
82
- record = SeqRecord(Seq(seq), str(index))
83
- record.description = ''
84
- records.append(record)
85
- SeqIO.write(records, args.out_file, "fasta")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from tqdm import tqdm
10
  from Bio import SeqIO
11
 
 
 
 
 
 
12
 
 
13
 
14
+ def esm_embedding_prep(out_file, protein_path):
15
+ biopython_parser = PDBParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ three_to_one = {
18
+ "ALA": "A",
19
+ "ARG": "R",
20
+ "ASN": "N",
21
+ "ASP": "D",
22
+ "CYS": "C",
23
+ "GLN": "Q",
24
+ "GLU": "E",
25
+ "GLY": "G",
26
+ "HIS": "H",
27
+ "ILE": "I",
28
+ "LEU": "L",
29
+ "LYS": "K",
30
+ "MET": "M",
31
+ "MSE": "M", # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
32
+ "PHE": "F",
33
+ "PRO": "P",
34
+ "PYL": "O",
35
+ "SER": "S",
36
+ "SEC": "U",
37
+ "THR": "T",
38
+ "TRP": "W",
39
+ "TYR": "Y",
40
+ "VAL": "V",
41
+ "ASX": "B",
42
+ "GLX": "Z",
43
+ "XAA": "X",
44
+ "XLE": "J",
45
+ }
46
 
47
+ file_paths = [protein_path]
48
+ sequences = []
49
+ ids = []
50
+ for file_path in tqdm(file_paths):
51
+ structure = biopython_parser.get_structure("random_id", file_path)
52
+ structure = structure[0]
53
+ for i, chain in enumerate(structure):
54
+ seq = ""
55
+ for res_idx, residue in enumerate(chain):
56
+ if residue.get_resname() == "HOH":
57
+ continue
58
+ residue_coords = []
59
+ c_alpha, n, c = None, None, None
60
+ for atom in residue:
61
+ if atom.name == "CA":
62
+ c_alpha = list(atom.get_vector())
63
+ if atom.name == "N":
64
+ n = list(atom.get_vector())
65
+ if atom.name == "C":
66
+ c = list(atom.get_vector())
67
+ if (
68
+ c_alpha != None and n != None and c != None
69
+ ): # only append residue if it is an amino acid
70
+ try:
71
+ seq += three_to_one[residue.get_resname()]
72
+ except Exception as e:
73
+ seq += "-"
74
+ print(
75
+ "encountered unknown AA: ",
76
+ residue.get_resname(),
77
+ " in the complex ",
78
+ file_path,
79
+ ". Replacing it with a dash - .",
80
+ )
81
+ sequences.append(seq)
82
+ ids.append(f"{os.path.basename(file_path)}_chain_{i}")
83
+ records = []
84
+ for (index, seq) in zip(ids, sequences):
85
+ record = SeqRecord(Seq(seq), str(index))
86
+ record.description = ""
87
+ records.append(record)
88
+ SeqIO.write(records, out_file, "fasta")
datasets/pdbbind.py CHANGED
@@ -16,8 +16,15 @@ from torch_geometric.loader import DataLoader, DataListLoader
16
  from torch_geometric.transforms import BaseTransform
17
  from tqdm import tqdm
18
 
19
- from datasets.process_mols import read_molecule, get_rec_graph, generate_conformer, \
20
- get_lig_graph_with_matching, extract_receptor_structure, parse_receptor, parse_pdb_from_path
 
 
 
 
 
 
 
21
  from utils.diffusion_utils import modify_conformer, set_time
22
  from utils.utils import read_strings_from_txt
23
  from utils import so3, torus
@@ -34,32 +41,87 @@ class NoiseTransform(BaseTransform):
34
  t_tr, t_rot, t_tor = t, t, t
35
  return self.apply_noise(data, t_tr, t_rot, t_tor)
36
 
37
- def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update = None, rot_update=None, torsion_updates=None):
38
- if not torch.is_tensor(data['ligand'].pos):
39
- data['ligand'].pos = random.choice(data['ligand'].pos)
 
 
 
 
 
 
 
 
 
40
 
41
  tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
42
  set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
43
 
44
- tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update
 
 
 
 
45
  rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
46
- torsion_updates = np.random.normal(loc=0.0, scale=tor_sigma, size=data['ligand'].edge_mask.sum()) if torsion_updates is None else torsion_updates
 
 
 
 
 
 
47
  torsion_updates = None if self.no_torsion else torsion_updates
48
- modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
49
-
50
- data.tr_score = -tr_update / tr_sigma ** 2
51
- data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0)
52
- data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
53
- data.tor_sigma_edge = None if self.no_torsion else np.ones(data['ligand'].edge_mask.sum()) * tor_sigma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return data
55
 
56
 
57
  class PDBBind(Dataset):
58
- def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0,
59
- receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
60
- matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
61
- atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False,
62
- ligands_list=None, protein_path_list=None, ligand_descriptions=None, keep_local_structures=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  super(PDBBind, self).__init__(root, transform)
65
  self.pdbbind_dir = root
@@ -75,37 +137,67 @@ class PDBBind(Dataset):
75
  self.protein_path_list = protein_path_list
76
  self.ligand_descriptions = ligand_descriptions
77
  self.keep_local_structures = keep_local_structures
78
- if matching or protein_path_list is not None and ligand_descriptions is not None:
79
- cache_path += '_torsion'
 
 
 
 
80
  if all_atoms:
81
- cache_path += '_allatoms'
82
- self.full_cache_path = os.path.join(cache_path, f'limit{self.limit_complexes}'
83
- f'_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}'
84
- f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
85
- f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
86
- + ('' if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
87
- + ('' if not matching or num_conformers == 1 else f'_confs{num_conformers}')
88
- + ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
89
- + ('' if not keep_local_structures else f'_keptLocalStruct')
90
- + ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode()))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  self.popsize, self.maxiter = popsize, maxiter
92
  self.matching, self.keep_original = matching, keep_original
93
  self.num_conformers = num_conformers
94
  self.all_atoms = all_atoms
95
  self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
96
- if not os.path.exists(os.path.join(self.full_cache_path, "heterographs.pkl"))\
97
- or (require_ligand and not os.path.exists(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"))):
 
 
 
 
 
 
98
  os.makedirs(self.full_cache_path, exist_ok=True)
99
  if protein_path_list is None or ligand_descriptions is None:
100
  self.preprocessing()
101
  else:
102
  self.inference_preprocessing()
103
 
104
- print('loading data from memory: ', os.path.join(self.full_cache_path, "heterographs.pkl"))
105
- with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
 
 
 
106
  self.complex_graphs = pickle.load(f)
107
  if require_ligand:
108
- with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
 
 
109
  self.rdkit_ligands = pickle.load(f)
110
 
111
  print_statistics(self.complex_graphs)
@@ -122,18 +214,20 @@ class PDBBind(Dataset):
122
  return copy.deepcopy(self.complex_graphs[idx])
123
 
124
  def preprocessing(self):
125
- print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]')
 
 
126
 
127
  complex_names_all = read_strings_from_txt(self.split_path)
128
  if self.limit_complexes is not None and self.limit_complexes != 0:
129
- complex_names_all = complex_names_all[:self.limit_complexes]
130
- print(f'Loading {len(complex_names_all)} complexes.')
131
 
132
  if self.esm_embeddings_path is not None:
133
  id_to_embeddings = torch.load(self.esm_embeddings_path)
134
  chain_embeddings_dictlist = defaultdict(list)
135
  for key, embedding in id_to_embeddings.items():
136
- key_name = key.split('_')[0]
137
  if key_name in complex_names_all:
138
  chain_embeddings_dictlist[key_name].append(embedding)
139
  lm_embeddings_chains_all = []
@@ -144,58 +238,98 @@ class PDBBind(Dataset):
144
 
145
  if self.num_workers > 1:
146
  # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
147
- for i in range(len(complex_names_all)//1000+1):
148
- if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
 
 
149
  continue
150
- complex_names = complex_names_all[1000*i:1000*(i+1)]
151
- lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
 
 
152
  complex_graphs, rdkit_ligands = [], []
153
  if self.num_workers > 1:
154
  p = Pool(self.num_workers, maxtasksperchild=1)
155
  p.__enter__()
156
- with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
 
 
 
157
  map_fn = p.imap_unordered if self.num_workers > 1 else map
158
- for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
 
 
 
 
 
 
 
 
159
  complex_graphs.extend(t[0])
160
  rdkit_ligands.extend(t[1])
161
  pbar.update()
162
- if self.num_workers > 1: p.__exit__(None, None, None)
 
163
 
164
- with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
 
 
165
  pickle.dump((complex_graphs), f)
166
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
 
 
167
  pickle.dump((rdkit_ligands), f)
168
 
169
  complex_graphs_all = []
170
- for i in range(len(complex_names_all)//1000+1):
171
- with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
 
 
172
  l = pickle.load(f)
173
  complex_graphs_all.extend(l)
174
- with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
 
 
175
  pickle.dump((complex_graphs_all), f)
176
 
177
  rdkit_ligands_all = []
178
  for i in range(len(complex_names_all) // 1000 + 1):
179
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
 
 
180
  l = pickle.load(f)
181
  rdkit_ligands_all.extend(l)
182
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
 
 
183
  pickle.dump((rdkit_ligands_all), f)
184
  else:
185
  complex_graphs, rdkit_ligands = [], []
186
- with tqdm(total=len(complex_names_all), desc='loading complexes') as pbar:
187
- for t in map(self.get_complex, zip(complex_names_all, lm_embeddings_chains_all, [None] * len(complex_names_all), [None] * len(complex_names_all))):
 
 
 
 
 
 
 
 
188
  complex_graphs.extend(t[0])
189
  rdkit_ligands.extend(t[1])
190
  pbar.update()
191
- with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
 
 
192
  pickle.dump((complex_graphs), f)
193
- with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
 
 
194
  pickle.dump((rdkit_ligands), f)
195
 
196
  def inference_preprocessing(self):
197
  ligands_list = []
198
- print('Reading molecules and generating local structures with RDKit')
199
  for ligand_description in tqdm(self.ligand_descriptions):
200
  mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
201
  if mol is not None:
@@ -211,70 +345,126 @@ class PDBBind(Dataset):
211
  ligands_list.append(mol)
212
 
213
  if self.esm_embeddings_path is not None:
214
- print('Reading language model embeddings.')
215
  lm_embeddings_chains_all = []
216
- if not os.path.exists(self.esm_embeddings_path): raise Exception('ESM embeddings path does not exist: ',self.esm_embeddings_path)
 
 
 
217
  for protein_path in self.protein_path_list:
218
- embeddings_paths = sorted(glob.glob(os.path.join(self.esm_embeddings_path, os.path.basename(protein_path)) + '*'))
 
 
 
 
 
 
 
219
  lm_embeddings_chains = []
220
  for embeddings_path in embeddings_paths:
221
- lm_embeddings_chains.append(torch.load(embeddings_path)['representations'][33])
 
 
222
  lm_embeddings_chains_all.append(lm_embeddings_chains)
223
  else:
224
  lm_embeddings_chains_all = [None] * len(self.protein_path_list)
225
 
226
- print('Generating graphs for ligands and proteins')
227
  if self.num_workers > 1:
228
  # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
229
- for i in range(len(self.protein_path_list)//1000+1):
230
- if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
 
 
231
  continue
232
- protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
233
- ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
234
- ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
235
- lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
 
 
 
 
236
  complex_graphs, rdkit_ligands = [], []
237
  if self.num_workers > 1:
238
  p = Pool(self.num_workers, maxtasksperchild=1)
239
  p.__enter__()
240
- with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
 
 
 
241
  map_fn = p.imap_unordered if self.num_workers > 1 else map
242
- for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
 
 
 
 
 
 
 
 
243
  complex_graphs.extend(t[0])
244
  rdkit_ligands.extend(t[1])
245
  pbar.update()
246
- if self.num_workers > 1: p.__exit__(None, None, None)
 
247
 
248
- with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
 
 
249
  pickle.dump((complex_graphs), f)
250
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
 
 
251
  pickle.dump((rdkit_ligands), f)
252
 
253
  complex_graphs_all = []
254
- for i in range(len(self.protein_path_list)//1000+1):
255
- with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
 
 
256
  l = pickle.load(f)
257
  complex_graphs_all.extend(l)
258
- with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
 
 
259
  pickle.dump((complex_graphs_all), f)
260
 
261
  rdkit_ligands_all = []
262
  for i in range(len(self.protein_path_list) // 1000 + 1):
263
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
 
 
264
  l = pickle.load(f)
265
  rdkit_ligands_all.extend(l)
266
- with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
 
 
267
  pickle.dump((rdkit_ligands_all), f)
268
  else:
269
  complex_graphs, rdkit_ligands = [], []
270
- with tqdm(total=len(self.protein_path_list), desc='loading complexes') as pbar:
271
- for t in map(self.get_complex, zip(self.protein_path_list, lm_embeddings_chains_all, ligands_list, self.ligand_descriptions)):
 
 
 
 
 
 
 
 
 
 
272
  complex_graphs.extend(t[0])
273
  rdkit_ligands.extend(t[1])
274
  pbar.update()
275
- with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
 
 
276
  pickle.dump((complex_graphs), f)
277
- with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
 
 
278
  pickle.dump((rdkit_ligands), f)
279
 
280
  def get_complex(self, par):
@@ -285,51 +475,94 @@ class PDBBind(Dataset):
285
 
286
  if ligand is not None:
287
  rec_model = parse_pdb_from_path(name)
288
- name = f'{name}____{ligand_description}'
289
  ligs = [ligand]
290
  else:
291
  try:
292
  rec_model = parse_receptor(name, self.pdbbind_dir)
293
  except Exception as e:
294
- print(f'Skipping {name} because of the error:')
295
  print(e)
296
  return [], []
297
 
298
  ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
299
  complex_graphs = []
300
  for i, lig in enumerate(ligs):
301
- if self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size:
302
- print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
 
 
 
 
 
303
  continue
304
  complex_graph = HeteroData()
305
- complex_graph['name'] = name
306
  try:
307
- get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
308
- self.num_conformers, remove_hs=self.remove_hs)
309
- rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings = extract_receptor_structure(copy.deepcopy(rec_model), lig, lm_embedding_chains=lm_embedding_chains)
310
- if lm_embeddings is not None and len(c_alpha_coords) != len(lm_embeddings):
311
- print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  continue
313
 
314
- get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius,
315
- c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms,
316
- atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  except Exception as e:
319
- print(f'Skipping {name} because of the error:')
320
  print(e)
321
  raise e
322
  continue
323
 
324
- protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
325
- complex_graph['receptor'].pos -= protein_center
 
 
326
  if self.all_atoms:
327
- complex_graph['atom'].pos -= protein_center
328
 
329
  if (not self.matching) or self.num_conformers == 1:
330
- complex_graph['ligand'].pos -= protein_center
331
  else:
332
- for p in complex_graph['ligand'].pos:
333
  p -= protein_center
334
 
335
  complex_graph.original_center = protein_center
@@ -341,11 +574,18 @@ def print_statistics(complex_graphs):
341
  statistics = ([], [], [], [])
342
 
343
  for complex_graph in complex_graphs:
344
- lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
345
- radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
 
 
 
 
 
 
346
  molecule_center = torch.mean(lig_pos, dim=0)
347
  radius_molecule = torch.max(
348
- torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1))
 
349
  distance_center = torch.linalg.vector_norm(molecule_center)
350
  statistics[0].append(radius_protein)
351
  statistics[1].append(radius_molecule)
@@ -355,52 +595,111 @@ def print_statistics(complex_graphs):
355
  else:
356
  statistics[3].append(0)
357
 
358
- name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching']
359
- print('Number of complexes: ', len(complex_graphs))
 
 
 
 
 
360
  for i in range(4):
361
  array = np.asarray(statistics[i])
362
- print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
 
 
363
 
364
 
365
  def construct_loader(args, t_to_sigma):
366
- transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
367
- all_atom=args.all_atoms)
368
-
369
- common_args = {'transform': transform, 'root': args.data_dir, 'limit_complexes': args.limit_complexes,
370
- 'receptor_radius': args.receptor_radius,
371
- 'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
372
- 'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
373
- 'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
374
- 'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
375
- 'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
376
- 'esm_embeddings_path': args.esm_embeddings_path}
377
-
378
- train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
379
- num_conformers=args.num_conformers, **common_args)
380
- val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True, **common_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
383
- train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
384
- val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  return train_loader, val_loader
387
 
388
 
389
  def read_mol(pdbbind_dir, name, remove_hs=False):
390
- lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.sdf'), remove_hs=remove_hs, sanitize=True)
 
 
 
 
391
  if lig is None: # read mol2 file if sdf file cannot be sanitized
392
- lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.mol2'), remove_hs=remove_hs, sanitize=True)
 
 
 
 
393
  return lig
394
 
395
 
396
  def read_mols(pdbbind_dir, name, remove_hs=False):
397
  ligs = []
398
  for file in os.listdir(os.path.join(pdbbind_dir, name)):
399
- if file.endswith(".sdf") and 'rdkit' not in file:
400
- lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)
401
- if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")): # read mol2 file if sdf file cannot be sanitized
402
- print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
403
- lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True)
 
 
 
 
 
 
 
 
 
 
 
 
404
  if lig is not None:
405
  ligs.append(lig)
406
- return ligs
 
16
  from torch_geometric.transforms import BaseTransform
17
  from tqdm import tqdm
18
 
19
+ from datasets.process_mols import (
20
+ read_molecule,
21
+ get_rec_graph,
22
+ generate_conformer,
23
+ get_lig_graph_with_matching,
24
+ extract_receptor_structure,
25
+ parse_receptor,
26
+ parse_pdb_from_path,
27
+ )
28
  from utils.diffusion_utils import modify_conformer, set_time
29
  from utils.utils import read_strings_from_txt
30
  from utils import so3, torus
 
41
  t_tr, t_rot, t_tor = t, t, t
42
  return self.apply_noise(data, t_tr, t_rot, t_tor)
43
 
44
+ def apply_noise(
45
+ self,
46
+ data,
47
+ t_tr,
48
+ t_rot,
49
+ t_tor,
50
+ tr_update=None,
51
+ rot_update=None,
52
+ torsion_updates=None,
53
+ ):
54
+ if not torch.is_tensor(data["ligand"].pos):
55
+ data["ligand"].pos = random.choice(data["ligand"].pos)
56
 
57
  tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
58
  set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
59
 
60
+ tr_update = (
61
+ torch.normal(mean=0, std=tr_sigma, size=(1, 3))
62
+ if tr_update is None
63
+ else tr_update
64
+ )
65
  rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
66
+ torsion_updates = (
67
+ np.random.normal(
68
+ loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum()
69
+ )
70
+ if torsion_updates is None
71
+ else torsion_updates
72
+ )
73
  torsion_updates = None if self.no_torsion else torsion_updates
74
+ modify_conformer(
75
+ data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates
76
+ )
77
+
78
+ data.tr_score = -tr_update / tr_sigma**2
79
+ data.rot_score = (
80
+ torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma))
81
+ .float()
82
+ .unsqueeze(0)
83
+ )
84
+ data.tor_score = (
85
+ None
86
+ if self.no_torsion
87
+ else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
88
+ )
89
+ data.tor_sigma_edge = (
90
+ None
91
+ if self.no_torsion
92
+ else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma
93
+ )
94
  return data
95
 
96
 
97
  class PDBBind(Dataset):
98
+ def __init__(
99
+ self,
100
+ root,
101
+ transform=None,
102
+ cache_path="data/cache",
103
+ split_path="data/",
104
+ limit_complexes=0,
105
+ receptor_radius=30,
106
+ num_workers=1,
107
+ c_alpha_max_neighbors=None,
108
+ popsize=15,
109
+ maxiter=15,
110
+ matching=True,
111
+ keep_original=False,
112
+ max_lig_size=None,
113
+ remove_hs=False,
114
+ num_conformers=1,
115
+ all_atoms=False,
116
+ atom_radius=5,
117
+ atom_max_neighbors=None,
118
+ esm_embeddings_path=None,
119
+ require_ligand=False,
120
+ ligands_list=None,
121
+ protein_path_list=None,
122
+ ligand_descriptions=None,
123
+ keep_local_structures=False,
124
+ ):
125
 
126
  super(PDBBind, self).__init__(root, transform)
127
  self.pdbbind_dir = root
 
137
  self.protein_path_list = protein_path_list
138
  self.ligand_descriptions = ligand_descriptions
139
  self.keep_local_structures = keep_local_structures
140
+ if (
141
+ matching
142
+ or protein_path_list is not None
143
+ and ligand_descriptions is not None
144
+ ):
145
+ cache_path += "_torsion"
146
  if all_atoms:
147
+ cache_path += "_allatoms"
148
+ self.full_cache_path = os.path.join(
149
+ cache_path,
150
+ f"limit{self.limit_complexes}"
151
+ f"_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}"
152
+ f"_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}"
153
+ f"_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}"
154
+ + (
155
+ ""
156
+ if not all_atoms
157
+ else f"_atomRad{atom_radius}_atomMax{atom_max_neighbors}"
158
+ )
159
+ + ("" if not matching or num_conformers == 1 else f"_confs{num_conformers}")
160
+ + ("" if self.esm_embeddings_path is None else f"_esmEmbeddings")
161
+ + ("" if not keep_local_structures else f"_keptLocalStruct")
162
+ + (
163
+ ""
164
+ if protein_path_list is None or ligand_descriptions is None
165
+ else str(
166
+ binascii.crc32(
167
+ "".join(ligand_descriptions + protein_path_list).encode()
168
+ )
169
+ )
170
+ ),
171
+ )
172
  self.popsize, self.maxiter = popsize, maxiter
173
  self.matching, self.keep_original = matching, keep_original
174
  self.num_conformers = num_conformers
175
  self.all_atoms = all_atoms
176
  self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
177
+ if not os.path.exists(
178
+ os.path.join(self.full_cache_path, "heterographs.pkl")
179
+ ) or (
180
+ require_ligand
181
+ and not os.path.exists(
182
+ os.path.join(self.full_cache_path, "rdkit_ligands.pkl")
183
+ )
184
+ ):
185
  os.makedirs(self.full_cache_path, exist_ok=True)
186
  if protein_path_list is None or ligand_descriptions is None:
187
  self.preprocessing()
188
  else:
189
  self.inference_preprocessing()
190
 
191
+ print(
192
+ "loading data from memory: ",
193
+ os.path.join(self.full_cache_path, "heterographs.pkl"),
194
+ )
195
+ with open(os.path.join(self.full_cache_path, "heterographs.pkl"), "rb") as f:
196
  self.complex_graphs = pickle.load(f)
197
  if require_ligand:
198
+ with open(
199
+ os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "rb"
200
+ ) as f:
201
  self.rdkit_ligands = pickle.load(f)
202
 
203
  print_statistics(self.complex_graphs)
 
214
  return copy.deepcopy(self.complex_graphs[idx])
215
 
216
  def preprocessing(self):
217
+ print(
218
+ f"Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]"
219
+ )
220
 
221
  complex_names_all = read_strings_from_txt(self.split_path)
222
  if self.limit_complexes is not None and self.limit_complexes != 0:
223
+ complex_names_all = complex_names_all[: self.limit_complexes]
224
+ print(f"Loading {len(complex_names_all)} complexes.")
225
 
226
  if self.esm_embeddings_path is not None:
227
  id_to_embeddings = torch.load(self.esm_embeddings_path)
228
  chain_embeddings_dictlist = defaultdict(list)
229
  for key, embedding in id_to_embeddings.items():
230
+ key_name = key.split("_")[0]
231
  if key_name in complex_names_all:
232
  chain_embeddings_dictlist[key_name].append(embedding)
233
  lm_embeddings_chains_all = []
 
238
 
239
  if self.num_workers > 1:
240
  # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
241
+ for i in range(len(complex_names_all) // 1000 + 1):
242
+ if os.path.exists(
243
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl")
244
+ ):
245
  continue
246
+ complex_names = complex_names_all[1000 * i : 1000 * (i + 1)]
247
+ lm_embeddings_chains = lm_embeddings_chains_all[
248
+ 1000 * i : 1000 * (i + 1)
249
+ ]
250
  complex_graphs, rdkit_ligands = [], []
251
  if self.num_workers > 1:
252
  p = Pool(self.num_workers, maxtasksperchild=1)
253
  p.__enter__()
254
+ with tqdm(
255
+ total=len(complex_names),
256
+ desc=f"loading complexes {i}/{len(complex_names_all)//1000+1}",
257
+ ) as pbar:
258
  map_fn = p.imap_unordered if self.num_workers > 1 else map
259
+ for t in map_fn(
260
+ self.get_complex,
261
+ zip(
262
+ complex_names,
263
+ lm_embeddings_chains,
264
+ [None] * len(complex_names),
265
+ [None] * len(complex_names),
266
+ ),
267
+ ):
268
  complex_graphs.extend(t[0])
269
  rdkit_ligands.extend(t[1])
270
  pbar.update()
271
+ if self.num_workers > 1:
272
+ p.__exit__(None, None, None)
273
 
274
+ with open(
275
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb"
276
+ ) as f:
277
  pickle.dump((complex_graphs), f)
278
+ with open(
279
+ os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb"
280
+ ) as f:
281
  pickle.dump((rdkit_ligands), f)
282
 
283
  complex_graphs_all = []
284
+ for i in range(len(complex_names_all) // 1000 + 1):
285
+ with open(
286
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb"
287
+ ) as f:
288
  l = pickle.load(f)
289
  complex_graphs_all.extend(l)
290
+ with open(
291
+ os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb"
292
+ ) as f:
293
  pickle.dump((complex_graphs_all), f)
294
 
295
  rdkit_ligands_all = []
296
  for i in range(len(complex_names_all) // 1000 + 1):
297
+ with open(
298
+ os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb"
299
+ ) as f:
300
  l = pickle.load(f)
301
  rdkit_ligands_all.extend(l)
302
+ with open(
303
+ os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb"
304
+ ) as f:
305
  pickle.dump((rdkit_ligands_all), f)
306
  else:
307
  complex_graphs, rdkit_ligands = [], []
308
+ with tqdm(total=len(complex_names_all), desc="loading complexes") as pbar:
309
+ for t in map(
310
+ self.get_complex,
311
+ zip(
312
+ complex_names_all,
313
+ lm_embeddings_chains_all,
314
+ [None] * len(complex_names_all),
315
+ [None] * len(complex_names_all),
316
+ ),
317
+ ):
318
  complex_graphs.extend(t[0])
319
  rdkit_ligands.extend(t[1])
320
  pbar.update()
321
+ with open(
322
+ os.path.join(self.full_cache_path, "heterographs.pkl"), "wb"
323
+ ) as f:
324
  pickle.dump((complex_graphs), f)
325
+ with open(
326
+ os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb"
327
+ ) as f:
328
  pickle.dump((rdkit_ligands), f)
329
 
330
  def inference_preprocessing(self):
331
  ligands_list = []
332
+ print("Reading molecules and generating local structures with RDKit")
333
  for ligand_description in tqdm(self.ligand_descriptions):
334
  mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
335
  if mol is not None:
 
345
  ligands_list.append(mol)
346
 
347
  if self.esm_embeddings_path is not None:
348
+ print("Reading language model embeddings.")
349
  lm_embeddings_chains_all = []
350
+ if not os.path.exists(self.esm_embeddings_path):
351
+ raise Exception(
352
+ "ESM embeddings path does not exist: ", self.esm_embeddings_path
353
+ )
354
  for protein_path in self.protein_path_list:
355
+ embeddings_paths = sorted(
356
+ glob.glob(
357
+ os.path.join(
358
+ self.esm_embeddings_path, os.path.basename(protein_path)
359
+ )
360
+ + "*"
361
+ )
362
+ )
363
  lm_embeddings_chains = []
364
  for embeddings_path in embeddings_paths:
365
+ lm_embeddings_chains.append(
366
+ torch.load(embeddings_path)["representations"][33]
367
+ )
368
  lm_embeddings_chains_all.append(lm_embeddings_chains)
369
  else:
370
  lm_embeddings_chains_all = [None] * len(self.protein_path_list)
371
 
372
+ print("Generating graphs for ligands and proteins")
373
  if self.num_workers > 1:
374
  # running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
375
+ for i in range(len(self.protein_path_list) // 1000 + 1):
376
+ if os.path.exists(
377
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl")
378
+ ):
379
  continue
380
+ protein_paths_chunk = self.protein_path_list[1000 * i : 1000 * (i + 1)]
381
+ ligand_description_chunk = self.ligand_descriptions[
382
+ 1000 * i : 1000 * (i + 1)
383
+ ]
384
+ ligands_chunk = ligands_list[1000 * i : 1000 * (i + 1)]
385
+ lm_embeddings_chains = lm_embeddings_chains_all[
386
+ 1000 * i : 1000 * (i + 1)
387
+ ]
388
  complex_graphs, rdkit_ligands = [], []
389
  if self.num_workers > 1:
390
  p = Pool(self.num_workers, maxtasksperchild=1)
391
  p.__enter__()
392
+ with tqdm(
393
+ total=len(protein_paths_chunk),
394
+ desc=f"loading complexes {i}/{len(protein_paths_chunk)//1000+1}",
395
+ ) as pbar:
396
  map_fn = p.imap_unordered if self.num_workers > 1 else map
397
+ for t in map_fn(
398
+ self.get_complex,
399
+ zip(
400
+ protein_paths_chunk,
401
+ lm_embeddings_chains,
402
+ ligands_chunk,
403
+ ligand_description_chunk,
404
+ ),
405
+ ):
406
  complex_graphs.extend(t[0])
407
  rdkit_ligands.extend(t[1])
408
  pbar.update()
409
+ if self.num_workers > 1:
410
+ p.__exit__(None, None, None)
411
 
412
+ with open(
413
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb"
414
+ ) as f:
415
  pickle.dump((complex_graphs), f)
416
+ with open(
417
+ os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb"
418
+ ) as f:
419
  pickle.dump((rdkit_ligands), f)
420
 
421
  complex_graphs_all = []
422
+ for i in range(len(self.protein_path_list) // 1000 + 1):
423
+ with open(
424
+ os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb"
425
+ ) as f:
426
  l = pickle.load(f)
427
  complex_graphs_all.extend(l)
428
+ with open(
429
+ os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb"
430
+ ) as f:
431
  pickle.dump((complex_graphs_all), f)
432
 
433
  rdkit_ligands_all = []
434
  for i in range(len(self.protein_path_list) // 1000 + 1):
435
+ with open(
436
+ os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb"
437
+ ) as f:
438
  l = pickle.load(f)
439
  rdkit_ligands_all.extend(l)
440
+ with open(
441
+ os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb"
442
+ ) as f:
443
  pickle.dump((rdkit_ligands_all), f)
444
  else:
445
  complex_graphs, rdkit_ligands = [], []
446
+ with tqdm(
447
+ total=len(self.protein_path_list), desc="loading complexes"
448
+ ) as pbar:
449
+ for t in map(
450
+ self.get_complex,
451
+ zip(
452
+ self.protein_path_list,
453
+ lm_embeddings_chains_all,
454
+ ligands_list,
455
+ self.ligand_descriptions,
456
+ ),
457
+ ):
458
  complex_graphs.extend(t[0])
459
  rdkit_ligands.extend(t[1])
460
  pbar.update()
461
+ with open(
462
+ os.path.join(self.full_cache_path, "heterographs.pkl"), "wb"
463
+ ) as f:
464
  pickle.dump((complex_graphs), f)
465
+ with open(
466
+ os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb"
467
+ ) as f:
468
  pickle.dump((rdkit_ligands), f)
469
 
470
  def get_complex(self, par):
 
475
 
476
  if ligand is not None:
477
  rec_model = parse_pdb_from_path(name)
478
+ name = f"{name}____{ligand_description}"
479
  ligs = [ligand]
480
  else:
481
  try:
482
  rec_model = parse_receptor(name, self.pdbbind_dir)
483
  except Exception as e:
484
+ print(f"Skipping {name} because of the error:")
485
  print(e)
486
  return [], []
487
 
488
  ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
489
  complex_graphs = []
490
  for i, lig in enumerate(ligs):
491
+ if (
492
+ self.max_lig_size is not None
493
+ and lig.GetNumHeavyAtoms() > self.max_lig_size
494
+ ):
495
+ print(
496
+ f"Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data."
497
+ )
498
  continue
499
  complex_graph = HeteroData()
500
+ complex_graph["name"] = name
501
  try:
502
+ get_lig_graph_with_matching(
503
+ lig,
504
+ complex_graph,
505
+ self.popsize,
506
+ self.maxiter,
507
+ self.matching,
508
+ self.keep_original,
509
+ self.num_conformers,
510
+ remove_hs=self.remove_hs,
511
+ )
512
+ print(lm_embedding_chains)
513
+ (
514
+ rec,
515
+ rec_coords,
516
+ c_alpha_coords,
517
+ n_coords,
518
+ c_coords,
519
+ lm_embeddings,
520
+ ) = extract_receptor_structure(
521
+ copy.deepcopy(rec_model),
522
+ lig,
523
+ lm_embedding_chains=lm_embedding_chains,
524
+ )
525
+ if lm_embeddings is not None and len(c_alpha_coords) != len(
526
+ lm_embeddings
527
+ ):
528
+ print(
529
+ f"LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}."
530
+ )
531
  continue
532
 
533
+ get_rec_graph(
534
+ rec,
535
+ rec_coords,
536
+ c_alpha_coords,
537
+ n_coords,
538
+ c_coords,
539
+ complex_graph,
540
+ rec_radius=self.receptor_radius,
541
+ c_alpha_max_neighbors=self.c_alpha_max_neighbors,
542
+ all_atoms=self.all_atoms,
543
+ atom_radius=self.atom_radius,
544
+ atom_max_neighbors=self.atom_max_neighbors,
545
+ remove_hs=self.remove_hs,
546
+ lm_embeddings=lm_embeddings,
547
+ )
548
 
549
  except Exception as e:
550
+ print(f"Skipping {name} because of the error:")
551
  print(e)
552
  raise e
553
  continue
554
 
555
+ protein_center = torch.mean(
556
+ complex_graph["receptor"].pos, dim=0, keepdim=True
557
+ )
558
+ complex_graph["receptor"].pos -= protein_center
559
  if self.all_atoms:
560
+ complex_graph["atom"].pos -= protein_center
561
 
562
  if (not self.matching) or self.num_conformers == 1:
563
+ complex_graph["ligand"].pos -= protein_center
564
  else:
565
+ for p in complex_graph["ligand"].pos:
566
  p -= protein_center
567
 
568
  complex_graph.original_center = protein_center
 
574
  statistics = ([], [], [], [])
575
 
576
  for complex_graph in complex_graphs:
577
+ lig_pos = (
578
+ complex_graph["ligand"].pos
579
+ if torch.is_tensor(complex_graph["ligand"].pos)
580
+ else complex_graph["ligand"].pos[0]
581
+ )
582
+ radius_protein = torch.max(
583
+ torch.linalg.vector_norm(complex_graph["receptor"].pos, dim=1)
584
+ )
585
  molecule_center = torch.mean(lig_pos, dim=0)
586
  radius_molecule = torch.max(
587
+ torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1)
588
+ )
589
  distance_center = torch.linalg.vector_norm(molecule_center)
590
  statistics[0].append(radius_protein)
591
  statistics[1].append(radius_molecule)
 
595
  else:
596
  statistics[3].append(0)
597
 
598
+ name = [
599
+ "radius protein",
600
+ "radius molecule",
601
+ "distance protein-mol",
602
+ "rmsd matching",
603
+ ]
604
+ print("Number of complexes: ", len(complex_graphs))
605
  for i in range(4):
606
  array = np.asarray(statistics[i])
607
+ print(
608
+ f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}"
609
+ )
610
 
611
 
612
  def construct_loader(args, t_to_sigma):
613
+ transform = NoiseTransform(
614
+ t_to_sigma=t_to_sigma, no_torsion=args.no_torsion, all_atom=args.all_atoms
615
+ )
616
+
617
+ common_args = {
618
+ "transform": transform,
619
+ "root": args.data_dir,
620
+ "limit_complexes": args.limit_complexes,
621
+ "receptor_radius": args.receptor_radius,
622
+ "c_alpha_max_neighbors": args.c_alpha_max_neighbors,
623
+ "remove_hs": args.remove_hs,
624
+ "max_lig_size": args.max_lig_size,
625
+ "matching": not args.no_torsion,
626
+ "popsize": args.matching_popsize,
627
+ "maxiter": args.matching_maxiter,
628
+ "num_workers": args.num_workers,
629
+ "all_atoms": args.all_atoms,
630
+ "atom_radius": args.atom_radius,
631
+ "atom_max_neighbors": args.atom_max_neighbors,
632
+ "esm_embeddings_path": args.esm_embeddings_path,
633
+ }
634
+
635
+ train_dataset = PDBBind(
636
+ cache_path=args.cache_path,
637
+ split_path=args.split_train,
638
+ keep_original=True,
639
+ num_conformers=args.num_conformers,
640
+ **common_args,
641
+ )
642
+ val_dataset = PDBBind(
643
+ cache_path=args.cache_path,
644
+ split_path=args.split_val,
645
+ keep_original=True,
646
+ **common_args,
647
+ )
648
 
649
  loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
650
+ train_loader = loader_class(
651
+ dataset=train_dataset,
652
+ batch_size=args.batch_size,
653
+ num_workers=args.num_dataloader_workers,
654
+ shuffle=True,
655
+ pin_memory=args.pin_memory,
656
+ )
657
+ val_loader = loader_class(
658
+ dataset=val_dataset,
659
+ batch_size=args.batch_size,
660
+ num_workers=args.num_dataloader_workers,
661
+ shuffle=True,
662
+ pin_memory=args.pin_memory,
663
+ )
664
 
665
  return train_loader, val_loader
666
 
667
 
668
  def read_mol(pdbbind_dir, name, remove_hs=False):
669
+ lig = read_molecule(
670
+ os.path.join(pdbbind_dir, name, f"{name}_ligand.sdf"),
671
+ remove_hs=remove_hs,
672
+ sanitize=True,
673
+ )
674
  if lig is None: # read mol2 file if sdf file cannot be sanitized
675
+ lig = read_molecule(
676
+ os.path.join(pdbbind_dir, name, f"{name}_ligand.mol2"),
677
+ remove_hs=remove_hs,
678
+ sanitize=True,
679
+ )
680
  return lig
681
 
682
 
683
  def read_mols(pdbbind_dir, name, remove_hs=False):
684
  ligs = []
685
  for file in os.listdir(os.path.join(pdbbind_dir, name)):
686
+ if file.endswith(".sdf") and "rdkit" not in file:
687
+ lig = read_molecule(
688
+ os.path.join(pdbbind_dir, name, file),
689
+ remove_hs=remove_hs,
690
+ sanitize=True,
691
+ )
692
+ if lig is None and os.path.exists(
693
+ os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")
694
+ ): # read mol2 file if sdf file cannot be sanitized
695
+ print(
696
+ "Using the .sdf file failed. We found a .mol2 file instead and are trying to use that."
697
+ )
698
+ lig = read_molecule(
699
+ os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"),
700
+ remove_hs=remove_hs,
701
+ sanitize=True,
702
+ )
703
  if lig is not None:
704
  ligs.append(lig)
705
+ return ligs
datasets/process_mols.py CHANGED
@@ -490,8 +490,10 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
490
  if molecule_file.endswith('.mol2'):
491
  mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
492
  elif molecule_file.endswith('.sdf'):
 
493
  supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
494
  mol = supplier[0]
 
495
  elif molecule_file.endswith('.pdbqt'):
496
  with open(molecule_file) as file:
497
  pdbqt_data = file.readlines()
@@ -505,6 +507,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
505
  return ValueError('Expect the format of the molecule_file to be '
506
  'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
507
 
 
 
508
  try:
509
  if sanitize or calc_charges:
510
  Chem.SanitizeMol(mol)
@@ -518,7 +522,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
518
 
519
  if remove_hs:
520
  mol = Chem.RemoveHs(mol, sanitize=sanitize)
521
- except:
 
522
  return None
523
 
524
  return mol
 
490
  if molecule_file.endswith('.mol2'):
491
  mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
492
  elif molecule_file.endswith('.sdf'):
493
+ print(molecule_file)
494
  supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
495
  mol = supplier[0]
496
+ print(mol)
497
  elif molecule_file.endswith('.pdbqt'):
498
  with open(molecule_file) as file:
499
  pdbqt_data = file.readlines()
 
507
  return ValueError('Expect the format of the molecule_file to be '
508
  'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
509
 
510
+ print(sanitize, calc_charges, remove_hs)
511
+
512
  try:
513
  if sanitize or calc_charges:
514
  Chem.SanitizeMol(mol)
 
522
 
523
  if remove_hs:
524
  mol = Chem.RemoveHs(mol, sanitize=sanitize)
525
+ except Exception as e:
526
+ print(e)
527
  return None
528
 
529
  return mol
examples/1a46_ligand.sdf ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1a46_ligand
2
+ -I-interpret-
3
+
4
+ 85 88 0 0 0 0 0 0 0 0999 V2000
5
+ 17.8330 -13.0420 21.6620 C 0 0 0 0 0
6
+ 18.8870 -13.0710 20.5870 C 0 0 0 0 0
7
+ 19.8510 -14.2200 21.1170 C 0 0 0 0 0
8
+ 19.3270 -16.4440 22.1560 C 0 0 0 0 0
9
+ 18.1340 -17.2300 22.7620 C 0 0 0 0 0
10
+ 17.2230 -16.3290 23.5970 C 0 0 0 0 0
11
+ 17.0320 -14.9230 23.0460 C 0 0 0 0 0
12
+ 18.8520 -15.2420 21.4440 N 0 3 0 0 0
13
+ 17.7750 -14.5090 22.0480 N 0 0 0 0 0
14
+ 15.9850 -14.2900 23.3800 O 0 0 0 0 0
15
+ 16.6380 -13.0610 20.7550 C 0 0 0 0 0
16
+ 16.4620 -13.9620 19.8370 O 0 0 0 0 0
17
+ 15.8090 -16.7300 23.6610 N 0 3 0 0 0
18
+ 17.4150 -16.4170 25.1230 C 0 0 0 0 0
19
+ 18.7640 -15.9840 25.5820 C 0 0 0 0 0
20
+ 19.0510 -14.6340 25.7600 C 0 0 0 0 0
21
+ 20.3910 -14.2520 26.0760 C 0 0 0 0 0
22
+ 21.4290 -15.1780 26.2150 C 0 0 0 0 0
23
+ 21.0990 -16.5480 26.0980 C 0 0 0 0 0
24
+ 19.7890 -16.9510 25.7560 C 0 0 0 0 0
25
+ 15.6470 -12.0890 20.7690 N 0 0 0 0 0
26
+ 14.4940 -11.8920 19.9090 C 0 0 0 0 0
27
+ 14.4960 -10.9450 18.7130 C 0 0 0 0 0
28
+ 13.3800 -10.6840 18.0770 O 0 0 0 0 0
29
+ 13.1950 -11.6150 20.6280 C 0 0 0 0 0
30
+ 12.8670 -12.5040 21.7570 C 0 0 0 0 0
31
+ 11.5610 -12.2200 22.4370 C 0 0 0 0 0
32
+ 11.1700 -13.3510 23.3530 C 0 0 0 0 0
33
+ 10.0380 -13.1110 24.2350 N 0 3 0 0 0
34
+ 14.8040 -11.9210 16.4570 N 0 0 0 0 0
35
+ 15.3450 -11.4350 17.5510 C 0 0 0 0 0
36
+ 16.4740 -11.0890 17.7310 O 0 0 0 0 0
37
+ 15.6510 -12.3330 15.3350 C 0 0 0 0 0
38
+ 16.0390 -13.7960 15.2500 C 0 0 0 0 0
39
+ 14.9560 -14.6030 14.5390 C 0 0 0 0 0
40
+ 14.5990 -13.9990 13.1800 C 0 0 0 0 0
41
+ 14.1680 -12.5610 13.3540 C 0 0 0 0 0
42
+ 15.2770 -11.7400 13.9980 C 0 0 0 0 0
43
+ 17.9332 -12.2994 22.4536 H 0 0 0 0 0
44
+ 19.3882 -12.1140 20.4420 H 0 0 0 0 0
45
+ 18.4882 -13.2617 19.5906 H 0 0 0 0 0
46
+ 20.4926 -13.9283 21.9484 H 0 0 0 0 0
47
+ 20.6127 -14.5392 20.4056 H 0 0 0 0 0
48
+ 19.8508 -17.0880 21.4496 H 0 0 0 0 0
49
+ 19.9921 -16.1358 22.9627 H 0 0 0 0 0
50
+ 18.5327 -18.0092 23.4116 H 0 0 0 0 0
51
+ 17.5467 -17.6450 21.9429 H 0 0 0 0 0
52
+ 18.5389 -15.7277 20.6035 H 0 0 0 0 0
53
+ 15.7428 -17.6818 24.0216 H 0 0 0 0 0
54
+ 15.3044 -16.0949 24.2794 H 0 0 0 0 0
55
+ 15.4029 -16.6903 22.7262 H 0 0 0 0 0
56
+ 17.2937 -17.4623 25.4072 H 0 0 0 0 0
57
+ 16.6848 -15.7509 25.5825 H 0 0 0 0 0
58
+ 18.2682 -13.8821 25.6602 H 0 0 0 0 0
59
+ 20.6133 -13.1939 26.2145 H 0 0 0 0 0
60
+ 22.4528 -14.8565 26.4061 H 0 0 0 0 0
61
+ 21.8654 -17.3029 26.2740 H 0 0 0 0 0
62
+ 19.5640 -18.0094 25.6250 H 0 0 0 0 0
63
+ 15.7457 -11.3948 21.5098 H 0 0 0 0 0
64
+ 14.5905 -12.8910 19.4839 H 0 0 0 0 0
65
+ 14.8425 -10.0689 19.2612 H 0 0 0 0 0
66
+ 13.5584 -10.0751 17.3566 H 0 0 0 0 0
67
+ 12.4050 -11.7585 19.8909 H 0 0 0 0 0
68
+ 13.2901 -10.6141 21.0491 H 0 0 0 0 0
69
+ 13.6465 -12.3595 22.5050 H 0 0 0 0 0
70
+ 12.7942 -13.5124 21.3496 H 0 0 0 0 0
71
+ 10.7892 -12.1043 21.6761 H 0 0 0 0 0
72
+ 11.6663 -11.3113 23.0296 H 0 0 0 0 0
73
+ 12.0278 -13.5229 24.0031 H 0 0 0 0 0
74
+ 10.8774 -14.1769 22.7046 H 0 0 0 0 0
75
+ 9.8690 -13.9413 24.8029 H 0 0 0 0 0
76
+ 10.2441 -12.3181 24.8427 H 0 0 0 0 0
77
+ 9.2101 -12.9059 23.6756 H 0 0 0 0 0
78
+ 13.7904 -12.0118 16.3885 H 0 0 0 0 0
79
+ 16.5871 -11.8550 15.6237 H 0 0 0 0 0
80
+ 16.1623 -14.1864 16.2602 H 0 0 0 0 0
81
+ 16.9681 -13.8812 14.6864 H 0 0 0 0 0
82
+ 14.0613 -14.5994 15.1616 H 0 0 0 0 0
83
+ 15.3317 -15.6133 14.3772 H 0 0 0 0 0
84
+ 13.7819 -14.5683 12.7368 H 0 0 0 0 0
85
+ 15.4725 -14.0364 12.5291 H 0 0 0 0 0
86
+ 13.2893 -12.5323 13.9983 H 0 0 0 0 0
87
+ 13.9420 -12.1402 12.3742 H 0 0 0 0 0
88
+ 16.1510 -11.7459 13.3467 H 0 0 0 0 0
89
+ 14.9268 -10.7183 14.1449 H 0 0 0 0 0
90
+ 2 1 1 0 0 0
91
+ 1 9 1 0 0 0
92
+ 1 11 1 0 0 0
93
+ 3 2 1 0 0 0
94
+ 8 3 1 0 0 0
95
+ 4 5 1 0 0 0
96
+ 4 8 1 0 0 0
97
+ 5 6 1 0 0 0
98
+ 6 7 1 0 0 0
99
+ 6 13 1 0 0 0
100
+ 6 14 1 0 0 0
101
+ 7 9 1 0 0 0
102
+ 7 10 2 0 0 0
103
+ 8 9 1 0 0 0
104
+ 11 12 2 0 0 0
105
+ 11 21 1 0 0 0
106
+ 14 15 1 0 0 0
107
+ 15 16 4 0 0 0
108
+ 15 20 4 0 0 0
109
+ 16 17 4 0 0 0
110
+ 17 18 4 0 0 0
111
+ 18 19 4 0 0 0
112
+ 19 20 4 0 0 0
113
+ 21 22 1 0 0 0
114
+ 22 23 1 0 0 0
115
+ 22 25 1 0 0 0
116
+ 23 24 1 0 0 0
117
+ 23 31 1 0 0 0
118
+ 25 26 1 0 0 0
119
+ 26 27 1 0 0 0
120
+ 27 28 1 0 0 0
121
+ 28 29 1 0 0 0
122
+ 31 30 1 0 0 0
123
+ 30 33 1 0 0 0
124
+ 31 32 2 0 0 0
125
+ 33 34 1 0 0 0
126
+ 33 38 1 0 0 0
127
+ 34 35 1 0 0 0
128
+ 35 36 1 0 0 0
129
+ 36 37 1 0 0 0
130
+ 37 38 1 0 0 0
131
+ 1 39 1 0 0 0
132
+ 2 40 1 0 0 0
133
+ 2 41 1 0 0 0
134
+ 3 42 1 0 0 0
135
+ 3 43 1 0 0 0
136
+ 4 44 1 0 0 0
137
+ 4 45 1 0 0 0
138
+ 5 46 1 0 0 0
139
+ 5 47 1 0 0 0
140
+ 8 48 1 0 0 0
141
+ 13 49 1 0 0 0
142
+ 13 50 1 0 0 0
143
+ 13 51 1 0 0 0
144
+ 14 52 1 0 0 0
145
+ 14 53 1 0 0 0
146
+ 16 54 1 0 0 0
147
+ 17 55 1 0 0 0
148
+ 18 56 1 0 0 0
149
+ 19 57 1 0 0 0
150
+ 20 58 1 0 0 0
151
+ 21 59 1 0 0 0
152
+ 22 60 1 0 0 0
153
+ 23 61 1 0 0 0
154
+ 24 62 1 0 0 0
155
+ 25 63 1 0 0 0
156
+ 25 64 1 0 0 0
157
+ 26 65 1 0 0 0
158
+ 26 66 1 0 0 0
159
+ 27 67 1 0 0 0
160
+ 27 68 1 0 0 0
161
+ 28 69 1 0 0 0
162
+ 28 70 1 0 0 0
163
+ 29 71 1 0 0 0
164
+ 29 72 1 0 0 0
165
+ 29 73 1 0 0 0
166
+ 30 74 1 0 0 0
167
+ 33 75 1 0 0 0
168
+ 34 76 1 0 0 0
169
+ 34 77 1 0 0 0
170
+ 35 78 1 0 0 0
171
+ 35 79 1 0 0 0
172
+ 36 80 1 0 0 0
173
+ 36 81 1 0 0 0
174
+ 37 82 1 0 0 0
175
+ 37 83 1 0 0 0
176
+ 38 84 1 0 0 0
177
+ 38 85 1 0 0 0
178
+ M END
179
+ $$$$
examples/1a46_protein_processed.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/1cbr_ligand.sdf ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1cbr_ligand
2
+
3
+ Created by X-TOOL on Fri Nov 18 12:01:53 2016
4
+ 49 49 0 0 0 0 0 0 0 0999 V2000
5
+ 5.0920 2.4270 -10.7940 C 0 0 0 1 0 4
6
+ 6.0790 1.2390 -10.8790 C 0 0 0 3 0 4
7
+ 7.4570 1.5880 -11.3400 C 0 0 0 3 0 4
8
+ 8.1090 2.6160 -10.4790 C 0 0 0 3 0 4
9
+ 7.1710 3.7700 -10.1040 C 0 0 0 1 0 3
10
+ 5.8090 3.6640 -10.1590 C 0 0 0 1 0 3
11
+ 4.8670 4.7410 -9.7870 C 0 0 0 2 0 3
12
+ 5.0090 5.6850 -8.8490 C 0 0 0 2 0 3
13
+ 4.0490 6.7120 -8.5120 C 0 0 0 1 0 3
14
+ 4.3830 7.6020 -7.5550 C 0 0 0 2 0 3
15
+ 3.5130 8.6700 -7.1050 C 0 0 0 2 0 3
16
+ 3.9620 9.5090 -6.1670 C 0 0 0 2 0 3
17
+ 3.1640 10.5920 -5.6370 C 0 0 0 1 0 3
18
+ 3.7030 11.3990 -4.6890 C 0 0 0 2 0 3
19
+ 3.0710 12.5430 -4.0160 C 0 5 0 1 0 3
20
+ 3.9070 2.0000 -9.9190 C 0 0 0 4 0 4
21
+ 4.5820 2.7980 -12.2130 C 0 0 0 4 0 4
22
+ 7.9800 4.9390 -9.5360 C 0 0 0 4 0 4
23
+ 2.7160 6.8010 -9.2660 C 0 0 0 4 0 4
24
+ 1.7300 10.7780 -6.1620 C 0 0 0 4 0 4
25
+ 2.5240 13.4330 -4.7040 O 0 0 0 1 0 1
26
+ 3.0900 12.6020 -2.7660 O 0 0 0 1 0 1
27
+ 5.6628 0.5003 -11.5797 H 0 0 0 1 0 1
28
+ 6.1586 0.7905 -9.8778 H 0 0 0 1 0 1
29
+ 7.3965 1.9765 -12.3673 H 0 0 0 1 0 1
30
+ 8.0733 0.6769 -11.3282 H 0 0 0 1 0 1
31
+ 8.9730 3.0290 -11.0202 H 0 0 0 1 0 1
32
+ 8.4536 2.1305 -9.5541 H 0 0 0 1 0 1
33
+ 3.9353 4.7700 -10.3501 H 0 0 0 1 0 1
34
+ 5.9398 5.6789 -8.2837 H 0 0 0 1 0 1
35
+ 5.3651 7.5126 -7.0930 H 0 0 0 1 0 1
36
+ 2.5140 8.7864 -7.5226 H 0 0 0 1 0 1
37
+ 4.9725 9.3712 -5.7852 H 0 0 0 1 0 1
38
+ 4.7256 11.1723 -4.3911 H 0 0 0 1 0 1
39
+ 3.1893 2.8302 -9.8432 H 0 0 0 1 0 1
40
+ 4.2693 1.7357 -8.9146 H 0 0 0 1 0 1
41
+ 3.4124 1.1280 -10.3717 H 0 0 0 1 0 1
42
+ 5.4325 3.1041 -12.8399 H 0 0 0 1 0 1
43
+ 3.8636 3.6277 -12.1392 H 0 0 0 1 0 1
44
+ 4.0887 1.9250 -12.6652 H 0 0 0 1 0 1
45
+ 7.2992 5.7611 -9.2702 H 0 0 0 1 0 1
46
+ 8.6994 5.2884 -10.2913 H 0 0 0 1 0 1
47
+ 8.5226 4.6076 -8.6384 H 0 0 0 1 0 1
48
+ 2.6523 5.9808 -9.9962 H 0 0 0 1 0 1
49
+ 2.6558 7.7653 -9.7917 H 0 0 0 1 0 1
50
+ 1.8841 6.7206 -8.5508 H 0 0 0 1 0 1
51
+ 1.5151 10.0113 -6.9209 H 0 0 0 1 0 1
52
+ 1.6308 11.7769 -6.6117 H 0 0 0 1 0 1
53
+ 1.0187 10.6787 -5.3288 H 0 0 0 1 0 1
54
+ 1 2 1 0 0 1
55
+ 1 6 1 0 0 1
56
+ 1 16 1 0 0 2
57
+ 1 17 1 0 0 2
58
+ 2 3 1 0 0 1
59
+ 3 4 1 0 0 1
60
+ 4 5 1 0 0 1
61
+ 5 6 2 0 0 1
62
+ 5 18 1 0 0 2
63
+ 6 7 1 0 0 2
64
+ 7 8 2 0 0 2
65
+ 8 9 1 0 0 2
66
+ 9 10 2 0 0 2
67
+ 9 19 1 0 0 2
68
+ 10 11 1 0 0 2
69
+ 11 12 2 0 0 2
70
+ 12 13 1 0 0 2
71
+ 13 14 2 0 0 2
72
+ 13 20 1 0 0 2
73
+ 14 15 1 0 0 2
74
+ 15 21 2 0 0 2
75
+ 15 22 2 0 0 2
76
+ 2 23 1 0 0 2
77
+ 2 24 1 0 0 2
78
+ 3 25 1 0 0 2
79
+ 3 26 1 0 0 2
80
+ 4 27 1 0 0 2
81
+ 4 28 1 0 0 2
82
+ 7 29 1 0 0 2
83
+ 8 30 1 0 0 2
84
+ 10 31 1 0 0 2
85
+ 11 32 1 0 0 2
86
+ 12 33 1 0 0 2
87
+ 14 34 1 0 0 2
88
+ 16 35 1 0 0 2
89
+ 16 36 1 0 0 2
90
+ 16 37 1 0 0 2
91
+ 17 38 1 0 0 2
92
+ 17 39 1 0 0 2
93
+ 17 40 1 0 0 2
94
+ 18 41 1 0 0 2
95
+ 18 42 1 0 0 2
96
+ 18 43 1 0 0 2
97
+ 19 44 1 0 0 2
98
+ 19 45 1 0 0 2
99
+ 19 46 1 0 0 2
100
+ 20 47 1 0 0 2
101
+ 20 48 1 0 0 2
102
+ 20 49 1 0 0 2
103
+ M END
104
+ > <MOLECULAR_FORMULA>
105
+ C20H27O2
106
+
107
+ > <MOLECULAR_WEIGHT>
108
+ 299.2
109
+
110
+ > <NUM_HB_ATOMS>
111
+ 2
112
+
113
+ > <NUM_ROTOR>
114
+ 0
115
+
116
+ > <XLOGP2>
117
+ 3.40
118
+
119
+ $$$$
examples/1cbr_protein.pdb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ biopandas==0.4.1
2
+ biopython==1.79
3
+ e3nn==0.5.0
4
+ jinja2==3.1.2
5
+ joblib==1.2.0
6
+ markupsafe==2.1.1
7
+ mpmath==1.2.1
8
+ networkx==2.8.7
9
+ opt-einsum==3.3.0
10
+ opt-einsum-fx==0.1.4
11
+ packaging==21.3
12
+ pandas==1.5.0
13
+ scikit-learn==1.1.2
14
+ scipy==1.9.1
15
+ spyrmsd==0.5.2
16
+ sympy==1.11.1
17
+ spyrmsd==0.5.2
18
+ sympy==1.11.1
19
+ pytorch==1.12.1
20
+ numpy==1.23.1
21
+ torchaudio=0.12.1
22
+ torchvision=0.13.1
23
+ rdkit-pypi==2022.3.5
24
+ torch-scatter
25
+ torch-sparse
26
+ torch-cluster
27
+ torch-spline-conv
28
+ torch-geometric
29
+ -f https://data.pyg.org/whl/torch-1.12.0+cu102.html