igashov commited on
Commit
88b37fb
1 Parent(s): c2d5999

updated code

Browse files
app.py CHANGED
@@ -35,12 +35,16 @@ MODELS_METADATA = {
35
  'path': 'models/geom_difflinker_given_anchors.ckpt',
36
  },
37
  'pockets_difflinker': {
38
- 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
39
- 'path': 'models/pockets_difflinker.ckpt',
 
 
40
  },
41
  'pockets_difflinker_given_anchors': {
42
- 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
43
- 'path': 'models/pockets_difflinker_given_anchors.ckpt',
 
 
44
  },
45
  }
46
 
 
35
  'path': 'models/geom_difflinker_given_anchors.ckpt',
36
  },
37
  'pockets_difflinker': {
38
+ # 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
39
+ # 'path': 'models/pockets_difflinker.ckpt',
40
+ 'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt?download=1',
41
+ 'path': 'models/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt',
42
  },
43
  'pockets_difflinker_given_anchors': {
44
+ # 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
45
+ # 'path': 'models/pockets_difflinker_given_anchors.ckpt',
46
+ 'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_fc_pdb_excluded.ckpt?download=1',
47
+ 'path': 'models/pockets_difflinker_full_fc_pdb_excluded.ckpt',
48
  },
49
  }
50
 
src/datasets.py CHANGED
@@ -148,6 +148,15 @@ class MOADDataset(Dataset):
148
  total=len(table)
149
  )
150
  for (_, row), fragments, linker, pocket_data in generator:
 
 
 
 
 
 
 
 
 
151
  uuid = row['uuid']
152
  name = row['molecule']
153
  frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
@@ -212,16 +221,112 @@ class MOADDataset(Dataset):
212
 
213
  return data
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  @staticmethod
216
- def create_edges(positions, fragment_mask_only, linker_mask_only):
217
- ligand_mask = fragment_mask_only.astype(bool) | linker_mask_only.astype(bool)
218
- ligand_adj = ligand_mask[:, None] & ligand_mask[None, :]
219
- proximity_adj = np.linalg.norm(positions[:, None, :] - positions[None, :, :], axis=-1) <= 6
220
- full_adj = ligand_adj | proximity_adj
221
- full_adj &= ~np.eye(len(positions)).astype(bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- curr_rows, curr_cols = np.where(full_adj)
224
- return [curr_rows, curr_cols]
 
 
225
 
226
 
227
  def collate(batch):
@@ -231,7 +336,7 @@ def collate(batch):
231
  # if 'pocket_mask' not in batch[0].keys():
232
  # batch = [data for data in batch if data['num_atoms'] <= 50]
233
  # else:
234
- # batch = [data for data in batch if data['num_atoms'] <= 1000]
235
 
236
  for i, data in enumerate(batch):
237
  for key, value in data.items():
 
148
  total=len(table)
149
  )
150
  for (_, row), fragments, linker, pocket_data in generator:
151
+ pdb = row['molecule_name'].split('_')[0]
152
+ if pdb in {
153
+ '5ou2', '5ou3', '6hay',
154
+ '5mo8', '5mo5', '5mo7', '5ctp', '5cu2', '5cu4', '5mmr', '5mmf',
155
+ '5moe', '3iw7', '4i9n', '3fi2', '3fi3',
156
+ }:
157
+ print(f'Skipping pdb={pdb}')
158
+ continue
159
+
160
  uuid = row['uuid']
161
  name = row['molecule']
162
  frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
 
221
 
222
  return data
223
 
224
+
225
+ class OptimisedMOADDataset(MOADDataset):
226
+ # TODO: finish testing
227
+
228
+ def __len__(self):
229
+ return len(self.data['fragmentation_level_data'])
230
+
231
+ def __getitem__(self, item):
232
+ fragmentation_level_data = self.data['fragmentation_level_data'][item]
233
+ protein_level_data = self.data['protein_level_data'][fragmentation_level_data['name']]
234
+ return {
235
+ **fragmentation_level_data,
236
+ **protein_level_data,
237
+ }
238
+
239
  @staticmethod
