aero-recognize / core /model.py
chiyoi's picture
update
139dd3e
raw
history blame
No virus
1.47 kB
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
model_id = 'a1'
num_classes = 6
num_frames = 8
resolution = 224
batch_size = 32
learning_rate = 0.001
backbone_trainable = True
def build_classifier_with_pretrained_weights(checkpoint_dir: str):
backbone = movinet.Movinet(model_id=model_id)
backbone.trainable = backbone_trainable
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
)
model.build([batch_size, num_frames, resolution, resolution, 3])
return model
def load_classifier(weights_path: str):
backbone = movinet.Movinet(model_id=model_id)
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
)
model.build([1, num_frames, resolution, resolution, 3])
model.load_weights(weights_path)
return model
def compile_classifier(model):
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
return model