owaiskha9654 commited on
Commit
39780d9
1 Parent(s): becd37a
Files changed (3) hide show
  1. models/experimental.py +262 -0
  2. models/export.py +98 -0
  3. models/yolo.py +843 -0
models/experimental.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.common import Conv, DWConv
7
+ from utils.google_utils import attempt_download
8
+
9
+
10
+ class CrossConv(nn.Module):
11
+ # Cross Convolution Downsample
12
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
13
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
14
+ super(CrossConv, self).__init__()
15
+ c_ = int(c2 * e) # hidden channels
16
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
17
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
18
+ self.add = shortcut and c1 == c2
19
+
20
+ def forward(self, x):
21
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
22
+
23
+
24
+ class Sum(nn.Module):
25
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
26
+ def __init__(self, n, weight=False): # n: number of inputs
27
+ super(Sum, self).__init__()
28
+ self.weight = weight # apply weights boolean
29
+ self.iter = range(n - 1) # iter object
30
+ if weight:
31
+ self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
32
+
33
+ def forward(self, x):
34
+ y = x[0] # no weight
35
+ if self.weight:
36
+ w = torch.sigmoid(self.w) * 2
37
+ for i in self.iter:
38
+ y = y + x[i + 1] * w[i]
39
+ else:
40
+ for i in self.iter:
41
+ y = y + x[i + 1]
42
+ return y
43
+
44
+
45
+ class MixConv2d(nn.Module):
46
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
47
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
48
+ super(MixConv2d, self).__init__()
49
+ groups = len(k)
50
+ if equal_ch: # equal c_ per group
51
+ i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
52
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
53
+ else: # equal weight.numel() per group
54
+ b = [c2] + [0] * groups
55
+ a = np.eye(groups + 1, groups, k=-1)
56
+ a -= np.roll(a, 1, axis=1)
57
+ a *= np.array(k) ** 2
58
+ a[0] = 1
59
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
60
+
61
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
62
+ self.bn = nn.BatchNorm2d(c2)
63
+ self.act = nn.LeakyReLU(0.1, inplace=True)
64
+
65
+ def forward(self, x):
66
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
67
+
68
+
69
+ class Ensemble(nn.ModuleList):
70
+ # Ensemble of models
71
+ def __init__(self):
72
+ super(Ensemble, self).__init__()
73
+
74
+ def forward(self, x, augment=False):
75
+ y = []
76
+ for module in self:
77
+ y.append(module(x, augment)[0])
78
+ # y = torch.stack(y).max(0)[0] # max ensemble
79
+ # y = torch.stack(y).mean(0) # mean ensemble
80
+ y = torch.cat(y, 1) # nms ensemble
81
+ return y, None # inference, train output
82
+
83
+
84
+
85
+
86
+
87
+ class ORT_NMS(torch.autograd.Function):
88
+ '''ONNX-Runtime NMS operation'''
89
+ @staticmethod
90
+ def forward(ctx,
91
+ boxes,
92
+ scores,
93
+ max_output_boxes_per_class=torch.tensor([100]),
94
+ iou_threshold=torch.tensor([0.45]),
95
+ score_threshold=torch.tensor([0.25])):
96
+ device = boxes.device
97
+ batch = scores.shape[0]
98
+ num_det = random.randint(0, 100)
99
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
100
+ idxs = torch.arange(100, 100 + num_det).to(device)
101
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
102
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
103
+ selected_indices = selected_indices.to(torch.int64)
104
+ return selected_indices
105
+
106
+ @staticmethod
107
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
108
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
109
+
110
+
111
+ class TRT_NMS(torch.autograd.Function):
112
+ '''TensorRT NMS operation'''
113
+ @staticmethod
114
+ def forward(
115
+ ctx,
116
+ boxes,
117
+ scores,
118
+ background_class=-1,
119
+ box_coding=1,
120
+ iou_threshold=0.45,
121
+ max_output_boxes=100,
122
+ plugin_version="1",
123
+ score_activation=0,
124
+ score_threshold=0.25,
125
+ ):
126
+ batch_size, num_boxes, num_classes = scores.shape
127
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
128
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
129
+ det_scores = torch.randn(batch_size, max_output_boxes)
130
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
131
+ return num_det, det_boxes, det_scores, det_classes
132
+
133
+ @staticmethod
134
+ def symbolic(g,
135
+ boxes,
136
+ scores,
137
+ background_class=-1,
138
+ box_coding=1,
139
+ iou_threshold=0.45,
140
+ max_output_boxes=100,
141
+ plugin_version="1",
142
+ score_activation=0,
143
+ score_threshold=0.25):
144
+ out = g.op("TRT::EfficientNMS_TRT",
145
+ boxes,
146
+ scores,
147
+ background_class_i=background_class,
148
+ box_coding_i=box_coding,
149
+ iou_threshold_f=iou_threshold,
150
+ max_output_boxes_i=max_output_boxes,
151
+ plugin_version_s=plugin_version,
152
+ score_activation_i=score_activation,
153
+ score_threshold_f=score_threshold,
154
+ outputs=4)
155
+ nums, boxes, scores, classes = out
156
+ return nums, boxes, scores, classes
157
+
158
+
159
+ class ONNX_ORT(nn.Module):
160
+ '''onnx module with ONNX-Runtime NMS operation.'''
161
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
162
+ super().__init__()
163
+ self.device = device if device else torch.device("cpu")
164
+ self.max_obj = torch.tensor([max_obj]).to(device)
165
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
166
+ self.score_threshold = torch.tensor([score_thres]).to(device)
167
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
168
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
169
+ dtype=torch.float32,
170
+ device=self.device)
171
+
172
+ def forward(self, x):
173
+ boxes = x[:, :, :4]
174
+ conf = x[:, :, 4:5]
175
+ scores = x[:, :, 5:]
176
+ scores *= conf
177
+ boxes @= self.convert_matrix
178
+ max_score, category_id = scores.max(2, keepdim=True)
179
+ dis = category_id.float() * self.max_wh
180
+ nmsbox = boxes + dis
181
+ max_score_tp = max_score.transpose(1, 2).contiguous()
182
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
183
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
184
+ selected_boxes = boxes[X, Y, :]
185
+ selected_categories = category_id[X, Y, :].float()
186
+ selected_scores = max_score[X, Y, :]
187
+ X = X.unsqueeze(1).float()
188
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
189
+
190
+ class ONNX_TRT(nn.Module):
191
+ '''onnx module with TensorRT NMS operation.'''
192
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
193
+ super().__init__()
194
+ assert max_wh is None
195
+ self.device = device if device else torch.device('cpu')
196
+ self.background_class = -1,
197
+ self.box_coding = 1,
198
+ self.iou_threshold = iou_thres
199
+ self.max_obj = max_obj
200
+ self.plugin_version = '1'
201
+ self.score_activation = 0
202
+ self.score_threshold = score_thres
203
+
204
+ def forward(self, x):
205
+ boxes = x[:, :, :4]
206
+ conf = x[:, :, 4:5]
207
+ scores = x[:, :, 5:]
208
+ scores *= conf
209
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
210
+ self.iou_threshold, self.max_obj,
211
+ self.plugin_version, self.score_activation,
212
+ self.score_threshold)
213
+ return num_det, det_boxes, det_scores, det_classes
214
+
215
+
216
+ class End2End(nn.Module):
217
+ '''export onnx or tensorrt model with NMS operation.'''
218
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
219
+ super().__init__()
220
+ device = device if device else torch.device('cpu')
221
+ assert isinstance(max_wh,(int)) or max_wh is None
222
+ self.model = model.to(device)
223
+ self.model.model[-1].end2end = True
224
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
225
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
226
+ self.end2end.eval()
227
+
228
+ def forward(self, x):
229
+ x = self.model(x)
230
+ x = self.end2end(x)
231
+ return x
232
+
233
+
234
+
235
+
236
+
237
+ def attempt_load(weights, map_location=None):
238
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
239
+ model = Ensemble()
240
+ for w in weights if isinstance(weights, list) else [weights]:
241
+ attempt_download(w)
242
+ ckpt = torch.load(w, map_location=map_location) # load
243
+ model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
244
+
245
+ # Compatibility updates
246
+ for m in model.modules():
247
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
248
+ m.inplace = True # pytorch 1.7.0 compatibility
249
+ elif type(m) is nn.Upsample:
250
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
251
+ elif type(m) is Conv:
252
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
253
+
254
+ if len(model) == 1:
255
+ return model[-1] # return model
256
+ else:
257
+ print('Ensemble created with %s\n' % weights)
258
+ for k in ['names', 'stride']:
259
+ setattr(model, k, getattr(model[-1], k))
260
+ return model # return ensemble
261
+
262
+
models/export.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import time
4
+
5
+ sys.path.append('./') # to run '$ python *.py' files in subdirectories
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import models
11
+ from models.experimental import attempt_load
12
+ from utils.activations import Hardswish, SiLU
13
+ from utils.general import set_logging, check_img_size
14
+ from utils.torch_utils import select_device
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--weights', type=str, default='./yolor-csp-c.pt', help='weights path')
19
+ parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
20
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
21
+ parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
22
+ parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
23
+ parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
24
+ opt = parser.parse_args()
25
+ opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
26
+ print(opt)
27
+ set_logging()
28
+ t = time.time()
29
+
30
+ # Load PyTorch model
31
+ device = select_device(opt.device)
32
+ model = attempt_load(opt.weights, map_location=device) # load FP32 model
33
+ labels = model.names
34
+
35
+ # Checks
36
+ gs = int(max(model.stride)) # grid size (max stride)
37
+ opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
38
+
39
+ # Input
40
+ img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection
41
+
42
+ # Update model
43
+ for k, m in model.named_modules():
44
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
45
+ if isinstance(m, models.common.Conv): # assign export-friendly activations
46
+ if isinstance(m.act, nn.Hardswish):
47
+ m.act = Hardswish()
48
+ elif isinstance(m.act, nn.SiLU):
49
+ m.act = SiLU()
50
+ # elif isinstance(m, models.yolo.Detect):
51
+ # m.forward = m.forward_export # assign forward (optional)
52
+ model.model[-1].export = not opt.grid # set Detect() layer grid export
53
+ y = model(img) # dry run
54
+
55
+ # TorchScript export
56
+ try:
57
+ print('\nStarting TorchScript export with torch %s...' % torch.__version__)
58
+ f = opt.weights.replace('.pt', '.torchscript.pt') # filename
59
+ ts = torch.jit.trace(model, img, strict=False)
60
+ ts.save(f)
61
+ print('TorchScript export success, saved as %s' % f)
62
+ except Exception as e:
63
+ print('TorchScript export failure: %s' % e)
64
+
65
+ # ONNX export
66
+ try:
67
+ import onnx
68
+
69
+ print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
70
+ f = opt.weights.replace('.pt', '.onnx') # filename
71
+ torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
72
+ output_names=['classes', 'boxes'] if y is None else ['output'],
73
+ dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
74
+ 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
75
+
76
+ # Checks
77
+ onnx_model = onnx.load(f) # load onnx model
78
+ onnx.checker.check_model(onnx_model) # check onnx model
79
+ # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
80
+ print('ONNX export success, saved as %s' % f)
81
+ except Exception as e:
82
+ print('ONNX export failure: %s' % e)
83
+
84
+ # CoreML export
85
+ try:
86
+ import coremltools as ct
87
+
88
+ print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
89
+ # convert model from torchscript and apply pixel scaling as per detect.py
90
+ model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
91
+ f = opt.weights.replace('.pt', '.mlmodel') # filename
92
+ model.save(f)
93
+ print('CoreML export success, saved as %s' % f)
94
+ except Exception as e:
95
+ print('CoreML export failure: %s' % e)
96
+
97
+ # Finish
98
+ print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
models/yolo.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ from copy import deepcopy
5
+
6
+ sys.path.append('./') # to run '$ python *.py' files in subdirectories
7
+ logger = logging.getLogger(__name__)
8
+ import torch
9
+ from models.common import *
10
+ from models.experimental import *
11
+ from utils.autoanchor import check_anchor_order
12
+ from utils.general import make_divisible, check_file, set_logging
13
+ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
14
+ select_device, copy_attr
15
+ from utils.loss import SigmoidBin
16
+
17
+ try:
18
+ import thop # for FLOPS computation
19
+ except ImportError:
20
+ thop = None
21
+
22
+
23
+ class Detect(nn.Module):
24
+ stride = None # strides computed during build
25
+ export = False # onnx export
26
+ end2end = False
27
+ include_nms = False
28
+ concat = False
29
+
30
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
31
+ super(Detect, self).__init__()
32
+ self.nc = nc # number of classes
33
+ self.no = nc + 5 # number of outputs per anchor
34
+ self.nl = len(anchors) # number of detection layers
35
+ self.na = len(anchors[0]) // 2 # number of anchors
36
+ self.grid = [torch.zeros(1)] * self.nl # init grid
37
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
38
+ self.register_buffer('anchors', a) # shape(nl,na,2)
39
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
40
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
41
+
42
+ def forward(self, x):
43
+ # x = x.copy() # for profiling
44
+ z = [] # inference output
45
+ self.training |= self.export
46
+ for i in range(self.nl):
47
+ x[i] = self.m[i](x[i]) # conv
48
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
49
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
50
+
51
+ if not self.training: # inference
52
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
53
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
54
+ y = x[i].sigmoid()
55
+ if not torch.onnx.is_in_onnx_export():
56
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
57
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
58
+ else:
59
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
60
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
61
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
62
+ y = torch.cat((xy, wh, conf), 4)
63
+ z.append(y.view(bs, -1, self.no))
64
+
65
+ if self.training:
66
+ out = x
67
+ elif self.end2end:
68
+ out = torch.cat(z, 1)
69
+ elif self.include_nms:
70
+ z = self.convert(z)
71
+ out = (z, )
72
+ elif self.concat:
73
+ out = torch.cat(z, 1)
74
+ else:
75
+ out = (torch.cat(z, 1), x)
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def _make_grid(nx=20, ny=20):
81
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
82
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
83
+
84
+ def convert(self, z):
85
+ z = torch.cat(z, 1)
86
+ box = z[:, :, :4]
87
+ conf = z[:, :, 4:5]
88
+ score = z[:, :, 5:]
89
+ score *= conf
90
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
91
+ dtype=torch.float32,
92
+ device=z.device)
93
+ box @= convert_matrix
94
+ return (box, score)
95
+
96
+
97
+ class IDetect(nn.Module):
98
+ stride = None # strides computed during build
99
+ export = False # onnx export
100
+ end2end = False
101
+ include_nms = False
102
+ concat = False
103
+
104
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
105
+ super(IDetect, self).__init__()
106
+ self.nc = nc # number of classes
107
+ self.no = nc + 5 # number of outputs per anchor
108
+ self.nl = len(anchors) # number of detection layers
109
+ self.na = len(anchors[0]) // 2 # number of anchors
110
+ self.grid = [torch.zeros(1)] * self.nl # init grid
111
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
112
+ self.register_buffer('anchors', a) # shape(nl,na,2)
113
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
114
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
115
+
116
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
117
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
118
+
119
+ def forward(self, x):
120
+ # x = x.copy() # for profiling
121
+ z = [] # inference output
122
+ self.training |= self.export
123
+ for i in range(self.nl):
124
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
125
+ x[i] = self.im[i](x[i])
126
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
127
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
128
+
129
+ if not self.training: # inference
130
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
131
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
132
+
133
+ y = x[i].sigmoid()
134
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
135
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
136
+ z.append(y.view(bs, -1, self.no))
137
+
138
+ return x if self.training else (torch.cat(z, 1), x)
139
+
140
+ def fuseforward(self, x):
141
+ # x = x.copy() # for profiling
142
+ z = [] # inference output
143
+ self.training |= self.export
144
+ for i in range(self.nl):
145
+ x[i] = self.m[i](x[i]) # conv
146
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
147
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
148
+
149
+ if not self.training: # inference
150
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
151
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
152
+
153
+ y = x[i].sigmoid()
154
+ if not torch.onnx.is_in_onnx_export():
155
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
156
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
157
+ else:
158
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
159
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
160
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
161
+ y = torch.cat((xy, wh, conf), 4)
162
+ z.append(y.view(bs, -1, self.no))
163
+
164
+ if self.training:
165
+ out = x
166
+ elif self.end2end:
167
+ out = torch.cat(z, 1)
168
+ elif self.include_nms:
169
+ z = self.convert(z)
170
+ out = (z, )
171
+ elif self.concat:
172
+ out = torch.cat(z, 1)
173
+ else:
174
+ out = (torch.cat(z, 1), x)
175
+
176
+ return out
177
+
178
+ def fuse(self):
179
+ print("IDetect.fuse")
180
+ # fuse ImplicitA and Convolution
181
+ for i in range(len(self.m)):
182
+ c1,c2,_,_ = self.m[i].weight.shape
183
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
184
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
185
+
186
+ # fuse ImplicitM and Convolution
187
+ for i in range(len(self.m)):
188
+ c1,c2, _,_ = self.im[i].implicit.shape
189
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
190
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
191
+
192
+ @staticmethod
193
+ def _make_grid(nx=20, ny=20):
194
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
195
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
196
+
197
+ def convert(self, z):
198
+ z = torch.cat(z, 1)
199
+ box = z[:, :, :4]
200
+ conf = z[:, :, 4:5]
201
+ score = z[:, :, 5:]
202
+ score *= conf
203
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
204
+ dtype=torch.float32,
205
+ device=z.device)
206
+ box @= convert_matrix
207
+ return (box, score)
208
+
209
+
210
+ class IKeypoint(nn.Module):
211
+ stride = None # strides computed during build
212
+ export = False # onnx export
213
+
214
+ def __init__(self, nc=80, anchors=(), nkpt=17, ch=(), inplace=True, dw_conv_kpt=False): # detection layer
215
+ super(IKeypoint, self).__init__()
216
+ self.nc = nc # number of classes
217
+ self.nkpt = nkpt
218
+ self.dw_conv_kpt = dw_conv_kpt
219
+ self.no_det=(nc + 5) # number of outputs per anchor for box and class
220
+ self.no_kpt = 3*self.nkpt ## number of outputs per anchor for keypoints
221
+ self.no = self.no_det+self.no_kpt
222
+ self.nl = len(anchors) # number of detection layers
223
+ self.na = len(anchors[0]) // 2 # number of anchors
224
+ self.grid = [torch.zeros(1)] * self.nl # init grid
225
+ self.flip_test = False
226
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
227
+ self.register_buffer('anchors', a) # shape(nl,na,2)
228
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
229
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no_det * self.na, 1) for x in ch) # output conv
230
+
231
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
232
+ self.im = nn.ModuleList(ImplicitM(self.no_det * self.na) for _ in ch)
233
+
234
+ if self.nkpt is not None:
235
+ if self.dw_conv_kpt: #keypoint head is slightly more complex
236
+ self.m_kpt = nn.ModuleList(
237
+ nn.Sequential(DWConv(x, x, k=3), Conv(x,x),
238
+ DWConv(x, x, k=3), Conv(x, x),
239
+ DWConv(x, x, k=3), Conv(x,x),
240
+ DWConv(x, x, k=3), Conv(x, x),
241
+ DWConv(x, x, k=3), Conv(x, x),
242
+ DWConv(x, x, k=3), nn.Conv2d(x, self.no_kpt * self.na, 1)) for x in ch)
243
+ else: #keypoint head is a single convolution
244
+ self.m_kpt = nn.ModuleList(nn.Conv2d(x, self.no_kpt * self.na, 1) for x in ch)
245
+
246
+ self.inplace = inplace # use in-place ops (e.g. slice assignment)
247
+
248
+ def forward(self, x):
249
+ # x = x.copy() # for profiling
250
+ z = [] # inference output
251
+ self.training |= self.export
252
+ for i in range(self.nl):
253
+ if self.nkpt is None or self.nkpt==0:
254
+ x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv
255
+ else :
256
+ x[i] = torch.cat((self.im[i](self.m[i](self.ia[i](x[i]))), self.m_kpt[i](x[i])), axis=1)
257
+
258
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
259
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
260
+ x_det = x[i][..., :6]
261
+ x_kpt = x[i][..., 6:]
262
+
263
+ if not self.training: # inference
264
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
265
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
266
+ kpt_grid_x = self.grid[i][..., 0:1]
267
+ kpt_grid_y = self.grid[i][..., 1:2]
268
+
269
+ if self.nkpt == 0:
270
+ y = x[i].sigmoid()
271
+ else:
272
+ y = x_det.sigmoid()
273
+
274
+ if self.inplace:
275
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
276
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
277
+ if self.nkpt != 0:
278
+ x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
279
+ x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
280
+ #x_kpt[..., 0::3] = (x_kpt[..., ::3] + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
281
+ #x_kpt[..., 1::3] = (x_kpt[..., 1::3] + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
282
+ #print('=============')
283
+ #print(self.anchor_grid[i].shape)
284
+ #print(self.anchor_grid[i][...,0].unsqueeze(4).shape)
285
+ #print(x_kpt[..., 0::3].shape)
286
+ #x_kpt[..., 0::3] = ((x_kpt[..., 0::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
287
+ #x_kpt[..., 1::3] = ((x_kpt[..., 1::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
288
+ #x_kpt[..., 0::3] = (((x_kpt[..., 0::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
289
+ #x_kpt[..., 1::3] = (((x_kpt[..., 1::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
290
+ x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid()
291
+
292
+ y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1)
293
+
294
+ else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
295
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
296
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
297
+ if self.nkpt != 0:
298
+ y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i] # xy
299
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
300
+
301
+ z.append(y.view(bs, -1, self.no))
302
+
303
+ return x if self.training else (torch.cat(z, 1), x)
304
+
305
+ @staticmethod
306
+ def _make_grid(nx=20, ny=20):
307
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
308
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
309
+
310
+
311
+ class IAuxDetect(nn.Module):
312
+ stride = None # strides computed during build
313
+ export = False # onnx export
314
+ end2end = False
315
+ include_nms = False
316
+ concat = False
317
+
318
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
319
+ super(IAuxDetect, self).__init__()
320
+ self.nc = nc # number of classes
321
+ self.no = nc + 5 # number of outputs per anchor
322
+ self.nl = len(anchors) # number of detection layers
323
+ self.na = len(anchors[0]) // 2 # number of anchors
324
+ self.grid = [torch.zeros(1)] * self.nl # init grid
325
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
326
+ self.register_buffer('anchors', a) # shape(nl,na,2)
327
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
328
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl]) # output conv
329
+ self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:]) # output conv
330
+
331
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
332
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])
333
+
334
+ def forward(self, x):
335
+ # x = x.copy() # for profiling
336
+ z = [] # inference output
337
+ self.training |= self.export
338
+ for i in range(self.nl):
339
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
340
+ x[i] = self.im[i](x[i])
341
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
342
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
343
+
344
+ x[i+self.nl] = self.m2[i](x[i+self.nl])
345
+ x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
346
+
347
+ if not self.training: # inference
348
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
349
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
350
+
351
+ y = x[i].sigmoid()
352
+ if not torch.onnx.is_in_onnx_export():
353
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
354
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
355
+ else:
356
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
357
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
358
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
359
+ y = torch.cat((xy, wh, conf), 4)
360
+ z.append(y.view(bs, -1, self.no))
361
+
362
+ return x if self.training else (torch.cat(z, 1), x[:self.nl])
363
+
364
+ def fuseforward(self, x):
365
+ # x = x.copy() # for profiling
366
+ z = [] # inference output
367
+ self.training |= self.export
368
+ for i in range(self.nl):
369
+ x[i] = self.m[i](x[i]) # conv
370
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
371
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
372
+
373
+ if not self.training: # inference
374
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
375
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
376
+
377
+ y = x[i].sigmoid()
378
+ if not torch.onnx.is_in_onnx_export():
379
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
380
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
381
+ else:
382
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
383
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
384
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
385
+ z.append(y.view(bs, -1, self.no))
386
+
387
+ if self.training:
388
+ out = x
389
+ elif self.end2end:
390
+ out = torch.cat(z, 1)
391
+ elif self.include_nms:
392
+ z = self.convert(z)
393
+ out = (z, )
394
+ elif self.concat:
395
+ out = torch.cat(z, 1)
396
+ else:
397
+ out = (torch.cat(z, 1), x)
398
+
399
+ return out
400
+
401
+ def fuse(self):
402
+ print("IAuxDetect.fuse")
403
+ # fuse ImplicitA and Convolution
404
+ for i in range(len(self.m)):
405
+ c1,c2,_,_ = self.m[i].weight.shape
406
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
407
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
408
+
409
+ # fuse ImplicitM and Convolution
410
+ for i in range(len(self.m)):
411
+ c1,c2, _,_ = self.im[i].implicit.shape
412
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
413
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
414
+
415
+ @staticmethod
416
+ def _make_grid(nx=20, ny=20):
417
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
418
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
419
+
420
+ def convert(self, z):
421
+ z = torch.cat(z, 1)
422
+ box = z[:, :, :4]
423
+ conf = z[:, :, 4:5]
424
+ score = z[:, :, 5:]
425
+ score *= conf
426
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
427
+ dtype=torch.float32,
428
+ device=z.device)
429
+ box @= convert_matrix
430
+ return (box, score)
431
+
432
+
433
+ class IBin(nn.Module):
434
+ stride = None # strides computed during build
435
+ export = False # onnx export
436
+
437
+ def __init__(self, nc=80, anchors=(), ch=(), bin_count=21): # detection layer
438
+ super(IBin, self).__init__()
439
+ self.nc = nc # number of classes
440
+ self.bin_count = bin_count
441
+
442
+ self.w_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
443
+ self.h_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
444
+ # classes, x,y,obj
445
+ self.no = nc + 3 + \
446
+ self.w_bin_sigmoid.get_length() + self.h_bin_sigmoid.get_length() # w-bce, h-bce
447
+ # + self.x_bin_sigmoid.get_length() + self.y_bin_sigmoid.get_length()
448
+
449
+ self.nl = len(anchors) # number of detection layers
450
+ self.na = len(anchors[0]) // 2 # number of anchors
451
+ self.grid = [torch.zeros(1)] * self.nl # init grid
452
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
453
+ self.register_buffer('anchors', a) # shape(nl,na,2)
454
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
455
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
456
+
457
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
458
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
459
+
460
+ def forward(self, x):
461
+
462
+ #self.x_bin_sigmoid.use_fw_regression = True
463
+ #self.y_bin_sigmoid.use_fw_regression = True
464
+ self.w_bin_sigmoid.use_fw_regression = True
465
+ self.h_bin_sigmoid.use_fw_regression = True
466
+
467
+ # x = x.copy() # for profiling
468
+ z = [] # inference output
469
+ self.training |= self.export
470
+ for i in range(self.nl):
471
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
472
+ x[i] = self.im[i](x[i])
473
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
474
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
475
+
476
+ if not self.training: # inference
477
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
478
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
479
+
480
+ y = x[i].sigmoid()
481
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
482
+ #y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
483
+
484
+
485
+ #px = (self.x_bin_sigmoid.forward(y[..., 0:12]) + self.grid[i][..., 0]) * self.stride[i]
486
+ #py = (self.y_bin_sigmoid.forward(y[..., 12:24]) + self.grid[i][..., 1]) * self.stride[i]
487
+
488
+ pw = self.w_bin_sigmoid.forward(y[..., 2:24]) * self.anchor_grid[i][..., 0]
489
+ ph = self.h_bin_sigmoid.forward(y[..., 24:46]) * self.anchor_grid[i][..., 1]
490
+
491
+ #y[..., 0] = px
492
+ #y[..., 1] = py
493
+ y[..., 2] = pw
494
+ y[..., 3] = ph
495
+
496
+ y = torch.cat((y[..., 0:4], y[..., 46:]), dim=-1)
497
+
498
+ z.append(y.view(bs, -1, y.shape[-1]))
499
+
500
+ return x if self.training else (torch.cat(z, 1), x)
501
+
502
+ @staticmethod
503
+ def _make_grid(nx=20, ny=20):
504
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
505
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
506
+
507
+
508
+ class Model(nn.Module):
509
+ def __init__(self, cfg='yolor-csp-c.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
510
+ super(Model, self).__init__()
511
+ self.traced = False
512
+ if isinstance(cfg, dict):
513
+ self.yaml = cfg # model dict
514
+ else: # is *.yaml
515
+ import yaml # for torch hub
516
+ self.yaml_file = Path(cfg).name
517
+ with open(cfg) as f:
518
+ self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
519
+
520
+ # Define model
521
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
522
+ if nc and nc != self.yaml['nc']:
523
+ logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
524
+ self.yaml['nc'] = nc # override yaml value
525
+ if anchors:
526
+ logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
527
+ self.yaml['anchors'] = round(anchors) # override yaml value
528
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
529
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
530
+ # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
531
+
532
+ # Build strides, anchors
533
+ m = self.model[-1] # Detect()
534
+ if isinstance(m, Detect):
535
+ s = 256 # 2x min stride
536
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
537
+ check_anchor_order(m)
538
+ m.anchors /= m.stride.view(-1, 1, 1)
539
+ self.stride = m.stride
540
+ self._initialize_biases() # only run once
541
+ # print('Strides: %s' % m.stride.tolist())
542
+ if isinstance(m, IDetect):
543
+ s = 256 # 2x min stride
544
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
545
+ check_anchor_order(m)
546
+ m.anchors /= m.stride.view(-1, 1, 1)
547
+ self.stride = m.stride
548
+ self._initialize_biases() # only run once
549
+ # print('Strides: %s' % m.stride.tolist())
550
+ if isinstance(m, IAuxDetect):
551
+ s = 256 # 2x min stride
552
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
553
+ #print(m.stride)
554
+ check_anchor_order(m)
555
+ m.anchors /= m.stride.view(-1, 1, 1)
556
+ self.stride = m.stride
557
+ self._initialize_aux_biases() # only run once
558
+ # print('Strides: %s' % m.stride.tolist())
559
+ if isinstance(m, IBin):
560
+ s = 256 # 2x min stride
561
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
562
+ check_anchor_order(m)
563
+ m.anchors /= m.stride.view(-1, 1, 1)
564
+ self.stride = m.stride
565
+ self._initialize_biases_bin() # only run once
566
+ # print('Strides: %s' % m.stride.tolist())
567
+ if isinstance(m, IKeypoint):
568
+ s = 256 # 2x min stride
569
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
570
+ check_anchor_order(m)
571
+ m.anchors /= m.stride.view(-1, 1, 1)
572
+ self.stride = m.stride
573
+ self._initialize_biases_kpt() # only run once
574
+ # print('Strides: %s' % m.stride.tolist())
575
+
576
+ # Init weights, biases
577
+ initialize_weights(self)
578
+ self.info()
579
+ logger.info('')
580
+
581
+ def forward(self, x, augment=False, profile=False):
582
+ if augment:
583
+ img_size = x.shape[-2:] # height, width
584
+ s = [1, 0.83, 0.67] # scales
585
+ f = [None, 3, None] # flips (2-ud, 3-lr)
586
+ y = [] # outputs
587
+ for si, fi in zip(s, f):
588
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
589
+ yi = self.forward_once(xi)[0] # forward
590
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
591
+ yi[..., :4] /= si # de-scale
592
+ if fi == 2:
593
+ yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
594
+ elif fi == 3:
595
+ yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
596
+ y.append(yi)
597
+ return torch.cat(y, 1), None # augmented inference, train
598
+ else:
599
+ return self.forward_once(x, profile) # single-scale inference, train
600
+
601
+ def forward_once(self, x, profile=False):
602
+ y, dt = [], [] # outputs
603
+ for m in self.model:
604
+ if m.f != -1: # if not from previous layer
605
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
606
+
607
+ if not hasattr(self, 'traced'):
608
+ self.traced=False
609
+
610
+ if self.traced:
611
+ if isinstance(m, Detect) or isinstance(m, IDetect) or isinstance(m, IAuxDetect) or isinstance(m, IKeypoint):
612
+ break
613
+
614
+ if profile:
615
+ c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
616
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
617
+ for _ in range(10):
618
+ m(x.copy() if c else x)
619
+ t = time_synchronized()
620
+ for _ in range(10):
621
+ m(x.copy() if c else x)
622
+ dt.append((time_synchronized() - t) * 100)
623
+ print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
624
+
625
+ x = m(x) # run
626
+
627
+ y.append(x if m.i in self.save else None) # save output
628
+
629
+ if profile:
630
+ print('%.1fms total' % sum(dt))
631
+ return x
632
+
633
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
634
+ # https://arxiv.org/abs/1708.02002 section 3.3
635
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
636
+ m = self.model[-1] # Detect() module
637
+ for mi, s in zip(m.m, m.stride): # from
638
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
639
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
640
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
641
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
642
+
643
+ def _initialize_aux_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
644
+ # https://arxiv.org/abs/1708.02002 section 3.3
645
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
646
+ m = self.model[-1] # Detect() module
647
+ for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
648
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
649
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
650
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
651
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
652
+ b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
653
+ b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
654
+ b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
655
+ mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
656
+
657
+ def _initialize_biases_bin(self, cf=None): # initialize biases into Detect(), cf is class frequency
658
+ # https://arxiv.org/abs/1708.02002 section 3.3
659
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
660
+ m = self.model[-1] # Bin() module
661
+ bc = m.bin_count
662
+ for mi, s in zip(m.m, m.stride): # from
663
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
664
+ old = b[:, (0,1,2,bc+3)].data
665
+ obj_idx = 2*bc+4
666
+ b[:, :obj_idx].data += math.log(0.6 / (bc + 1 - 0.99))
667
+ b[:, obj_idx].data += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
668
+ b[:, (obj_idx+1):].data += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
669
+ b[:, (0,1,2,bc+3)].data = old
670
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
671
+
672
+ def _initialize_biases_kpt(self, cf=None): # initialize biases into Detect(), cf is class frequency
673
+ # https://arxiv.org/abs/1708.02002 section 3.3
674
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
675
+ m = self.model[-1] # Detect() module
676
+ for mi, s in zip(m.m, m.stride): # from
677
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
678
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
679
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
680
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
681
+
682
+ def _print_biases(self):
683
+ m = self.model[-1] # Detect() module
684
+ for mi in m.m: # from
685
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
686
+ print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
687
+
688
+ # def _print_weights(self):
689
+ # for m in self.model.modules():
690
+ # if type(m) is Bottleneck:
691
+ # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
692
+
693
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
694
+ print('Fusing layers... ')
695
+ for m in self.model.modules():
696
+ if isinstance(m, RepConv):
697
+ #print(f" fuse_repvgg_block")
698
+ m.fuse_repvgg_block()
699
+ elif isinstance(m, RepConv_OREPA):
700
+ #print(f" switch_to_deploy")
701
+ m.switch_to_deploy()
702
+ elif type(m) is Conv and hasattr(m, 'bn'):
703
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
704
+ delattr(m, 'bn') # remove batchnorm
705
+ m.forward = m.fuseforward # update forward
706
+ elif isinstance(m, (IDetect, IAuxDetect)):
707
+ m.fuse()
708
+ m.forward = m.fuseforward
709
+ self.info()
710
+ return self
711
+
712
+ def nms(self, mode=True): # add or remove NMS module
713
+ present = type(self.model[-1]) is NMS # last layer is NMS
714
+ if mode and not present:
715
+ print('Adding NMS... ')
716
+ m = NMS() # module
717
+ m.f = -1 # from
718
+ m.i = self.model[-1].i + 1 # index
719
+ self.model.add_module(name='%s' % m.i, module=m) # add
720
+ self.eval()
721
+ elif not mode and present:
722
+ print('Removing NMS... ')
723
+ self.model = self.model[:-1] # remove
724
+ return self
725
+
726
+ def autoshape(self): # add autoShape module
727
+ print('Adding autoShape... ')
728
+ m = autoShape(self) # wrap model
729
+ copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
730
+ return m
731
+
732
+ def info(self, verbose=False, img_size=640): # print model information
733
+ model_info(self, verbose, img_size)
734
+
735
+
736
+ def parse_model(d, ch): # model_dict, input_channels(3)
737
+ logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
738
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
739
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
740
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
741
+
742
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
743
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
744
+ m = eval(m) if isinstance(m, str) else m # eval strings
745
+ for j, a in enumerate(args):
746
+ try:
747
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
748
+ except:
749
+ pass
750
+
751
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
752
+ if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC,
753
+ SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv,
754
+ Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
755
+ RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
756
+ Res, ResCSPA, ResCSPB, ResCSPC,
757
+ RepRes, RepResCSPA, RepResCSPB, RepResCSPC,
758
+ ResX, ResXCSPA, ResXCSPB, ResXCSPC,
759
+ RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC,
760
+ Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
761
+ SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
762
+ SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC]:
763
+ c1, c2 = ch[f], args[0]
764
+ if c2 != no: # if not output
765
+ c2 = make_divisible(c2 * gw, 8)
766
+
767
+ args = [c1, c2, *args[1:]]
768
+ if m in [DownC, SPPCSPC, GhostSPPCSPC,
769
+ BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
770
+ RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
771
+ ResCSPA, ResCSPB, ResCSPC,
772
+ RepResCSPA, RepResCSPB, RepResCSPC,
773
+ ResXCSPA, ResXCSPB, ResXCSPC,
774
+ RepResXCSPA, RepResXCSPB, RepResXCSPC,
775
+ GhostCSPA, GhostCSPB, GhostCSPC,
776
+ STCSPA, STCSPB, STCSPC,
777
+ ST2CSPA, ST2CSPB, ST2CSPC]:
778
+ args.insert(2, n) # number of repeats
779
+ n = 1
780
+ elif m is nn.BatchNorm2d:
781
+ args = [ch[f]]
782
+ elif m is Concat:
783
+ c2 = sum([ch[x] for x in f])
784
+ elif m is Chuncat:
785
+ c2 = sum([ch[x] for x in f])
786
+ elif m is Shortcut:
787
+ c2 = ch[f[0]]
788
+ elif m is Foldcut:
789
+ c2 = ch[f] // 2
790
+ elif m in [Detect, IDetect, IAuxDetect, IBin, IKeypoint]:
791
+ args.append([ch[x] for x in f])
792
+ if isinstance(args[1], int): # number of anchors
793
+ args[1] = [list(range(args[1] * 2))] * len(f)
794
+ elif m is ReOrg:
795
+ c2 = ch[f] * 4
796
+ elif m is Contract:
797
+ c2 = ch[f] * args[0] ** 2
798
+ elif m is Expand:
799
+ c2 = ch[f] // args[0] ** 2
800
+ else:
801
+ c2 = ch[f]
802
+
803
+ m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
804
+ t = str(m)[8:-2].replace('__main__.', '') # module type
805
+ np = sum([x.numel() for x in m_.parameters()]) # number params
806
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
807
+ logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
808
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
809
+ layers.append(m_)
810
+ if i == 0:
811
+ ch = []
812
+ ch.append(c2)
813
+ return nn.Sequential(*layers), sorted(save)
814
+
815
+
816
+ if __name__ == '__main__':
817
+ parser = argparse.ArgumentParser()
818
+ parser.add_argument('--cfg', type=str, default='yolor-csp-c.yaml', help='model.yaml')
819
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
820
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
821
+ opt = parser.parse_args()
822
+ opt.cfg = check_file(opt.cfg) # check file
823
+ set_logging()
824
+ device = select_device(opt.device)
825
+
826
+ # Create model
827
+ model = Model(opt.cfg).to(device)
828
+ model.train()
829
+
830
+ if opt.profile:
831
+ img = torch.rand(1, 3, 640, 640).to(device)
832
+ y = model(img, profile=True)
833
+
834
+ # Profile
835
+ # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
836
+ # y = model(img, profile=True)
837
+
838
+ # Tensorboard
839
+ # from torch.utils.tensorboard import SummaryWriter
840
+ # tb_writer = SummaryWriter()
841
+ # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
842
+ # tb_writer.add_graph(model.model, img) # add model to tensorboard
843
+ # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard