import tensorflow as tf from absl import app, flags, logging from absl.flags import FLAGS from core.yolov4 import YOLO, decode, filter_boxes import core.utils as utils from core.config import cfg flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file') flags.DEFINE_string('output', './checkpoints/yolov4-416', 'path to output') flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not') flags.DEFINE_integer('input_size', 416, 'define input size of export model') flags.DEFINE_float('score_thres', 0.2, 'define score threshold') flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)') flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4') def save_tf(): STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS) input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3]) feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny) bbox_tensors = [] prob_tensors = [] if FLAGS.tiny: for i, fm in enumerate(feature_maps): if i == 0: output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) else: output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) bbox_tensors.append(output_tensors[0]) prob_tensors.append(output_tensors[1]) else: for i, fm in enumerate(feature_maps): if i == 0: output_tensors = decode(fm, FLAGS.input_size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) elif i == 1: output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) else: output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) bbox_tensors.append(output_tensors[0]) prob_tensors.append(output_tensors[1]) pred_bbox = tf.concat(bbox_tensors, axis=1) pred_prob = tf.concat(prob_tensors, axis=1) if FLAGS.framework == 'tflite': pred = (pred_bbox, pred_prob) else: boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size])) pred = tf.concat([boxes, pred_conf], axis=-1) model = tf.keras.Model(input_layer, pred) utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny) model.summary() model.save(FLAGS.output) def main(_argv): save_tf() if __name__ == '__main__': try: app.run(main) except SystemExit: pass