Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from tensorflow import keras | |
from official.projects.movinet.modeling import movinet | |
from official.projects.movinet.modeling import movinet_model | |
model_id = 'a1' | |
num_classes = 6 | |
num_frames = 8 | |
resolution = 224 | |
batch_size = 32 | |
learning_rate = 0.001 | |
backbone_trainable = True | |
def build_classifier_with_pretrained_weights(checkpoint_dir: str): | |
backbone = movinet.Movinet(model_id=model_id) | |
backbone.trainable = backbone_trainable | |
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) | |
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) | |
checkpoint = tf.train.Checkpoint(model=model) | |
status = checkpoint.restore(checkpoint_path) | |
status.assert_existing_objects_matched() | |
model = movinet_model.MovinetClassifier( | |
backbone=backbone, | |
num_classes=num_classes, | |
) | |
model.build([batch_size, num_frames, resolution, resolution, 3]) | |
return model | |
def load_classifier(weights_path: str): | |
backbone = movinet.Movinet(model_id=model_id) | |
model = movinet_model.MovinetClassifier( | |
backbone=backbone, | |
num_classes=num_classes, | |
) | |
model.build([1, num_frames, resolution, resolution, 3]) | |
model.load_weights(weights_path) | |
return model | |
def compile_classifier(model): | |
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
optimizer = keras.optimizers.Adam(learning_rate=learning_rate) | |
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) | |
return model | |