Spaces:
Sleeping
Sleeping
File size: 9,883 Bytes
b4be21e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.models import mt_albis
import functools
import os
# Set environment variable for legacy Keras
os.environ['TF_USE_LEGACY_KERAS'] = '1'
# Set Streamlit title
st.title("Solving OGBN-MAG end-to-end with TF-GNN")
# Install necessary packages
st.write("Installing necessary packages...")
st.write("Setting up the environment...")
tf.get_logger().setLevel('ERROR')
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
NUM_TRAINING_SAMPLES = 629571
NUM_VALIDATION_SAMPLES = 64879
GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
# Load the graph schema and graph tensor
st.write("Loading graph schema and tensor...")
graph_schema = tfgnn.read_schema(SCHEMA_FILE)
serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
serialized_ogbn_mag_graph_tensor_string)
st.write("Graph tensor loaded successfully.")
# Define sampling sizes
train_sampling_sizes = {
"cites": 8,
"rev_writes": 8,
"writes": 8,
"affiliated_with": 8,
"has_topic": 8,
}
validation_sample_sizes = train_sampling_sizes.copy()
# Create sampling model
def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:
def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
edge_set_name = sampling_op.edge_set_name
sample_size = sizes[edge_set_name]
return sampler.InMemUniformEdgesSampler.from_graph_tensor(
full_graph_tensor, edge_set_name, sample_size=sample_size
)
def get_features(node_set_name: tfgnn.NodeSetName):
return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
full_graph_tensor, node_set_name
)
# Spell out the sampling procedure in python
sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
seed = sampling_spec_builder.seed("paper")
papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")
sampling_spec = sampling_spec_builder.build()
model = sampler.create_sampling_model_from_spec(
graph_schema, sampling_spec, edge_sampler, get_features,
seed_node_dtype=tf.int64)
return model
# Create the sampling model
st.write("Creating sampling model...")
sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes)
st.write("Sampling model created successfully.")
# Define seed dataset function
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
"""Seed dataset as indices of papers within split years."""
if split_name == "train":
mask = years <= 2017 # 629,571 examples
elif split_name == "validation":
mask = years == 2018 # 64,879 examples
elif split_name == "test":
mask = years == 2019 # 41,939 examples
else:
raise ValueError(f"Unknown split_name: '{split_name}'")
seed_indices = tf.squeeze(tf.where(mask), axis=-1)
return tf.data.Dataset.from_tensor_slices(seed_indices)
# Define SubgraphDatasetProvider
class SubgraphDatasetProvider(runner.DatasetProvider):
"""Dataset Provider based on Sampler V2."""
def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str):
super().__init__()
self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
self._split_name = split_name
self.input_graph_spec = self._sampling_model.output.spec
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
"""Creates TF dataset."""
self._seed_dataset = seed_dataset(self._years, self._split_name)
ds = self._seed_dataset.shard(
num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
if self._split_name == "train":
ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
ds = ds.batch(128)
ds = ds.map(
functools.partial(self.sample),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
return ds.unbatch().prefetch(tf.data.AUTOTUNE)
def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
seeds = tf.cast(seeds, tf.int64)
batch_size = tf.size(seeds)
seeds_ragged = tf.RaggedTensor.from_row_lengths(
seeds, tf.ones([batch_size], tf.int64),
)
return self._sampling_model(seeds_ragged)
# Create dataset providers
st.write("Creating dataset providers...")
train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
st.write("Dataset providers created successfully.")
# Define the model function
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
graph = tfgnn.keras.layers.MapFeatures(
node_sets_fn=set_initial_node_states)(graph)
for i in range(num_graph_updates):
graph = mt_albis.MtAlbisGraphUpdate(
units=node_state_dim,
message_dim=message_dim,
receiver_tag=tfgnn.SOURCE,
node_set_names=None if i < num_graph_updates-1 else ["paper"],
simple_conv_reduce_type="mean|sum",
state_dropout_rate=state_dropout_rate,
l2_regularization=l2_regularization,
normalization_type="layer",
next_state_type="residual",
)(graph)
return tf.keras.Model(inputs, graph)
# Check for TPU/ GPU and set strategy
st.write("Setting up strategy for distributed training...")
if tf.config.list_physical_devices("TPU"):
st.write("Using TPUStrategy")
strategy = runner.TPUStrategy("local")
train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider)
valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider)
elif tf.config.list_physical_devices("GPU"):
st.write("Using MirroredStrategy for GPUs")
strategy = tf.distribute.MirroredStrategy()
train_padding = None
valid_padding = None
else:
st.write("Using default strategy")
strategy = tf.distribute.get_strategy()
train_padding = None
valid_padding = None
st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
# Define task
st.write("Defining the task...")
task = runner.NodeMulticlassClassification(
num_classes=349,
label_feature_name="paper_venue")
# Set hyperparameters
st.write("Setting hyperparameters...")
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate, steps_per_epoch*epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
# Define trainer
st.write("Setting up the trainer...")
trainer = runner.KerasTrainer(
strategy=strategy,
model_dir="/tmp/gnn_model/",
callbacks=None,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
restore_best_weights=False,
checkpoint_every_n_steps="never",
summarize_every_n_steps="never",
backup_and_restore=False,
)
# Define feature processors
st.write("Defining feature processors...")
def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
if node_set_name == "field_of_study":
return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
if node_set_name == "institution":
return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
if node_set_name == "paper":
return {"feat": node_set["feat"], "label": node_set["label"]}
if node_set_name == "author":
return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
def drop_all_features(_, **unused_kwargs):
return {}
process_features = tfgnn.keras.layers.MapFeatures(
context_fn=drop_all_features,
node_sets_fn=process_node_features,
edge_sets_fn=drop_all_features)
add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="paper")
move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
"seed", feature_name="label", new_feature_name="paper_venue", remove_input_feature=True)
feature_processors = [process_features, add_readout, move_label_to_readout]
# Run training
st.write("Training the model...")
runner.run(
task=task,
model_fn=model_fn,
trainer=trainer,
optimizer_fn=optimizer_fn,
epochs=epochs,
global_batch_size=global_batch_size,
train_ds_provider=train_ds_provider,
valid_ds_provider=valid_ds_provider,
gtspec=example_input_graph_spec,
)
st.write("Training completed successfully.")
|