Spaces:
Runtime error
Runtime error
Update models/yolo.py
Browse files- models/yolo.py +110 -0
models/yolo.py
CHANGED
@@ -307,6 +307,93 @@ class IKeypoint(nn.Module):
|
|
307 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
308 |
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
class IAuxDetect(nn.Module):
|
312 |
stride = None # strides computed during build
|
@@ -572,6 +659,16 @@ class Model(nn.Module):
|
|
572 |
self.stride = m.stride
|
573 |
self._initialize_biases_kpt() # only run once
|
574 |
# print('Strides: %s' % m.stride.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
|
576 |
# Init weights, biases
|
577 |
initialize_weights(self)
|
@@ -793,10 +890,23 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|
793 |
args[1] = [list(range(args[1] * 2))] * len(f)
|
794 |
elif m is ReOrg:
|
795 |
c2 = ch[f] * 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
796 |
elif m is Contract:
|
797 |
c2 = ch[f] * args[0] ** 2
|
798 |
elif m is Expand:
|
799 |
c2 = ch[f] // args[0] ** 2
|
|
|
|
|
|
|
800 |
else:
|
801 |
c2 = ch[f]
|
802 |
|
|
|
307 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
308 |
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
309 |
|
310 |
+
class MT(nn.Module):
|
311 |
+
stride = None # strides computed during build
|
312 |
+
export = False # onnx export
|
313 |
+
|
314 |
+
def __init__(self, nc=80, anchors=(), attn=None, mask_iou=False, ch=()): # detection layer
|
315 |
+
super(MT, self).__init__()
|
316 |
+
self.nc = nc # number of classes
|
317 |
+
self.no = nc + 5 # number of outputs per anchor
|
318 |
+
self.nl = len(anchors) # number of detection layers
|
319 |
+
self.na = len(anchors[0]) // 2 # number of anchors
|
320 |
+
self.grid = [torch.zeros(1)] * self.nl # init grid
|
321 |
+
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
|
322 |
+
self.register_buffer('anchors', a) # shape(nl,na,2)
|
323 |
+
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
324 |
+
self.original_anchors = anchors
|
325 |
+
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[0]) # output conv
|
326 |
+
if mask_iou:
|
327 |
+
self.m_iou = nn.ModuleList(nn.Conv2d(x, self.na, 1) for x in ch[0]) # output con
|
328 |
+
self.mask_iou = mask_iou
|
329 |
+
self.attn = attn
|
330 |
+
if attn is not None:
|
331 |
+
# self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 3, padding=1) for x in ch) # output conv
|
332 |
+
self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 1) for x in ch[0]) # output conv
|
333 |
+
#self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, kernel_size=3, stride=1, padding=1) for x in ch) # output conv
|
334 |
+
|
335 |
+
def forward(self, x):
|
336 |
+
#print(x[1].shape)
|
337 |
+
#print(x[2].shape)
|
338 |
+
#print([a.shape for a in x])
|
339 |
+
#exit()
|
340 |
+
# x = x.copy() # for profiling
|
341 |
+
z = [] # inference output
|
342 |
+
za = []
|
343 |
+
zi = []
|
344 |
+
attn = [None] * self.nl
|
345 |
+
iou = [None] * self.nl
|
346 |
+
self.training |= self.export
|
347 |
+
output = dict()
|
348 |
+
for i in range(self.nl):
|
349 |
+
if self.attn is not None:
|
350 |
+
attn[i] = self.attn_m[i](x[0][i]) # conv
|
351 |
+
bs, _, ny, nx = attn[i].shape # x(bs,2352,20,20) to x(bs,3,20,20,784)
|
352 |
+
attn[i] = attn[i].view(bs, self.na, self.attn, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
353 |
+
if self.mask_iou:
|
354 |
+
iou[i] = self.m_iou[i](x[0][i])
|
355 |
+
x[0][i] = self.m[i](x[0][i]) # conv
|
356 |
+
|
357 |
+
bs, _, ny, nx = x[0][i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
358 |
+
x[0][i] = x[0][i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
359 |
+
if self.mask_iou:
|
360 |
+
iou[i] = iou[i].view(bs, self.na, ny, nx).contiguous()
|
361 |
+
|
362 |
+
if not self.training: # inference
|
363 |
+
za.append(attn[i].view(bs, -1, self.attn))
|
364 |
+
if self.mask_iou:
|
365 |
+
zi.append(iou[i].view(bs, -1))
|
366 |
+
if self.grid[i].shape[2:4] != x[0][i].shape[2:4]:
|
367 |
+
self.grid[i] = self._make_grid(nx, ny).to(x[0][i].device)
|
368 |
+
|
369 |
+
y = x[0][i].sigmoid()
|
370 |
+
y[..., 0:2] = (y[..., 0:2] * 3. - 1.0 + self.grid[i]) * self.stride[i] # xy
|
371 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
372 |
+
z.append(y.view(bs, -1, self.no))
|
373 |
+
output["mask_iou"] = None
|
374 |
+
if not self.training:
|
375 |
+
output["test"] = torch.cat(z, 1)
|
376 |
+
if self.attn is not None:
|
377 |
+
output["attn"] = torch.cat(za, 1)
|
378 |
+
if self.mask_iou:
|
379 |
+
output["mask_iou"] = torch.cat(zi, 1).sigmoid()
|
380 |
+
|
381 |
+
else:
|
382 |
+
if self.attn is not None:
|
383 |
+
output["attn"] = attn
|
384 |
+
if self.mask_iou:
|
385 |
+
output["mask_iou"] = iou
|
386 |
+
output["bbox_and_cls"] = x[0]
|
387 |
+
output["bases"] = x[1]
|
388 |
+
output["sem"] = x[2]
|
389 |
+
|
390 |
+
return output
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def _make_grid(nx=20, ny=20):
|
394 |
+
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
395 |
+
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
396 |
+
|
397 |
|
398 |
class IAuxDetect(nn.Module):
|
399 |
stride = None # strides computed during build
|
|
|
659 |
self.stride = m.stride
|
660 |
self._initialize_biases_kpt() # only run once
|
661 |
# print('Strides: %s' % m.stride.tolist())
|
662 |
+
if isinstance(m, MT):
|
663 |
+
s = 256 # 2x min stride
|
664 |
+
temp = self.forward(torch.zeros(1, ch, s, s))
|
665 |
+
if isinstance(temp, list):
|
666 |
+
temp = temp[0]
|
667 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in temp["bbox_and_cls"]]) # forward
|
668 |
+
check_anchor_order(m)
|
669 |
+
m.anchors /= m.stride.view(-1, 1, 1)
|
670 |
+
self.stride = m.stride
|
671 |
+
self._initialize_biases()
|
672 |
|
673 |
# Init weights, biases
|
674 |
initialize_weights(self)
|
|
|
890 |
args[1] = [list(range(args[1] * 2))] * len(f)
|
891 |
elif m is ReOrg:
|
892 |
c2 = ch[f] * 4
|
893 |
+
elif m in [Merge]:
|
894 |
+
c2 = args[0]
|
895 |
+
elif m in [MT]:
|
896 |
+
if len(args) == 3:
|
897 |
+
args.append(False)
|
898 |
+
#print(f)
|
899 |
+
#print(len(ch))
|
900 |
+
#for x in f:
|
901 |
+
# print(ch[x])
|
902 |
+
args.append([ch[x] for x in f])
|
903 |
elif m is Contract:
|
904 |
c2 = ch[f] * args[0] ** 2
|
905 |
elif m is Expand:
|
906 |
c2 = ch[f] // args[0] ** 2
|
907 |
+
elif m is Refine:
|
908 |
+
args.append([ch[x] for x in f])
|
909 |
+
c2 = args[0]
|
910 |
else:
|
911 |
c2 = ch[f]
|
912 |
|