YOLOv5 AWS Inferentia Inplace compatibility updates (#2953)
Browse files* Added flag to enable/disable all inplace and assignment operations
* Removed shape print statements
* Scope Detect/Model import to avoid circular dependency
* PEP8
* create _descale_pred()
* replace lost space
* replace list with tuple
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
- models/experimental.py +5 -3
- models/yolo.py +42 -18
models/experimental.py
CHANGED
@@ -110,7 +110,9 @@ class Ensemble(nn.ModuleList):
|
|
110 |
return y, None # inference, train output
|
111 |
|
112 |
|
113 |
-
def attempt_load(weights, map_location=None):
|
|
|
|
|
114 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
115 |
model = Ensemble()
|
116 |
for w in weights if isinstance(weights, list) else [weights]:
|
@@ -120,8 +122,8 @@ def attempt_load(weights, map_location=None):
|
|
120 |
|
121 |
# Compatibility updates
|
122 |
for m in model.modules():
|
123 |
-
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
124 |
-
m.inplace =
|
125 |
elif type(m) is Conv:
|
126 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
127 |
|
|
|
110 |
return y, None # inference, train output
|
111 |
|
112 |
|
113 |
+
def attempt_load(weights, map_location=None, inplace=True):
|
114 |
+
from models.yolo import Detect, Model
|
115 |
+
|
116 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
117 |
model = Ensemble()
|
118 |
for w in weights if isinstance(weights, list) else [weights]:
|
|
|
122 |
|
123 |
# Compatibility updates
|
124 |
for m in model.modules():
|
125 |
+
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
|
126 |
+
m.inplace = inplace # pytorch 1.7.0 compatibility
|
127 |
elif type(m) is Conv:
|
128 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
129 |
|
models/yolo.py
CHANGED
@@ -26,7 +26,7 @@ class Detect(nn.Module):
|
|
26 |
stride = None # strides computed during build
|
27 |
export = False # onnx export
|
28 |
|
29 |
-
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
|
30 |
super(Detect, self).__init__()
|
31 |
self.nc = nc # number of classes
|
32 |
self.no = nc + 5 # number of outputs per anchor
|
@@ -37,6 +37,7 @@ class Detect(nn.Module):
|
|
37 |
self.register_buffer('anchors', a) # shape(nl,na,2)
|
38 |
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
39 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
|
40 |
|
41 |
def forward(self, x):
|
42 |
# x = x.copy() # for profiling
|
@@ -52,8 +53,13 @@ class Detect(nn.Module):
|
|
52 |
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
53 |
|
54 |
y = x[i].sigmoid()
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
57 |
z.append(y.view(bs, -1, self.no))
|
58 |
|
59 |
return x if self.training else (torch.cat(z, 1), x)
|
@@ -85,12 +91,14 @@ class Model(nn.Module):
|
|
85 |
self.yaml['anchors'] = round(anchors) # override yaml value
|
86 |
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
87 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
|
|
88 |
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
89 |
|
90 |
# Build strides, anchors
|
91 |
m = self.model[-1] # Detect()
|
92 |
if isinstance(m, Detect):
|
93 |
s = 256 # 2x min stride
|
|
|
94 |
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
95 |
m.anchors /= m.stride.view(-1, 1, 1)
|
96 |
check_anchor_order(m)
|
@@ -105,24 +113,23 @@ class Model(nn.Module):
|
|
105 |
|
106 |
def forward(self, x, augment=False, profile=False):
|
107 |
if augment:
|
108 |
-
|
109 |
-
s = [1, 0.83, 0.67] # scales
|
110 |
-
f = [None, 3, None] # flips (2-ud, 3-lr)
|
111 |
-
y = [] # outputs
|
112 |
-
for si, fi in zip(s, f):
|
113 |
-
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
114 |
-
yi = self.forward_once(xi)[0] # forward
|
115 |
-
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
116 |
-
yi[..., :4] /= si # de-scale
|
117 |
-
if fi == 2:
|
118 |
-
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
|
119 |
-
elif fi == 3:
|
120 |
-
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
|
121 |
-
y.append(yi)
|
122 |
-
return torch.cat(y, 1), None # augmented inference, train
|
123 |
else:
|
124 |
return self.forward_once(x, profile) # single-scale inference, train
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def forward_once(self, x, profile=False):
|
127 |
y, dt = [], [] # outputs
|
128 |
for m in self.model:
|
@@ -146,6 +153,23 @@ class Model(nn.Module):
|
|
146 |
logger.info('%.1fms total' % sum(dt))
|
147 |
return x
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
150 |
# https://arxiv.org/abs/1708.02002 section 3.3
|
151 |
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
|
|
26 |
stride = None # strides computed during build
|
27 |
export = False # onnx export
|
28 |
|
29 |
+
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
30 |
super(Detect, self).__init__()
|
31 |
self.nc = nc # number of classes
|
32 |
self.no = nc + 5 # number of outputs per anchor
|
|
|
37 |
self.register_buffer('anchors', a) # shape(nl,na,2)
|
38 |
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
39 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
40 |
+
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
41 |
|
42 |
def forward(self, x):
|
43 |
# x = x.copy() # for profiling
|
|
|
53 |
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
54 |
|
55 |
y = x[i].sigmoid()
|
56 |
+
if self.inplace:
|
57 |
+
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
58 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
59 |
+
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
60 |
+
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
61 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
62 |
+
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
63 |
z.append(y.view(bs, -1, self.no))
|
64 |
|
65 |
return x if self.training else (torch.cat(z, 1), x)
|
|
|
91 |
self.yaml['anchors'] = round(anchors) # override yaml value
|
92 |
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
93 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
94 |
+
self.inplace = self.yaml.get('inplace', True)
|
95 |
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
96 |
|
97 |
# Build strides, anchors
|
98 |
m = self.model[-1] # Detect()
|
99 |
if isinstance(m, Detect):
|
100 |
s = 256 # 2x min stride
|
101 |
+
m.inplace = self.inplace
|
102 |
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
103 |
m.anchors /= m.stride.view(-1, 1, 1)
|
104 |
check_anchor_order(m)
|
|
|
113 |
|
114 |
def forward(self, x, augment=False, profile=False):
|
115 |
if augment:
|
116 |
+
return self.forward_augment(x) # augmented inference, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
else:
|
118 |
return self.forward_once(x, profile) # single-scale inference, train
|
119 |
|
120 |
+
def forward_augment(self, x):
|
121 |
+
img_size = x.shape[-2:] # height, width
|
122 |
+
s = [1, 0.83, 0.67] # scales
|
123 |
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
124 |
+
y = [] # outputs
|
125 |
+
for si, fi in zip(s, f):
|
126 |
+
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
127 |
+
yi = self.forward_once(xi)[0] # forward
|
128 |
+
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
129 |
+
yi = self._descale_pred(yi, fi, si, img_size)
|
130 |
+
y.append(yi)
|
131 |
+
return torch.cat(y, 1), None # augmented inference, train
|
132 |
+
|
133 |
def forward_once(self, x, profile=False):
|
134 |
y, dt = [], [] # outputs
|
135 |
for m in self.model:
|
|
|
153 |
logger.info('%.1fms total' % sum(dt))
|
154 |
return x
|
155 |
|
156 |
+
def _descale_pred(self, p, flips, scale, img_size):
|
157 |
+
# de-scale predictions following augmented inference (inverse operation)
|
158 |
+
if self.inplace:
|
159 |
+
p[..., :4] /= scale # de-scale
|
160 |
+
if flips == 2:
|
161 |
+
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
162 |
+
elif flips == 3:
|
163 |
+
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
164 |
+
else:
|
165 |
+
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
166 |
+
if flips == 2:
|
167 |
+
y = img_size[0] - y # de-flip ud
|
168 |
+
elif flips == 3:
|
169 |
+
x = img_size[1] - x # de-flip lr
|
170 |
+
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
171 |
+
return p
|
172 |
+
|
173 |
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
174 |
# https://arxiv.org/abs/1708.02002 section 3.3
|
175 |
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|