glenn-jocher pre-commit-ci[bot] commited on
Commit
05cf0d1
1 Parent(s): 779efbb

Export single output only (#7259)

Browse files

* Update

* Update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (2) hide show
  1. export.py +1 -0
  2. models/yolo.py +2 -1
export.py CHANGED
@@ -477,6 +477,7 @@ def run(
477
  if isinstance(m, Detect):
478
  m.inplace = inplace
479
  m.onnx_dynamic = dynamic
 
480
  if hasattr(m, 'forward_export'):
481
  m.forward = m.forward_export # assign custom forward (optional)
482
 
 
477
  if isinstance(m, Detect):
478
  m.inplace = inplace
479
  m.onnx_dynamic = dynamic
480
+ m.export = True
481
  if hasattr(m, 'forward_export'):
482
  m.forward = m.forward_export # assign custom forward (optional)
483
 
models/yolo.py CHANGED
@@ -37,6 +37,7 @@ except ImportError:
37
  class Detect(nn.Module):
38
  stride = None # strides computed during build
39
  onnx_dynamic = False # ONNX export parameter
 
40
 
41
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
42
  super().__init__()
@@ -72,7 +73,7 @@ class Detect(nn.Module):
72
  y = torch.cat((xy, wh, conf), 4)
73
  z.append(y.view(bs, -1, self.no))
74
 
75
- return x if self.training else (torch.cat(z, 1), x)
76
 
77
  def _make_grid(self, nx=20, ny=20, i=0):
78
  d = self.anchors[i].device
 
37
  class Detect(nn.Module):
38
  stride = None # strides computed during build
39
  onnx_dynamic = False # ONNX export parameter
40
+ export = False # export mode
41
 
42
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
43
  super().__init__()
 
73
  y = torch.cat((xy, wh, conf), 4)
74
  z.append(y.view(bs, -1, self.no))
75
 
76
+ return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
77
 
78
  def _make_grid(self, nx=20, ny=20, i=0):
79
  d = self.anchors[i].device