Mrinal Jain commited on
Commit
4effd06
1 Parent(s): 7c6a335

Consistent saved_model output format (#7032)

Browse files
Files changed (2) hide show
  1. export.py +1 -1
  2. models/common.py +1 -1
export.py CHANGED
@@ -275,7 +275,7 @@ def export_saved_model(model, im, file, dynamic,
275
  m = m.get_concrete_function(spec)
276
  frozen_func = convert_variables_to_constants_v2(m)
277
  tfm = tf.Module()
278
- tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
279
  tfm.__call__(im)
280
  tf.saved_model.save(
281
  tfm,
 
275
  m = m.get_concrete_function(spec)
276
  frozen_func = convert_variables_to_constants_v2(m)
277
  tfm = tf.Module()
278
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[0], [spec])
279
  tfm.__call__(im)
280
  tf.saved_model.save(
281
  tfm,
models/common.py CHANGED
@@ -441,7 +441,7 @@ class DetectMultiBackend(nn.Module):
441
  else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
442
  im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
443
  if self.saved_model: # SavedModel
444
- y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
445
  elif self.pb: # GraphDef
446
  y = self.frozen_func(x=self.tf.constant(im)).numpy()
447
  else: # Lite or Edge TPU
 
441
  else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
442
  im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
443
  if self.saved_model: # SavedModel
444
+ y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
445
  elif self.pb: # GraphDef
446
  y = self.frozen_func(x=self.tf.constant(im)).numpy()
447
  else: # Lite or Edge TPU