vumichien commited on
Commit
e553806
1 Parent(s): 57bd5c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -136,8 +136,36 @@ def graphs_from_smiles(smiles_list):
136
  tf.ragged.constant(bond_features_list, dtype=tf.float32),
137
  tf.ragged.constant(pair_indices_list, dtype=tf.int64),
138
  )
139
-
140
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def MPNNDataset(X, y, batch_size=32, shuffle=False):
142
  dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
143
  if shuffle:
 
136
  tf.ragged.constant(bond_features_list, dtype=tf.float32),
137
  tf.ragged.constant(pair_indices_list, dtype=tf.int64),
138
  )
139
+
140
+
141
+ def prepare_batch(x_batch, y_batch):
142
+ """Merges (sub)graphs of batch into a single global (disconnected) graph
143
+ """
144
+
145
+ atom_features, bond_features, pair_indices = x_batch
146
+
147
+ # Obtain number of atoms and bonds for each graph (molecule)
148
+ num_atoms = atom_features.row_lengths()
149
+ num_bonds = bond_features.row_lengths()
150
+
151
+ # Obtain partition indices (molecule_indicator), which will be used to
152
+ # gather (sub)graphs from global graph in model later on
153
+ molecule_indices = tf.range(len(num_atoms))
154
+ molecule_indicator = tf.repeat(molecule_indices, num_atoms)
155
+
156
+ # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
157
+ # 'pair_indices' (and merging ragged tensors) actualizes the global graph
158
+ gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
159
+ increment = tf.cumsum(num_atoms[:-1])
160
+ increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
161
+ pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
162
+ pair_indices = pair_indices + increment[:, tf.newaxis]
163
+ atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
164
+ bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
165
+
166
+ return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
167
+
168
+
169
  def MPNNDataset(X, y, batch_size=32, shuffle=False):
170
  dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
171
  if shuffle: