chiyoi commited on
Commit
c331f57
1 Parent(s): 804b63c
movinet/model.py DELETED
@@ -1,9 +0,0 @@
1
- from site_packages.models.official.projects.movinet.modeling import movinet_model
2
-
3
- def build_classifier(batch_size: int, num_frames: int, resolution: int, backbone, num_classes: int):
4
- model = movinet_model.MovinetClassifier(
5
- backbone=backbone,
6
- num_classes=num_classes,
7
- )
8
- model.build([batch_size, num_frames, resolution, resolution, 3])
9
- return model
 
 
 
 
 
 
 
 
 
 
movinet/scripts/train.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- from pathlib import Path
3
-
4
- import tensorflow as tf
5
- import tf_keras as keras
6
-
7
- from site_packages.models.official.projects.movinet.modeling import movinet
8
-
9
- from movinet.data import frame_generator, total_steps
10
- from movinet.model import build_classifier
11
-
12
- model_id = 'a0'
13
- resolution = 256
14
- batch_size = 8
15
- num_frames = 8
16
- num_classes = 6
17
- model_save_path = "out/aero-recognize-classifier.keras"
18
- num_epochs = 2
19
-
20
- print('Load data.')
21
- data_dir = Path('assets/datasets/Aero')
22
- output_signature = (
23
- tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.float32),
24
- tf.TensorSpec(shape=(), dtype=tf.int16),
25
- )
26
- training_data = tf.data.Dataset.from_generator(
27
- frame_generator(data_dir, num_frames, 'training'),
28
- output_signature=output_signature,
29
- )
30
- training_data = training_data.batch(batch_size)
31
- validation_data = tf.data.Dataset.from_generator(
32
- frame_generator(data_dir, num_frames, 'validation'),
33
- output_signature=output_signature,
34
- )
35
- validation_data = validation_data.batch(batch_size)
36
-
37
- print('Build model.')
38
- backbone = movinet.Movinet(model_id=model_id)
39
- backbone.trainable = True
40
- model = build_classifier(batch_size, num_frames, resolution, backbone, 6)
41
-
42
- print('Start training.')
43
- model_dir = os.path.dirname(model_save_path)
44
- save_model = keras.callbacks.ModelCheckpoint(filepath=model_save_path)
45
- loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
46
- # optimizer = keras.optimizers.legacy.Adam(learning_rate=0.001)
47
- model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
48
- train_steps, validation_steps = total_steps(data_dir)
49
- results = model.fit(
50
- training_data,
51
- steps_per_epoch=train_steps,
52
- validation_data=validation_data,
53
- validation_steps=validation_steps,
54
- epochs=num_epochs,
55
- validation_freq=1,
56
- verbose=1,
57
- callbacks=[save_model],
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
site_packages/models DELETED
@@ -1 +0,0 @@
1
- Subproject commit d14cf43b09cc29d68900bb9f766de19b01acde40