Spaces:
Sleeping
Sleeping
File size: 1,465 Bytes
139dd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
model_id = 'a1'
num_classes = 6
num_frames = 8
resolution = 224
batch_size = 32
learning_rate = 0.001
backbone_trainable = True
def build_classifier_with_pretrained_weights(checkpoint_dir: str):
backbone = movinet.Movinet(model_id=model_id)
backbone.trainable = backbone_trainable
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
)
model.build([batch_size, num_frames, resolution, resolution, 3])
return model
def load_classifier(weights_path: str):
backbone = movinet.Movinet(model_id=model_id)
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
)
model.build([1, num_frames, resolution, resolution, 3])
model.load_weights(weights_path)
return model
def compile_classifier(model):
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
return model
|