File size: 16,242 Bytes
fd4b932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Train the deeplabv3p model for your own dataset.
"""
import os, sys, argparse, time
import warnings
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TerminateOnNaN

from deeplabv3p.model import get_deeplabv3p_model
from unet.model import get_unet_model
from fast_scnn.model import get_fast_scnn_model
from deeplabv3p.data import SegmentationGenerator
from deeplabv3p.loss import sparse_crossentropy, softmax_focal_loss, WeightedSparseCategoricalCrossEntropy
from deeplabv3p.metrics import Jaccard#, sparse_accuracy_ignoring_last_label
from common.utils import get_classes, get_data_list, optimize_tf_gpu, calculate_weigths_labels, load_class_weights
from common.model_utils import get_optimizer
from common.callbacks import EvalCallBack

# Try to enable Auto Mixed Precision on TF 2.0
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
optimize_tf_gpu(tf, K)


def main(args):
    log_dir = 'logs/000/'
    # get class info, add background class to match model & GT
    class_names = get_classes(args.classes_path)
    assert len(class_names) < 254, 'PNG image label only support less than 254 classes.'
    class_names = ['background'] + class_names
    num_classes = len(class_names)

    # callbacks for training process
    monitor = 'Jaccard'

    tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=False, write_grads=False, write_images=False, update_freq='batch')
    checkpoint = ModelCheckpoint(os.path.join(log_dir, 'ep{epoch:03d}-loss{loss:.3f}-Jaccard{Jaccard:.3f}-val_loss{val_loss:.3f}-val_Jaccard{val_Jaccard:.3f}.h5'),
        monitor='val_{}'.format(monitor),
        mode='max',
        verbose=1,
        save_weights_only=False,
        save_best_only=True,
        period=1)

    reduce_lr = ReduceLROnPlateau(monitor='val_{}'.format(monitor), factor=0.5, mode='max',
                patience=5, verbose=1, cooldown=0, min_lr=1e-6)
    early_stopping = EarlyStopping(monitor='val_{}'.format(monitor), min_delta=0, patience=100, verbose=1, mode='max')
    terminate_on_nan = TerminateOnNaN()

    callbacks=[tensorboard, checkpoint, reduce_lr, early_stopping, terminate_on_nan]


    # get train&val dataset
    dataset = get_data_list(args.dataset_file)
    if args.val_dataset_file:
        val_dataset = get_data_list(args.val_dataset_file)
        num_train = len(dataset)
        num_val = len(val_dataset)
        dataset.extend(val_dataset)
    else:
        val_split = args.val_split
        num_val = int(len(dataset)*val_split)
        num_train = len(dataset) - num_val

    # prepare train&val data generator
    train_generator = SegmentationGenerator(args.dataset_path, dataset[:num_train],
                                            args.batch_size,
                                            num_classes,
                                            target_size=args.model_input_shape[::-1],
                                            weighted_type=args.weighted_type,
                                            is_eval=False,
                                            augment=True)

    valid_generator = SegmentationGenerator(args.dataset_path, dataset[num_train:],
                                            args.batch_size,
                                            num_classes,
                                            target_size=args.model_input_shape[::-1],
                                            weighted_type=args.weighted_type,
                                            is_eval=False,
                                            augment=False)

    # prepare online evaluation callback
    if args.eval_online:
        eval_callback = EvalCallBack(args.dataset_path, dataset[num_train:], class_names, args.model_input_shape, args.model_pruning, log_dir, eval_epoch_interval=args.eval_epoch_interval, save_eval_checkpoint=args.save_eval_checkpoint)
        callbacks.append(eval_callback)

    # prepare optimizer
    optimizer = get_optimizer(args.optimizer, args.learning_rate, average_type=None, decay_type=None)

    # prepare loss according to loss type & weigted type
    if args.weighted_type == 'balanced':
        classes_weights_path = os.path.join(args.dataset_path, 'classes_weights.txt')
        if os.path.isfile(classes_weights_path):
            weights = load_class_weights(classes_weights_path)
        else:
            weights = calculate_weigths_labels(train_generator, num_classes, save_path=args.dataset_path)
        losses = WeightedSparseCategoricalCrossEntropy(weights)
        sample_weight_mode = None
    elif args.weighted_type == 'adaptive':
        losses = sparse_crossentropy
        sample_weight_mode = 'temporal'
    elif args.weighted_type == None:
        losses = sparse_crossentropy
        sample_weight_mode = None
    else:
        raise ValueError('invalid weighted_type {}'.format(args.weighted_type))

    if args.loss == 'focal':
        warnings.warn("Focal loss doesn't support weighted class balance, will ignore related config")
        losses = softmax_focal_loss
        sample_weight_mode = None
    elif args.loss == 'crossentropy':
        # using crossentropy will keep the weigted type setting
        pass
    else:
        raise ValueError('invalid loss type {}'.format(args.loss))

    # prepare metric
    metrics = {'pred_mask' : Jaccard}

    # support multi-gpu training
    if args.gpu_num >= 2:
        # devices_list=["/gpu:0", "/gpu:1"]
        devices_list=["/gpu:{}".format(n) for n in range(args.gpu_num)]
        strategy = tf.distribute.MirroredStrategy(devices=devices_list)
        print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
        with strategy.scope():
            # get multi-gpu train model
            if args.model_type.startswith('unet_'):
                model = get_unet_model(args.model_type, num_classes, args.model_input_shape, args.freeze_level, weights_path=args.weights_path)
            elif args.model_type.startswith('fast_scnn'):
                model = get_fast_scnn_model(args.model_type, num_classes, args.model_input_shape, weights_path=args.weights_path)
            else:
                model = get_deeplabv3p_model(args.model_type, num_classes, args.model_input_shape, args.output_stride, args.freeze_level, weights_path=args.weights_path)
            # compile model
            model.compile(optimizer=optimizer, sample_weight_mode=sample_weight_mode,
                          loss = losses, metrics = metrics)
    else:
        # get normal train model
        if args.model_type.startswith('unet_'):
            model = get_unet_model(args.model_type, num_classes, args.model_input_shape, args.freeze_level, weights_path=args.weights_path)
        elif args.model_type.startswith('fast_scnn'):
            model = get_fast_scnn_model(args.model_type, num_classes, args.model_input_shape, weights_path=args.weights_path)
        else:
            model = get_deeplabv3p_model(args.model_type, num_classes, args.model_input_shape, args.output_stride, args.freeze_level, weights_path=args.weights_path)
        # compile model
        model.compile(optimizer=optimizer, sample_weight_mode=sample_weight_mode,
                      loss = losses, metrics = metrics)
    model.summary()

    # Transfer training some epochs with frozen layers first if needed, to get a stable loss.
    initial_epoch = args.init_epoch
    epochs = initial_epoch + args.transfer_epoch
    print("Transfer training stage")
    print('Train on {} samples, val on {} samples, with batch size {}, input_shape {}.'.format(num_train, num_val, args.batch_size, args.model_input_shape))
    model.fit_generator(generator=train_generator,
                        steps_per_epoch=len(train_generator),
                        validation_data=valid_generator,
                        validation_steps=len(valid_generator),
                        epochs=epochs,
                        initial_epoch=initial_epoch,
                        verbose=1,
                        workers=1,
                        use_multiprocessing=False,
                        max_queue_size=10,
                        callbacks = callbacks)

    # Wait 2 seconds for next stage
    time.sleep(2)

    if args.decay_type or args.average_type:
        # rebuild optimizer to apply learning rate decay or weights averager,
        # only after unfreeze all layers
        if args.decay_type:
            callbacks.remove(reduce_lr)

        if args.average_type == 'ema' or args.average_type == 'swa':
            # weights averager need tensorflow-addons,
            # which request TF 2.x and have version compatibility
            import tensorflow_addons as tfa
            callbacks.remove(checkpoint)
            avg_checkpoint = tfa.callbacks.AverageModelCheckpoint(filepath=os.path.join(log_dir, 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'),
                                                                  update_weights=True,
                                                                  monitor='val_loss',
                                                                  mode='min',
                                                                  verbose=1,
                                                                  save_weights_only=False,
                                                                  save_best_only=True,
                                                                  period=1)
            callbacks.append(avg_checkpoint)

        steps_per_epoch = max(1, len(train_generator))
        decay_steps = steps_per_epoch * (args.total_epoch - args.init_epoch - args.transfer_epoch)
        optimizer = get_optimizer(args.optimizer, args.learning_rate, average_type=args.average_type, decay_type=args.decay_type, decay_steps=decay_steps)

    # Unfreeze the whole network for further tuning
    # NOTE: more GPU memory is required after unfreezing the body
    print("Unfreeze and continue training, to fine-tune.")
    if args.gpu_num >= 2:
        with strategy.scope():
            for i in range(len(model.layers)):
                model.layers[i].trainable = True
            model.compile(optimizer=optimizer, sample_weight_mode=sample_weight_mode,
                          loss = losses, metrics = metrics) # recompile to apply the change

    else:
        for i in range(len(model.layers)):
            model.layers[i].trainable = True
        model.compile(optimizer=optimizer, sample_weight_mode=sample_weight_mode,
                      loss = losses, metrics = metrics) # recompile to apply the change

    print('Train on {} samples, val on {} samples, with batch size {}, input_shape {}.'.format(num_train, num_val, args.batch_size, args.model_input_shape))
    model.fit_generator(generator=train_generator,
                        steps_per_epoch=len(train_generator),
                        validation_data=valid_generator,
                        validation_steps=len(valid_generator),
                        epochs=args.total_epoch,
                        initial_epoch=epochs,
                        verbose=1,
                        workers=1,
                        use_multiprocessing=False,
                        max_queue_size=10,
                        callbacks = callbacks)

    # Finally store model
    model.save(os.path.join(log_dir, 'trained_final.h5'))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Model definition options
    parser.add_argument('--model_type', type=str, required=False, default='mobilenetv2_lite',
        help='DeepLabv3+ model type: mobilenetv2/mobilenetv2_lite/resnet50, default=%(default)s')
    parser.add_argument('--weights_path', type=str, required=False, default=None,
        help = "Pretrained model/weights file for fine tune")
    parser.add_argument('--model_input_shape', type=str, required=False, default='512x512',
        help = "model image input shape as <height>x<width>, default=%(default)s")
    parser.add_argument('--output_stride', type=int, required=False, default=16, choices=[8, 16, 32],
        help = "model output stride, default=%(default)s")

    # Data options
    parser.add_argument('--dataset_path', type=str, required=False, default='VOC2012/',
        help='dataset path containing images and label png file, default=%(default)s')
    parser.add_argument('--dataset_file', type=str, required=False, default='VOC2012/ImageSets/Segmentation/trainval.txt',
        help='train samples txt file, default=%(default)s')
    parser.add_argument('--val_dataset_file', type=str, required=False, default=None,
        help='val samples txt file, default=%(default)s')
    parser.add_argument('--val_split', type=float, required=False, default=0.1,
        help = "validation data persentage in dataset if no val dataset provide, default=%(default)s")
    parser.add_argument('--classes_path', type=str, required=False, default='configs/voc_classes.txt',
        help='path to class definitions, default=%(default)s')

    # Training options
    parser.add_argument("--batch_size", type=int, required=False, default=16,
        help='batch size for training, default=%(default)s')
    parser.add_argument('--optimizer', type=str, required=False, default='sgd', choices=['adam', 'rmsprop', 'sgd'],
        help = "optimizer for training (adam/rmsprop/sgd), default=%(default)s")
    parser.add_argument('--loss', type=str, required=False, default='crossentropy', choices=['crossentropy', 'focal'],
        help = "loss type for training (crossentropy/focal), default=%(default)s")
    parser.add_argument('--weighted_type', type=str, required=False, default=None, choices=[None, 'adaptive', 'balanced'],
        help = "class balance weighted type, default=%(default)s")
    parser.add_argument('--learning_rate', type=float, required=False, default=1e-2,
        help = "Initial learning rate, default=%(default)s")
    parser.add_argument('--average_type', type=str, required=False, default=None, choices=[None, 'ema', 'swa', 'lookahead'],
        help = "weights average type, default=%(default)s")
    parser.add_argument('--decay_type', type=str, required=False, default=None, choices=[None, 'cosine', 'exponential', 'polynomial', 'piecewise_constant'],
        help = "Learning rate decay type, default=%(default)s")
    parser.add_argument('--transfer_epoch', type=int, required=False, default=5,
        help = "Transfer training stage epochs, default=%(default)s")
    parser.add_argument('--freeze_level', type=int, required=False, default=1, choices=[0, 1, 2],
        help = "Freeze level of the model in transfer training stage. 0:NA/1:backbone/2:only open prediction layer")

    parser.add_argument("--init_epoch", type=int, required=False, default=0,
        help="initial training epochs for fine tune training, default=%(default)s")
    parser.add_argument("--total_epoch", type=int, required=False, default=150,
        help="total training epochs, default=%(default)s")
    parser.add_argument('--gpu_num', type=int, required=False, default=1,
        help='Number of GPU to use, default=%(default)s')
    parser.add_argument('--model_pruning', default=False, action="store_true",
        help='Use model pruning for optimization, only for TF 1.x')

    # Evaluation options
    parser.add_argument('--eval_online', default=False, action="store_true",
        help='Whether to do evaluation on validation dataset during training')
    parser.add_argument('--eval_epoch_interval', type=int, required=False, default=10,
        help = "Number of iteration(epochs) interval to do evaluation, default=%(default)s")
    parser.add_argument('--save_eval_checkpoint', default=False, action="store_true",
        help='Whether to save checkpoint with best evaluation result')

    args = parser.parse_args()
    height, width = args.model_input_shape.split('x')
    args.model_input_shape = (int(height), int(width))

    main(args)