Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Update pages/19_Graphs3.py
Browse files- pages/19_Graphs3.py +10 -1
pages/19_Graphs3.py
CHANGED
@@ -252,11 +252,20 @@ move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
|
|
252 |
|
253 |
feature_processors = [process_features, add_readout, move_label_to_readout]
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
# Run training
|
256 |
st.write("Training the model...")
|
257 |
runner.run(
|
258 |
task=task,
|
259 |
-
model_fn=
|
260 |
trainer=trainer,
|
261 |
optimizer_fn=optimizer_fn,
|
262 |
epochs=epochs,
|
|
|
252 |
|
253 |
feature_processors = [process_features, add_readout, move_label_to_readout]
|
254 |
|
255 |
+
# Function to create the full model with feature processing
|
256 |
+
def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
257 |
+
input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
258 |
+
graph = input_graph
|
259 |
+
for processor in feature_processors:
|
260 |
+
graph = processor(graph)
|
261 |
+
output_graph = model_fn(graph.spec)(graph)
|
262 |
+
return tf.keras.Model(input_graph, output_graph)
|
263 |
+
|
264 |
# Run training
|
265 |
st.write("Training the model...")
|
266 |
runner.run(
|
267 |
task=task,
|
268 |
+
model_fn=create_full_model,
|
269 |
trainer=trainer,
|
270 |
optimizer_fn=optimizer_fn,
|
271 |
epochs=epochs,
|