Sa-m commited on
Commit
2f36488
1 Parent(s): e3d794c

Update models/yolo.py

Browse files
Files changed (1) hide show
  1. models/yolo.py +110 -0
models/yolo.py CHANGED
@@ -307,6 +307,93 @@ class IKeypoint(nn.Module):
307
  yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
308
  return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  class IAuxDetect(nn.Module):
312
  stride = None # strides computed during build
@@ -572,6 +659,16 @@ class Model(nn.Module):
572
  self.stride = m.stride
573
  self._initialize_biases_kpt() # only run once
574
  # print('Strides: %s' % m.stride.tolist())
 
 
 
 
 
 
 
 
 
 
575
 
576
  # Init weights, biases
577
  initialize_weights(self)
@@ -793,10 +890,23 @@ def parse_model(d, ch): # model_dict, input_channels(3)
793
  args[1] = [list(range(args[1] * 2))] * len(f)
794
  elif m is ReOrg:
795
  c2 = ch[f] * 4
 
 
 
 
 
 
 
 
 
 
796
  elif m is Contract:
797
  c2 = ch[f] * args[0] ** 2
798
  elif m is Expand:
799
  c2 = ch[f] // args[0] ** 2
 
 
 
800
  else:
801
  c2 = ch[f]
802
 
 
307
  yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
308
  return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
309
 
310
+ class MT(nn.Module):
311
+ stride = None # strides computed during build
312
+ export = False # onnx export
313
+
314
+ def __init__(self, nc=80, anchors=(), attn=None, mask_iou=False, ch=()): # detection layer
315
+ super(MT, self).__init__()
316
+ self.nc = nc # number of classes
317
+ self.no = nc + 5 # number of outputs per anchor
318
+ self.nl = len(anchors) # number of detection layers
319
+ self.na = len(anchors[0]) // 2 # number of anchors
320
+ self.grid = [torch.zeros(1)] * self.nl # init grid
321
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
322
+ self.register_buffer('anchors', a) # shape(nl,na,2)
323
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
324
+ self.original_anchors = anchors
325
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[0]) # output conv
326
+ if mask_iou:
327
+ self.m_iou = nn.ModuleList(nn.Conv2d(x, self.na, 1) for x in ch[0]) # output con
328
+ self.mask_iou = mask_iou
329
+ self.attn = attn
330
+ if attn is not None:
331
+ # self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 3, padding=1) for x in ch) # output conv
332
+ self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 1) for x in ch[0]) # output conv
333
+ #self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, kernel_size=3, stride=1, padding=1) for x in ch) # output conv
334
+
335
+ def forward(self, x):
336
+ #print(x[1].shape)
337
+ #print(x[2].shape)
338
+ #print([a.shape for a in x])
339
+ #exit()
340
+ # x = x.copy() # for profiling
341
+ z = [] # inference output
342
+ za = []
343
+ zi = []
344
+ attn = [None] * self.nl
345
+ iou = [None] * self.nl
346
+ self.training |= self.export
347
+ output = dict()
348
+ for i in range(self.nl):
349
+ if self.attn is not None:
350
+ attn[i] = self.attn_m[i](x[0][i]) # conv
351
+ bs, _, ny, nx = attn[i].shape # x(bs,2352,20,20) to x(bs,3,20,20,784)
352
+ attn[i] = attn[i].view(bs, self.na, self.attn, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
353
+ if self.mask_iou:
354
+ iou[i] = self.m_iou[i](x[0][i])
355
+ x[0][i] = self.m[i](x[0][i]) # conv
356
+
357
+ bs, _, ny, nx = x[0][i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
358
+ x[0][i] = x[0][i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
359
+ if self.mask_iou:
360
+ iou[i] = iou[i].view(bs, self.na, ny, nx).contiguous()
361
+
362
+ if not self.training: # inference
363
+ za.append(attn[i].view(bs, -1, self.attn))
364
+ if self.mask_iou:
365
+ zi.append(iou[i].view(bs, -1))
366
+ if self.grid[i].shape[2:4] != x[0][i].shape[2:4]:
367
+ self.grid[i] = self._make_grid(nx, ny).to(x[0][i].device)
368
+
369
+ y = x[0][i].sigmoid()
370
+ y[..., 0:2] = (y[..., 0:2] * 3. - 1.0 + self.grid[i]) * self.stride[i] # xy
371
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
372
+ z.append(y.view(bs, -1, self.no))
373
+ output["mask_iou"] = None
374
+ if not self.training:
375
+ output["test"] = torch.cat(z, 1)
376
+ if self.attn is not None:
377
+ output["attn"] = torch.cat(za, 1)
378
+ if self.mask_iou:
379
+ output["mask_iou"] = torch.cat(zi, 1).sigmoid()
380
+
381
+ else:
382
+ if self.attn is not None:
383
+ output["attn"] = attn
384
+ if self.mask_iou:
385
+ output["mask_iou"] = iou
386
+ output["bbox_and_cls"] = x[0]
387
+ output["bases"] = x[1]
388
+ output["sem"] = x[2]
389
+
390
+ return output
391
+
392
+ @staticmethod
393
+ def _make_grid(nx=20, ny=20):
394
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
395
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
396
+
397
 
398
  class IAuxDetect(nn.Module):
399
  stride = None # strides computed during build
 
659
  self.stride = m.stride
660
  self._initialize_biases_kpt() # only run once
661
  # print('Strides: %s' % m.stride.tolist())
662
+ if isinstance(m, MT):
663
+ s = 256 # 2x min stride
664
+ temp = self.forward(torch.zeros(1, ch, s, s))
665
+ if isinstance(temp, list):
666
+ temp = temp[0]
667
+ m.stride = torch.tensor([s / x.shape[-2] for x in temp["bbox_and_cls"]]) # forward
668
+ check_anchor_order(m)
669
+ m.anchors /= m.stride.view(-1, 1, 1)
670
+ self.stride = m.stride
671
+ self._initialize_biases()
672
 
673
  # Init weights, biases
674
  initialize_weights(self)
 
890
  args[1] = [list(range(args[1] * 2))] * len(f)
891
  elif m is ReOrg:
892
  c2 = ch[f] * 4
893
+ elif m in [Merge]:
894
+ c2 = args[0]
895
+ elif m in [MT]:
896
+ if len(args) == 3:
897
+ args.append(False)
898
+ #print(f)
899
+ #print(len(ch))
900
+ #for x in f:
901
+ # print(ch[x])
902
+ args.append([ch[x] for x in f])
903
  elif m is Contract:
904
  c2 = ch[f] * args[0] ** 2
905
  elif m is Expand:
906
  c2 = ch[f] // args[0] ** 2
907
+ elif m is Refine:
908
+ args.append([ch[x] for x in f])
909
+ c2 = args[0]
910
  else:
911
  c2 = ch[f]
912