import tensorflow as tf from tensorflow import keras from official.projects.movinet.modeling import movinet from official.projects.movinet.modeling import movinet_model from configurations import * def load_backbone(): return movinet.Movinet() def build_classifier(): backbone = load_backbone() model = movinet_model.MovinetClassifier( backbone=backbone, num_classes=600) checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) checkpoint = tf.train.Checkpoint(model=model) status = checkpoint.restore(checkpoint_path) status.assert_existing_objects_matched() model.build([batch_size, num_frames, resolution, resolution, 3]) output = keras.layers.Dense(num_classes) return keras.Sequential(layers=[model, output]) def load_classifier(): backbone = load_backbone() model = movinet_model.MovinetClassifier( backbone=backbone, num_classes=num_classes, output_states=True) model.build([batch_size, num_frames, resolution, resolution, 3]) output = keras.layers.Dense(num_classes) model = keras.Sequential(layers=[model, output]) model.load_weights(model_save_path) return model def compile_classifier(model): loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) optimizer = keras.optimizers.Adam(learning_rate=learning_rate) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])