File size: 1,373 Bytes
139dd3e
 
 
 
 
160ded7
139dd3e
160ded7
 
139dd3e
160ded7
 
 
 
 
139dd3e
 
 
 
 
160ded7
 
139dd3e
160ded7
 
139dd3e
 
 
160ded7
 
 
 
 
139dd3e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model

from configurations import *

def load_backbone():
  return movinet.Movinet()

def build_classifier():
  backbone = load_backbone()
  model = movinet_model.MovinetClassifier(
    backbone=backbone,
    num_classes=600)
  checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
  checkpoint = tf.train.Checkpoint(model=model)
  status = checkpoint.restore(checkpoint_path)
  status.assert_existing_objects_matched()
  model.build([batch_size, num_frames, resolution, resolution, 3])
  output = keras.layers.Dense(num_classes)
  return keras.Sequential(layers=[model, output])

def load_classifier():
  backbone = load_backbone()
  model = movinet_model.MovinetClassifier(
    backbone=backbone,
    num_classes=num_classes,
    output_states=True)
  model.build([batch_size, num_frames, resolution, resolution, 3])
  output = keras.layers.Dense(num_classes)
  model = keras.Sequential(layers=[model, output])
  model.load_weights(model_save_path)
  return model

def compile_classifier(model):
  loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])