240
+ def preprocess(data_path, prefix, pocket_mode, device):
241
+ print('Preprocessing optimised version of the dataset')
242
+ protein_level_data = {}
243
+ fragmentation_level_data = []
244
+
245
+ table_path = os.path.join(data_path, f'{prefix}_table.csv')
246
+ fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
247
+ linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
248
+ pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl')
249
+
250
+ is_geom = True
251
+ is_multifrag = 'multifrag' in prefix
252
+
253
+ with open(pockets_path, 'rb') as f:
254
+ pockets = pickle.load(f)
255
+
256
+ table = pd.read_csv(table_path)
257
+ generator = tqdm(
258
+ zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets),
259
+ total=len(table)
260
+ )
261
+ for (_, row), fragments, linker, pocket_data in generator:
262
+ uuid = row['uuid']
263
+ name = row['molecule']
264
+ frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
265
+ link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
266
+
267
+ # Parsing pocket data
268
+ pocket_pos = pocket_data[f'{pocket_mode}_coord']
269
+ pocket_one_hot = []
270
+ pocket_charges = []
271
+ for atom_type in pocket_data[f'{pocket_mode}_types']:
272
+ pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
273
+ pocket_charges.append(const.GEOM_CHARGES[atom_type])
274
+ pocket_one_hot = np.array(pocket_one_hot)
275
+ pocket_charges = np.array(pocket_charges)
276
+
277
+ positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0)
278
+ one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0)
279
+ charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0)
280
+ anchors = np.zeros_like(charges)
281
+
282
+ if is_multifrag:
283
+ for anchor_idx in map(int, row['anchors'].split('-')):
284
+ anchors[anchor_idx] = 1
285
+ else:
286
+ anchors[row['anchor_1']] = 1
287
+ anchors[row['anchor_2']] = 1
288
+
289
+ fragment_only_mask = np.concatenate([
290
+ np.ones_like(frag_charges),
291
+ np.zeros_like(pocket_charges),
292
+ np.zeros_like(link_charges)
293
+ ])
294
+ pocket_mask = np.concatenate([
295
+ np.zeros_like(frag_charges),
296
+ np.ones_like(pocket_charges),
297
+ np.zeros_like(link_charges)
298
+ ])
299
+ linker_mask = np.concatenate([
300
+ np.zeros_like(frag_charges),
301
+ np.zeros_like(pocket_charges),
302
+ np.ones_like(link_charges)
303
+ ])
304
+ fragment_mask = np.concatenate([
305
+ np.ones_like(frag_charges),
306
+ np.ones_like(pocket_charges),
307
+ np.zeros_like(link_charges)
308
+ ])
309
+
310
+ fragmentation_level_data.append({
311
+ 'uuid': uuid,
312
+ 'name': name,
313
+ 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
314
+ 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
315
+ 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
316
+ 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
317
+ 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
318
+ })
319
+ protein_level_data[name] = {
320
+ 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
321
+ 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
322
+ 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
323
+ 'num_atoms': len(positions),
324
+ }
325
 
326
+ return {
327
+ 'fragmentation_level_data': fragmentation_level_data,
328
+ 'protein_level_data': protein_level_data,
329
+ }
330
 
331
 
332
  def collate(batch):
 
336
  # if 'pocket_mask' not in batch[0].keys():
337
  # batch = [data for data in batch if data['num_atoms'] <= 50]
338
  # else:
339
+ # batch = [data for data in batch if data['num_atoms'] <= 1000]
340
 
341
  for i, data in enumerate(batch):
342
  for key, value in data.items():
src/egnn.py CHANGED
@@ -315,7 +315,7 @@ class Dynamics(nn.Module):
315
  self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
316
  n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
317
  sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
318
- normalization=None, centering=False,
319
  ):
320
  super().__init__()
321
  self.device = device
@@ -324,6 +324,7 @@ class Dynamics(nn.Module):
324
  self.condition_time = condition_time
325
  self.model = model
326
  self.centering = centering
 
327
 
328
  in_node_nf = in_node_nf + context_node_nf + condition_time
329
  if self.model == 'egnn_dynamics':
@@ -369,6 +370,8 @@ class Dynamics(nn.Module):
369
  - context: (B, N, C)
