aero-recognize / core /model.py
chiyoi's picture
fix
d78535a
raw
history blame
No virus
1.36 kB
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
from configurations import *
def load_backbone():
return movinet.Movinet()
def build_classifier():
backbone = load_backbone()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=600)
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
model.build([batch_size, num_frames, resolution, resolution, 3])
output = keras.layers.Dense(num_classes)
return keras.Sequential(layers=[model, output])
def load_classifier(classifier_path):
backbone = load_backbone()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=600)
model.build([batch_size, num_frames, resolution, resolution, 3])
output = keras.layers.Dense(num_classes)
model = keras.Sequential(layers=[model, output])
model.load_weights(classifier_path)
return model
def compile_classifier(model):
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])