Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,11 @@ import numpy as np
|
|
6 |
import tensorflow as tf
|
7 |
from tensorflow import keras
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
RDLogger.DisableLog("rdApp.*")
|
10 |
|
11 |
def graph_to_molecule(graph):
|
@@ -50,7 +55,7 @@ generator = from_pretrained_keras("keras-io/wgan-molecular-graphs")
|
|
50 |
|
51 |
def predict(num_mol):
|
52 |
samples = num_mol*2
|
53 |
-
z = tf.random.normal((samples,
|
54 |
graph = generator.predict(z)
|
55 |
# obtain one-hot encoded adjacency tensor
|
56 |
adjacency = tf.argmax(graph[0], axis=1)
|
@@ -59,7 +64,7 @@ def predict(num_mol):
|
|
59 |
adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
|
60 |
# obtain one-hot encoded feature tensor
|
61 |
features = tf.argmax(graph[1], axis=2)
|
62 |
-
features = tf.one_hot(features, depth=
|
63 |
molecules = [
|
64 |
graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
|
65 |
for i in range(samples)
|
|
|
6 |
import tensorflow as tf
|
7 |
from tensorflow import keras
|
8 |
|
9 |
+
# Config
|
10 |
+
NUM_ATOMS = 9 # Maximum number of atoms
|
11 |
+
ATOM_DIM = 4 + 1 # Number of atom types
|
12 |
+
BOND_DIM = 4 + 1 # Number of bond types
|
13 |
+
LATENT_DIM = 64 # Size of the latent space
|
14 |
RDLogger.DisableLog("rdApp.*")
|
15 |
|
16 |
def graph_to_molecule(graph):
|
|
|
55 |
|
56 |
def predict(num_mol):
|
57 |
samples = num_mol*2
|
58 |
+
z = tf.random.normal((samples, LATENT_DIM))
|
59 |
graph = generator.predict(z)
|
60 |
# obtain one-hot encoded adjacency tensor
|
61 |
adjacency = tf.argmax(graph[0], axis=1)
|
|
|
64 |
adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
|
65 |
# obtain one-hot encoded feature tensor
|
66 |
features = tf.argmax(graph[1], axis=2)
|
67 |
+
features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
|
68 |
molecules = [
|
69 |
graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
|
70 |
for i in range(samples)
|