Louis Combaldieu
commited on
Fix export for 1-channel images (#6780)
Browse filesExport failed for 1-channel input shape, 1-liner fix
export.py
CHANGED
@@ -260,9 +260,9 @@ def export_saved_model(model, im, file, dynamic,
|
|
260 |
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
261 |
|
262 |
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
263 |
-
im = tf.zeros((batch_size, *imgsz,
|
264 |
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
265 |
-
inputs = tf.keras.Input(shape=(*imgsz,
|
266 |
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
267 |
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
268 |
keras_model.trainable = False
|
|
|
260 |
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
261 |
|
262 |
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
263 |
+
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
|
264 |
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
265 |
+
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
|
266 |
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
267 |
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
268 |
keras_model.trainable = False
|