File size: 1,465 Bytes
139dd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import tensorflow as tf
from tensorflow import keras
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model

model_id = 'a1'
num_classes = 6
num_frames = 8
resolution = 224

batch_size = 32
learning_rate = 0.001
backbone_trainable = True

def build_classifier_with_pretrained_weights(checkpoint_dir: str):
  backbone = movinet.Movinet(model_id=model_id)
  backbone.trainable = backbone_trainable
  model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
  checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
  checkpoint = tf.train.Checkpoint(model=model)
  status = checkpoint.restore(checkpoint_path)
  status.assert_existing_objects_matched()
  model = movinet_model.MovinetClassifier(
    backbone=backbone,
    num_classes=num_classes,
  )
  model.build([batch_size, num_frames, resolution, resolution, 3])
  return model

def load_classifier(weights_path: str):
  backbone = movinet.Movinet(model_id=model_id)
  model = movinet_model.MovinetClassifier(
    backbone=backbone,
    num_classes=num_classes,
  )
  model.build([1, num_frames, resolution, resolution, 3])
  model.load_weights(weights_path)
  return model

def compile_classifier(model):
  loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
  model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
  return model