import tensorflow as tf from tensorflow import keras from official.projects.movinet.modeling import movinet from official.projects.movinet.modeling import movinet_model model_id = 'a1' num_classes = 6 num_frames = 8 resolution = 224 batch_size = 32 learning_rate = 0.001 backbone_trainable = True def build_classifier_with_pretrained_weights(checkpoint_dir: str): backbone = movinet.Movinet(model_id=model_id) backbone.trainable = backbone_trainable model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) checkpoint = tf.train.Checkpoint(model=model) status = checkpoint.restore(checkpoint_path) status.assert_existing_objects_matched() model = movinet_model.MovinetClassifier( backbone=backbone, num_classes=num_classes, ) model.build([batch_size, num_frames, resolution, resolution, 3]) return model def load_classifier(weights_path: str): backbone = movinet.Movinet(model_id=model_id) model = movinet_model.MovinetClassifier( backbone=backbone, num_classes=num_classes, ) model.build([1, num_frames, resolution, resolution, 3]) model.load_weights(weights_path) return model def compile_classifier(model): loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) optimizer = keras.optimizers.Adam(learning_rate=learning_rate) model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) return model