370
  """
371
 
 
 
372
  bs, n_nodes = xh.shape[0], xh.shape[1]
373
  edges = self.get_edges(n_nodes, bs) # (2, B*N)
374
  node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
@@ -421,16 +424,6 @@ class Dynamics(nn.Module):
421
  if self.condition_time:
422
  h_final = h_final[:, :-1]
423
 
424
- if torch.any(torch.isnan(vel)):
425
- print('Found NaN values in velocities')
426
- nan_mask = torch.isnan(vel).float()
427
- vel = x * nan_mask + torch.nan_to_num(vel) * (1 - nan_mask)
428
-
429
- if torch.any(torch.isnan(h_final)):
430
- print('Found NaN values in features')
431
- nan_mask = torch.isnan(h_final).float()
432
- h_final = h[:, :h_final.shape[1]] * nan_mask + torch.nan_to_num(h_final) * (1 - nan_mask)
433
-
434
  vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
435
  h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
436
  node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
@@ -477,12 +470,21 @@ class DynamicsWithPockets(Dynamics):
477
  if linker_mask is not None:
478
  linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
479
 
 
 
 
 
480
  # Reshaping node features & adding time feature
481
  xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
482
  x = xh[:, :self.n_dims].clone() # (B*N, 3)
483
  h = xh[:, self.n_dims:].clone() # (B*N, nf)
484
 
485
- edges = self.get_dist_edges(x, node_mask, edge_mask)
 
 
 
 
 
486
  if self.condition_time:
487
  if np.prod(t.size()) == 1:
488
  # t is the same for all elements in batch.
@@ -537,7 +539,7 @@ class DynamicsWithPockets(Dynamics):
537
  return torch.cat([vel, h_final], dim=2)
538
 
539
  @staticmethod
540
- def get_dist_edges(x, node_mask, batch_mask):
541
  node_mask = node_mask.squeeze().bool()
542
  batch_adj = (batch_mask[:, None] == batch_mask[None, :])
543
  nodes_adj = (node_mask[:, None] & node_mask[None, :])
@@ -546,3 +548,36 @@ class DynamicsWithPockets(Dynamics):
546
  adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
547
  edges = torch.stack(torch.where(adj))
548
  return edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
316
  n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
317
  sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
318
+ normalization=None, centering=False, graph_type='FC',
319
  ):
320
  super().__init__()
321
  self.device = device
 
324
  self.condition_time = condition_time
325
  self.model = model
326
  self.centering = centering
327
+ self.graph_type = graph_type
328
 
329
  in_node_nf = in_node_nf + context_node_nf + condition_time
330
  if self.model == 'egnn_dynamics':
 
370
  - context: (B, N, C)
371
  """
372
 
373
+ assert self.graph_type == 'FC'
374
+
375
  bs, n_nodes = xh.shape[0], xh.shape[1]
376
  edges = self.get_edges(n_nodes, bs) # (2, B*N)
377
  node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
 
424
  if self.condition_time:
425
  h_final = h_final[:, :-1]
426
 
 
 
 
 
 
 
 
 
 
 
427
  vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
428
  h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
429
  node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
 
470
  if linker_mask is not None:
471
  linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
472
 
473
+ fragment_only_mask = context[..., -2].view(bs * n_nodes, 1) # (B*N, 1)
474
+ pocket_only_mask = context[..., -1].view(bs * n_nodes, 1) # (B*N, 1)
475
+ assert torch.all(fragment_only_mask.bool() | pocket_only_mask.bool() | linker_mask.bool() == node_mask.bool())
476
+
477
  # Reshaping node features & adding time feature
478
  xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
479
  x = xh[:, :self.n_dims].clone() # (B*N, 3)
480
  h = xh[:, self.n_dims:].clone() # (B*N, nf)
481
 
482
+ assert self.graph_type in ['4A', 'FC-4A', 'FC-10A-4A']
483
+ if self.graph_type == '4A' or self.graph_type is None:
484
+ edges = self.get_dist_edges_4A(x, node_mask, edge_mask)
485
+ else:
486
+ edges = self.get_dist_edges(x, node_mask, edge_mask, linker_mask, fragment_only_mask, pocket_only_mask)
487
+
488
  if self.condition_time:
489
  if np.prod(t.size()) == 1:
490
  # t is the same for all elements in batch.
 
539
  return torch.cat([vel, h_final], dim=2)
540
 
541
  @staticmethod
542
+ def get_dist_edges_4A(x, node_mask, batch_mask):
543
  node_mask = node_mask.squeeze().bool()
544
  batch_adj = (batch_mask[:, None] == batch_mask[None, :])
545
  nodes_adj = (node_mask[:, None] & node_mask[None, :])
 
548
  adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
549
  edges = torch.stack(torch.where(adj))
550
  return edges
