aero-recognize / core /model.py
chiyoi's picture
Refactor code structure and import configurations
160ded7
raw
history blame
No virus
1.37 kB
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
from configurations import *
def load_backbone():
return movinet.Movinet()
def build_classifier():
backbone = load_backbone()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=600)
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
model.build([batch_size, num_frames, resolution, resolution, 3])
output = keras.layers.Dense(num_classes)
return keras.Sequential(layers=[model, output])
def load_classifier():
backbone = load_backbone()
model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes,
output_states=True)
model.build([batch_size, num_frames, resolution, resolution, 3])
output = keras.layers.Dense(num_classes)
model = keras.Sequential(layers=[model, output])
model.load_weights(model_save_path)
return model
def compile_classifier(model):
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])