zldrobit glenn-jocher commited on
Commit
2c2ef25
1 Parent(s): ce7fa81

TensorFlow.js export enhancements (#4905)

Browse files

* Add arguments to TensorFlow NMS call

* Add regex substitution to reorder Identity_*

* Delete reorder in docstring

* Cleanup

* Cleanup2

* Removed `+ \` on string ends (not needed)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (2) hide show
  1. export.py +27 -2
  2. models/tf.py +1 -1
export.py CHANGED
@@ -14,7 +14,6 @@ Inference:
14
  yolov5s.tflite
15
 
16
  TensorFlow.js:
17
- $ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
18
  $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
19
  $ npm install
20
  $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
@@ -213,16 +212,32 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
213
  # YOLOv5 TensorFlow.js export
214
  try:
215
  check_requirements(('tensorflowjs',))
 
216
  import tensorflowjs as tfjs
217
 
218
  print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
219
  f = str(file).replace('.pt', '_web_model') # js dir
220
  f_pb = file.with_suffix('.pb') # *.pb path
 
221
 
222
  cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
223
  f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
224
  subprocess.run(cmd, shell=True)
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
227
  except Exception as e:
228
  print(f'\n{prefix} export failure: {e}')
@@ -243,6 +258,10 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
243
  dynamic=False, # ONNX/TF: dynamic axes
244
  simplify=False, # ONNX: simplify model
245
  opset=12, # ONNX: opset version
 
 
 
 
246
  ):
247
  t = time.time()
248
  include = [x.lower() for x in include]
@@ -290,7 +309,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
290
  if any(tf_exports):
291
  pb, tflite, tfjs = tf_exports[1:]
292
  assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
293
- model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model
 
 
294
  if pb or tfjs: # pb prerequisite to tfjs
295
  export_pb(model, im, file)
296
  if tflite:
@@ -319,6 +340,10 @@ def parse_opt():
319
  parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
320
  parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
321
  parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
 
 
 
 
322
  parser.add_argument('--include', nargs='+',
323
  default=['torchscript', 'onnx'],
324
  help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
 
14
  yolov5s.tflite
15
 
16
  TensorFlow.js:
 
17
  $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
18
  $ npm install
19
  $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
 
212
  # YOLOv5 TensorFlow.js export
213
  try:
214
  check_requirements(('tensorflowjs',))
215
+ import re
216
  import tensorflowjs as tfjs
217
 
218
  print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
219
  f = str(file).replace('.pt', '_web_model') # js dir
220
  f_pb = file.with_suffix('.pb') # *.pb path
221
+ f_json = f + '/model.json' # *.json path
222
 
223
  cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
224
  f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
225
  subprocess.run(cmd, shell=True)
226
 
227
+ json = open(f_json).read()
228
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
229
+ subst = re.sub(
230
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
231
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
232
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
233
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}',
234
+ r'{"outputs": {"Identity": {"name": "Identity"}, '
235
+ r'"Identity_1": {"name": "Identity_1"}, '
236
+ r'"Identity_2": {"name": "Identity_2"}, '
237
+ r'"Identity_3": {"name": "Identity_3"}}}',
238
+ json)
239
+ j.write(subst)
240
+
241
  print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
242
  except Exception as e:
243
  print(f'\n{prefix} export failure: {e}')
 
258
  dynamic=False, # ONNX/TF: dynamic axes
259
  simplify=False, # ONNX: simplify model
260
  opset=12, # ONNX: opset version
261
+ topk_per_class=100, # TF.js NMS: topk per class to keep
262
+ topk_all=100, # TF.js NMS: topk for all classes to keep
263
+ iou_thres=0.45, # TF.js NMS: IoU threshold
264
+ conf_thres=0.25 # TF.js NMS: confidence threshold
265
  ):
266
  t = time.time()
267
  include = [x.lower() for x in include]
 
309
  if any(tf_exports):
310
  pb, tflite, tfjs = tf_exports[1:]
311
  assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
312
+ model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs,
313
+ topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres,
314
+ iou_thres=iou_thres) # keras model
315
  if pb or tfjs: # pb prerequisite to tfjs
316
  export_pb(model, im, file)
317
  if tflite:
 
340
  parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
341
  parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
342
  parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
343
+ parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
344
+ parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
345
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
346
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
347
  parser.add_argument('--include', nargs='+',
348
  default=['torchscript', 'onnx'],
349
  help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
models/tf.py CHANGED
@@ -367,7 +367,7 @@ class AgnosticNMS(keras.layers.Layer):
367
  # TF Agnostic NMS
368
  def call(self, input, topk_all, iou_thres, conf_thres):
369
  # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
370
- return tf.map_fn(self._nms, input,
371
  fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
372
  name='agnostic_nms')
373
 
 
367
  # TF Agnostic NMS
368
  def call(self, input, topk_all, iou_thres, conf_thres):
369
  # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
370
+ return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
371
  fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
372
  name='agnostic_nms')
373