Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -136,8 +136,36 @@ def graphs_from_smiles(smiles_list):
|
|
136 |
tf.ragged.constant(bond_features_list, dtype=tf.float32),
|
137 |
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
|
138 |
)
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
def MPNNDataset(X, y, batch_size=32, shuffle=False):
|
142 |
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
|
143 |
if shuffle:
|
|
|
136 |
tf.ragged.constant(bond_features_list, dtype=tf.float32),
|
137 |
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
|
138 |
)
|
139 |
+
|
140 |
+
|
141 |
+
def prepare_batch(x_batch, y_batch):
|
142 |
+
"""Merges (sub)graphs of batch into a single global (disconnected) graph
|
143 |
+
"""
|
144 |
+
|
145 |
+
atom_features, bond_features, pair_indices = x_batch
|
146 |
+
|
147 |
+
# Obtain number of atoms and bonds for each graph (molecule)
|
148 |
+
num_atoms = atom_features.row_lengths()
|
149 |
+
num_bonds = bond_features.row_lengths()
|
150 |
+
|
151 |
+
# Obtain partition indices (molecule_indicator), which will be used to
|
152 |
+
# gather (sub)graphs from global graph in model later on
|
153 |
+
molecule_indices = tf.range(len(num_atoms))
|
154 |
+
molecule_indicator = tf.repeat(molecule_indices, num_atoms)
|
155 |
+
|
156 |
+
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
|
157 |
+
# 'pair_indices' (and merging ragged tensors) actualizes the global graph
|
158 |
+
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
|
159 |
+
increment = tf.cumsum(num_atoms[:-1])
|
160 |
+
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
|
161 |
+
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
162 |
+
pair_indices = pair_indices + increment[:, tf.newaxis]
|
163 |
+
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
164 |
+
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
165 |
+
|
166 |
+
return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
|
167 |
+
|
168 |
+
|
169 |
def MPNNDataset(X, y, batch_size=32, shuffle=False):
|
170 |
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
|
171 |
if shuffle:
|