551
+
552
+ def get_dist_edges(self, x, node_mask, batch_mask, linker_mask, fragment_only_mask, pocket_only_mask):
553
+ node_mask = node_mask.squeeze().bool()
554
+ linker_mask = linker_mask.squeeze().bool() & node_mask
555
+ fragment_only_mask = fragment_only_mask.squeeze().bool() & node_mask
556
+ pocket_only_mask = pocket_only_mask.squeeze().bool() & node_mask
557
+ ligand_mask = linker_mask | fragment_only_mask
558
+
559
+ # General constrains:
560
+ batch_adj = (batch_mask[:, None] == batch_mask[None, :])
561
+ nodes_adj = (node_mask[:, None] & node_mask[None, :])
562
+ rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
563
+ constraints = batch_adj & nodes_adj & rm_self_loops
564
+
565
+ # Ligand atoms – fully-connected graph
566
+ ligand_adj = (ligand_mask[:, None] & ligand_mask[None, :])
567
+ ligand_interactions = ligand_adj & constraints
568
+
569
+ # Pocket atoms - within 4A
570
+ pocket_adj = (pocket_only_mask[:, None] & pocket_only_mask[None, :])
571
+ pocket_dists_adj = (torch.cdist(x, x) <= 4)
572
+ pocket_interactions = pocket_adj & pocket_dists_adj & constraints
573
+
574
+ # Pocket-ligand atoms - within 10A
575
+ pocket_ligand_cutoff = 4 if self.graph_type == 'FC-4A' else 10
576
+ pocket_ligand_adj = (ligand_mask[:, None] & pocket_only_mask[None, :])
577
+ pocket_ligand_adj = pocket_ligand_adj | (pocket_only_mask[:, None] & ligand_mask[None, :])
578
+ pocket_ligand_dists_adj = (torch.cdist(x, x) <= pocket_ligand_cutoff)
579
+ pocket_ligand_interactions = pocket_ligand_adj & pocket_ligand_dists_adj & constraints
580
+
581
+ adj = ligand_interactions | pocket_interactions | pocket_ligand_interactions
582
+ edges = torch.stack(torch.where(adj))
583
+ return edges
src/lightning.py CHANGED
@@ -44,7 +44,7 @@ class DDPM(pl.LightningModule):
44
  normalize_factors, include_charges, model,
45
  data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
46
  normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
47
- center_of_mass='fragments', inpainting=False, anchors_context=True,
48
  ):
49
  super(DDPM, self).__init__()
50
 
@@ -54,7 +54,7 @@ class DDPM(pl.LightningModule):
54
  self.val_data_prefix = val_data_prefix
55
  self.batch_size = batch_size
56
  self.lr = lr
57
- self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  self.include_charges = include_charges
59
  self.test_epochs = test_epochs
60
  self.n_stability_samples = n_stability_samples
@@ -72,6 +72,9 @@ class DDPM(pl.LightningModule):
72
 
73
  self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
74
 
 
 
 
75
  if type(activation) is str:
76
  activation = get_activation(activation)
77
 
@@ -80,7 +83,7 @@ class DDPM(pl.LightningModule):
80
  in_node_nf=in_node_nf,
81
  n_dims=n_dims,
82
  context_node_nf=context_node_nf,
83
- device=self.torch_device,
84
  hidden_nf=hidden_nf,
85
  activation=activation,
86
  n_layers=n_layers,
@@ -94,6 +97,7 @@ class DDPM(pl.LightningModule):
94
  model=model,
95
  normalization=normalization,
96
  centering=inpainting,
 
97
  )
98
  edm_class = InpaintingEDM if inpainting else EDM
99
  self.edm = edm_class(
@@ -424,7 +428,7 @@ class DDPM(pl.LightningModule):
424
  context = fragment_mask
425
 
426
  # Add information about pocket to the context
427
- if isinstance(self.val_dataset, MOADDataset):
428
  fragment_pocket_mask = fragment_mask
429
  fragment_only_mask = template_data['fragment_only_mask']
430
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
 
44
  normalize_factors, include_charges, model,
45
  data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
46
  normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
47
+ center_of_mass='fragments', inpainting=False, anchors_context=True, graph_type=None,
48
  ):
49
  super(DDPM, self).__init__()
50
 
 
54
  self.val_data_prefix = val_data_prefix
55
  self.batch_size = batch_size
56
  self.lr = lr
57
+ self.torch_device = torch_device
58
  self.include_charges = include_charges
59
  self.test_epochs = test_epochs
60
  self.n_stability_samples = n_stability_samples
 
72
 
73
  self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
74
 
75
+ if graph_type is None:
76
+ graph_type = '4A' if '.' in train_data_prefix else 'FC'
77
+
78
  if type(activation) is str:
79
  activation = get_activation(activation)
80
 
 
83
  in_node_nf=in_node_nf,
84
  n_dims=n_dims,
85
  context_node_nf=context_node_nf,
86
+ device=torch_device,
87
  hidden_nf=hidden_nf,
88
  activation=activation,
89
  n_layers=n_layers,
 
97
  model=model,
98
  normalization=normalization,
99
  centering=inpainting,
100
+ graph_type=graph_type,
101
  )
102
  edm_class = InpaintingEDM if inpainting else EDM
103
  self.edm = edm_class(
 
428
  context = fragment_mask
429
 
430
  # Add information about pocket to the context
431
+ if '.' in self.train_data_prefix:
432
  fragment_pocket_mask = fragment_mask
433
  fragment_only_mask = template_data['fragment_only_mask']
434
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
src/linker_size.py CHANGED
@@ -21,10 +21,6 @@ class DistributionNodes:
21
  prob = prob/np.sum(prob)
22
 
23
  self.prob = torch.from_numpy(prob).float()
24
-
25
- entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
26
- print("Entropy of n_nodes: H[N]", entropy.item())
27
-
28
  self.m = Categorical(torch.tensor(prob))
29
 
30
  def sample(self, n_samples=1):
 
21
  prob = prob/np.sum(prob)
22
 
23
  self.prob = torch.from_numpy(prob).float()
 
 
 
 
24
  self.m = Categorical(torch.tensor(prob))
25
 
26
  def sample(self, n_samples=1):
src/linker_size_lightning.py CHANGED
@@ -40,6 +40,7 @@ class SizeClassifier(pl.LightningModule):
40
  self.lr = lr
41
  self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
 
43
  self.gnn = SizeGNN(
44
  in_node_nf=in_node_nf,
45
  hidden_nf=hidden_nf,
@@ -79,7 +80,7 @@ class SizeClassifier(pl.LightningModule):
79
  def test_dataloader(self):
80
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
81
 
82
- def forward(self, data, return_loss=True, with_pocket=False):
83
  h = data['one_hot']
84
  x = data['positions']
85
  fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
@@ -91,6 +92,10 @@ class SizeClassifier(pl.LightningModule):
91
  x = x * fragment_mask
92
  h = h * fragment_mask
93
 
 
 
 
 
94
  # Reshaping
95
  bs, n_nodes = x.shape[0], x.shape[1]
96
  fragment_mask = fragment_mask.view(bs * n_nodes, 1)
 
40
  self.lr = lr
41
  self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
43
+ self.in_node_nf = in_node_nf
44
  self.gnn = SizeGNN(
45
  in_node_nf=in_node_nf,
46
  hidden_nf=hidden_nf,
 
80
  def test_dataloader(self):
81
  return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
82
 
83
+ def forward(self, data, return_loss=True, with_pocket=False, adjust_shape=False):
84
  h = data['one_hot']
85
  x = data['positions']
86
  fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
 
92
  x = x * fragment_mask
93
  h = h * fragment_mask
94
 
95
+ if h.shape[-1] != self.in_node_nf and adjust_shape:
96
+ assert torch.allclose(h[..., -1], torch.zeros_like(h[..., -1]))
97
+ h = h[..., :-1]
98
+
99
  # Reshaping
100
  bs, n_nodes = x.shape[0], x.shape[1]
101
  fragment_mask = fragment_mask.view(bs * n_nodes, 1)
src/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import sys
 
2
  from datetime import datetime
3
 
4
  import torch
@@ -21,9 +22,11 @@ class Logger(object):
21
  # you might want to specify some extra behavior here.
22
  pass
23
 
 
24
  def log(*args):
25
  print(f'[{datetime.now()}]', *args)
26
 
 
27
  class EMA:
28
  def __init__(self, beta):
29
  super().__init__()
@@ -257,6 +260,17 @@ def disable_rdkit_logging():
257
  rkrb.DisableLog('rdApp.error')
258
 
259
 
 
 
 
 
 
 
 
 
 
 
 
260
  class FoundNaNException(Exception):
261
  def __init__(self, x, h):
262
  x_nan_idx = self.find_nan_idx(x)
 
1
  import sys
2
+ import random
3
  from datetime import datetime
4
 
5
  import torch
 
22
  # you might want to specify some extra behavior here.
23
  pass
24
 
25
+
26
  def log(*args):
27
  print(f'[{datetime.now()}]', *args)
28
 
29
+
30
  class EMA:
31
  def __init__(self, beta):
32
  super().__init__()
 
260
  rkrb.DisableLog('rdApp.error')
261
 
262
 
263
+ def set_deterministic(seed):
264
+ random.seed(seed)
265
+ np.random.seed(seed)
266
+ torch.manual_seed(seed)
267
+ if torch.cuda.is_available():
268
+ torch.cuda.manual_seed_all(seed)
269
+
270
+ torch.backends.cudnn.deterministic = True
271
+ torch.backends.cudnn.benchmark = False
272
+
273
+
274
  class FoundNaNException(Exception):
275
  def __init__(self, x, h):
276
  x_nan_idx = self.find_nan_idx(x)