eaglelandsonce commited on
Commit
dc52ee1
·
verified ·
1 Parent(s): 16fad40

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +10 -1
pages/19_Graphs3.py CHANGED
@@ -252,11 +252,20 @@ move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
252
 
253
  feature_processors = [process_features, add_readout, move_label_to_readout]
254
 
 
 
 
 
 
 
 
 
 
255
  # Run training
256
  st.write("Training the model...")
257
  runner.run(
258
  task=task,
259
- model_fn=model_fn,
260
  trainer=trainer,
261
  optimizer_fn=optimizer_fn,
262
  epochs=epochs,
 
252
 
253
  feature_processors = [process_features, add_readout, move_label_to_readout]
254
 
255
+ # Function to create the full model with feature processing
256
+ def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
257
+ input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)
258
+ graph = input_graph
259
+ for processor in feature_processors:
260
+ graph = processor(graph)
261
+ output_graph = model_fn(graph.spec)(graph)
262
+ return tf.keras.Model(input_graph, output_graph)
263
+
264
  # Run training
265
  st.write("Training the model...")
266
  runner.run(
267
  task=task,
268
+ model_fn=create_full_model,
269
  trainer=trainer,
270
  optimizer_fn=optimizer_fn,
271
  epochs=epochs,