Louis Combaldieu commited on
Commit
b2adc7c
·
unverified ·
1 Parent(s): cea994b

Fix export for 1-channel images (#6780)

Browse files

Export failed for 1-channel input shape, 1-liner fix

Files changed (1) hide show
  1. export.py +2 -2
export.py CHANGED
@@ -260,9 +260,9 @@ def export_saved_model(model, im, file, dynamic,
260
  batch_size, ch, *imgsz = list(im.shape) # BCHW
261
 
262
  tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
263
- im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
264
  _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
265
- inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
266
  outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
267
  keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
268
  keras_model.trainable = False
 
260
  batch_size, ch, *imgsz = list(im.shape) # BCHW
261
 
262
  tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
263
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
264
  _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
265
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
266
  outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
267
  keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
268
  keras_model.trainable = False