mamechin commited on
Commit
48e83a1
1 Parent(s): 7df6ce7
Files changed (4) hide show
  1. models/__init__.py +1 -0
  2. models/common.py +2019 -0
  3. models/experimental.py +272 -0
  4. models/yolo.py +843 -0
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
models/common.py ADDED
@@ -0,0 +1,2019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import copy
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import requests
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchvision.ops import DeformConv2d
12
+ from PIL import Image
13
+ from torch.cuda import amp
14
+
15
+ from utils.datasets import letterbox
16
+ from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
17
+ from utils.plots import color_list, plot_one_box
18
+ from utils.torch_utils import time_synchronized
19
+
20
+
21
+ ##### basic ####
22
+
23
+ def autopad(k, p=None): # kernel, padding
24
+ # Pad to 'same'
25
+ if p is None:
26
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
27
+ return p
28
+
29
+
30
+ class MP(nn.Module):
31
+ def __init__(self, k=2):
32
+ super(MP, self).__init__()
33
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
34
+
35
+ def forward(self, x):
36
+ return self.m(x)
37
+
38
+
39
+ class SP(nn.Module):
40
+ def __init__(self, k=3, s=1):
41
+ super(SP, self).__init__()
42
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
43
+
44
+ def forward(self, x):
45
+ return self.m(x)
46
+
47
+
48
+ class ReOrg(nn.Module):
49
+ def __init__(self):
50
+ super(ReOrg, self).__init__()
51
+
52
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
53
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
54
+
55
+
56
+ class Concat(nn.Module):
57
+ def __init__(self, dimension=1):
58
+ super(Concat, self).__init__()
59
+ self.d = dimension
60
+
61
+ def forward(self, x):
62
+ return torch.cat(x, self.d)
63
+
64
+
65
+ class Chuncat(nn.Module):
66
+ def __init__(self, dimension=1):
67
+ super(Chuncat, self).__init__()
68
+ self.d = dimension
69
+
70
+ def forward(self, x):
71
+ x1 = []
72
+ x2 = []
73
+ for xi in x:
74
+ xi1, xi2 = xi.chunk(2, self.d)
75
+ x1.append(xi1)
76
+ x2.append(xi2)
77
+ return torch.cat(x1+x2, self.d)
78
+
79
+
80
+ class Shortcut(nn.Module):
81
+ def __init__(self, dimension=0):
82
+ super(Shortcut, self).__init__()
83
+ self.d = dimension
84
+
85
+ def forward(self, x):
86
+ return x[0]+x[1]
87
+
88
+
89
+ class Foldcut(nn.Module):
90
+ def __init__(self, dimension=0):
91
+ super(Foldcut, self).__init__()
92
+ self.d = dimension
93
+
94
+ def forward(self, x):
95
+ x1, x2 = x.chunk(2, self.d)
96
+ return x1+x2
97
+
98
+
99
+ class Conv(nn.Module):
100
+ # Standard convolution
101
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
102
+ super(Conv, self).__init__()
103
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
104
+ self.bn = nn.BatchNorm2d(c2)
105
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
106
+
107
+ def forward(self, x):
108
+ return self.act(self.bn(self.conv(x)))
109
+
110
+ def fuseforward(self, x):
111
+ return self.act(self.conv(x))
112
+
113
+
114
+ class RobustConv(nn.Module):
115
+ # Robust convolution (use high kernel size 7-11 for: downsampling and other layers). Train for 300 - 450 epochs.
116
+ def __init__(self, c1, c2, k=7, s=1, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
117
+ super(RobustConv, self).__init__()
118
+ self.conv_dw = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
119
+ self.conv1x1 = nn.Conv2d(c1, c2, 1, 1, 0, groups=1, bias=True)
120
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
121
+
122
+ def forward(self, x):
123
+ x = x.to(memory_format=torch.channels_last)
124
+ x = self.conv1x1(self.conv_dw(x))
125
+ if self.gamma is not None:
126
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
127
+ return x
128
+
129
+
130
+ class RobustConv2(nn.Module):
131
+ # Robust convolution 2 (use [32, 5, 2] or [32, 7, 4] or [32, 11, 8] for one of the paths in CSP).
132
+ def __init__(self, c1, c2, k=7, s=4, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
133
+ super(RobustConv2, self).__init__()
134
+ self.conv_strided = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
135
+ self.conv_deconv = nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=s, stride=s,
136
+ padding=0, bias=True, dilation=1, groups=1
137
+ )
138
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
139
+
140
+ def forward(self, x):
141
+ x = self.conv_deconv(self.conv_strided(x))
142
+ if self.gamma is not None:
143
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
144
+ return x
145
+
146
+
147
+ def DWConv(c1, c2, k=1, s=1, act=True):
148
+ # Depthwise convolution
149
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
150
+
151
+
152
+ class GhostConv(nn.Module):
153
+ # Ghost Convolution https://github.com/huawei-noah/ghostnet
154
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
155
+ super(GhostConv, self).__init__()
156
+ c_ = c2 // 2 # hidden channels
157
+ self.cv1 = Conv(c1, c_, k, s, None, g, act)
158
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
159
+
160
+ def forward(self, x):
161
+ y = self.cv1(x)
162
+ return torch.cat([y, self.cv2(y)], 1)
163
+
164
+
165
+ class Stem(nn.Module):
166
+ # Stem
167
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
168
+ super(Stem, self).__init__()
169
+ c_ = int(c2/2) # hidden channels
170
+ self.cv1 = Conv(c1, c_, 3, 2)
171
+ self.cv2 = Conv(c_, c_, 1, 1)
172
+ self.cv3 = Conv(c_, c_, 3, 2)
173
+ self.pool = torch.nn.MaxPool2d(2, stride=2)
174
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
175
+
176
+ def forward(self, x):
177
+ x = self.cv1(x)
178
+ return self.cv4(torch.cat((self.cv3(self.cv2(x)), self.pool(x)), dim=1))
179
+
180
+
181
+ class DownC(nn.Module):
182
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
183
+ def __init__(self, c1, c2, n=1, k=2):
184
+ super(DownC, self).__init__()
185
+ c_ = int(c1) # hidden channels
186
+ self.cv1 = Conv(c1, c_, 1, 1)
187
+ self.cv2 = Conv(c_, c2//2, 3, k)
188
+ self.cv3 = Conv(c1, c2//2, 1, 1)
189
+ self.mp = nn.MaxPool2d(kernel_size=k, stride=k)
190
+
191
+ def forward(self, x):
192
+ return torch.cat((self.cv2(self.cv1(x)), self.cv3(self.mp(x))), dim=1)
193
+
194
+
195
+ class SPP(nn.Module):
196
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
197
+ def __init__(self, c1, c2, k=(5, 9, 13)):
198
+ super(SPP, self).__init__()
199
+ c_ = c1 // 2 # hidden channels
200
+ self.cv1 = Conv(c1, c_, 1, 1)
201
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
202
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
203
+
204
+ def forward(self, x):
205
+ x = self.cv1(x)
206
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
207
+
208
+
209
+ class Bottleneck(nn.Module):
210
+ # Darknet bottleneck
211
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
212
+ super(Bottleneck, self).__init__()
213
+ c_ = int(c2 * e) # hidden channels
214
+ self.cv1 = Conv(c1, c_, 1, 1)
215
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
216
+ self.add = shortcut and c1 == c2
217
+
218
+ def forward(self, x):
219
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
220
+
221
+
222
+ class Res(nn.Module):
223
+ # ResNet bottleneck
224
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
225
+ super(Res, self).__init__()
226
+ c_ = int(c2 * e) # hidden channels
227
+ self.cv1 = Conv(c1, c_, 1, 1)
228
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
229
+ self.cv3 = Conv(c_, c2, 1, 1)
230
+ self.add = shortcut and c1 == c2
231
+
232
+ def forward(self, x):
233
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
234
+
235
+
236
+ class ResX(Res):
237
+ # ResNet bottleneck
238
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
239
+ super().__init__(c1, c2, shortcut, g, e)
240
+ c_ = int(c2 * e) # hidden channels
241
+
242
+
243
+ class Ghost(nn.Module):
244
+ # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
245
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
246
+ super(Ghost, self).__init__()
247
+ c_ = c2 // 2
248
+ self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
249
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
250
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
251
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
252
+ Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
253
+
254
+ def forward(self, x):
255
+ return self.conv(x) + self.shortcut(x)
256
+
257
+ ##### end of basic #####
258
+
259
+
260
+ ##### cspnet #####
261
+
262
+ class SPPCSPC(nn.Module):
263
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
264
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
265
+ super(SPPCSPC, self).__init__()
266
+ c_ = int(2 * c2 * e) # hidden channels
267
+ self.cv1 = Conv(c1, c_, 1, 1)
268
+ self.cv2 = Conv(c1, c_, 1, 1)
269
+ self.cv3 = Conv(c_, c_, 3, 1)
270
+ self.cv4 = Conv(c_, c_, 1, 1)
271
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
272
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
273
+ self.cv6 = Conv(c_, c_, 3, 1)
274
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
275
+
276
+ def forward(self, x):
277
+ x1 = self.cv4(self.cv3(self.cv1(x)))
278
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
279
+ y2 = self.cv2(x)
280
+ return self.cv7(torch.cat((y1, y2), dim=1))
281
+
282
+ class GhostSPPCSPC(SPPCSPC):
283
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
284
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
285
+ super().__init__(c1, c2, n, shortcut, g, e, k)
286
+ c_ = int(2 * c2 * e) # hidden channels
287
+ self.cv1 = GhostConv(c1, c_, 1, 1)
288
+ self.cv2 = GhostConv(c1, c_, 1, 1)
289
+ self.cv3 = GhostConv(c_, c_, 3, 1)
290
+ self.cv4 = GhostConv(c_, c_, 1, 1)
291
+ self.cv5 = GhostConv(4 * c_, c_, 1, 1)
292
+ self.cv6 = GhostConv(c_, c_, 3, 1)
293
+ self.cv7 = GhostConv(2 * c_, c2, 1, 1)
294
+
295
+
296
+ class GhostStem(Stem):
297
+ # Stem
298
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
299
+ super().__init__(c1, c2, k, s, p, g, act)
300
+ c_ = int(c2/2) # hidden channels
301
+ self.cv1 = GhostConv(c1, c_, 3, 2)
302
+ self.cv2 = GhostConv(c_, c_, 1, 1)
303
+ self.cv3 = GhostConv(c_, c_, 3, 2)
304
+ self.cv4 = GhostConv(2 * c_, c2, 1, 1)
305
+
306
+
307
+ class BottleneckCSPA(nn.Module):
308
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
309
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
310
+ super(BottleneckCSPA, self).__init__()
311
+ c_ = int(c2 * e) # hidden channels
312
+ self.cv1 = Conv(c1, c_, 1, 1)
313
+ self.cv2 = Conv(c1, c_, 1, 1)
314
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
315
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
316
+
317
+ def forward(self, x):
318
+ y1 = self.m(self.cv1(x))
319
+ y2 = self.cv2(x)
320
+ return self.cv3(torch.cat((y1, y2), dim=1))
321
+
322
+
323
+ class BottleneckCSPB(nn.Module):
324
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
325
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
326
+ super(BottleneckCSPB, self).__init__()
327
+ c_ = int(c2) # hidden channels
328
+ self.cv1 = Conv(c1, c_, 1, 1)
329
+ self.cv2 = Conv(c_, c_, 1, 1)
330
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
331
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
332
+
333
+ def forward(self, x):
334
+ x1 = self.cv1(x)
335
+ y1 = self.m(x1)
336
+ y2 = self.cv2(x1)
337
+ return self.cv3(torch.cat((y1, y2), dim=1))
338
+
339
+
340
+ class BottleneckCSPC(nn.Module):
341
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
342
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
343
+ super(BottleneckCSPC, self).__init__()
344
+ c_ = int(c2 * e) # hidden channels
345
+ self.cv1 = Conv(c1, c_, 1, 1)
346
+ self.cv2 = Conv(c1, c_, 1, 1)
347
+ self.cv3 = Conv(c_, c_, 1, 1)
348
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
349
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
350
+
351
+ def forward(self, x):
352
+ y1 = self.cv3(self.m(self.cv1(x)))
353
+ y2 = self.cv2(x)
354
+ return self.cv4(torch.cat((y1, y2), dim=1))
355
+
356
+
357
+ class ResCSPA(BottleneckCSPA):
358
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
359
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
360
+ super().__init__(c1, c2, n, shortcut, g, e)
361
+ c_ = int(c2 * e) # hidden channels
362
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
363
+
364
+
365
+ class ResCSPB(BottleneckCSPB):
366
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
367
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
368
+ super().__init__(c1, c2, n, shortcut, g, e)
369
+ c_ = int(c2) # hidden channels
370
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
371
+
372
+
373
+ class ResCSPC(BottleneckCSPC):
374
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
375
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
376
+ super().__init__(c1, c2, n, shortcut, g, e)
377
+ c_ = int(c2 * e) # hidden channels
378
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
379
+
380
+
381
+ class ResXCSPA(ResCSPA):
382
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
383
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
384
+ super().__init__(c1, c2, n, shortcut, g, e)
385
+ c_ = int(c2 * e) # hidden channels
386
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
387
+
388
+
389
+ class ResXCSPB(ResCSPB):
390
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
391
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
392
+ super().__init__(c1, c2, n, shortcut, g, e)
393
+ c_ = int(c2) # hidden channels
394
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
395
+
396
+
397
+ class ResXCSPC(ResCSPC):
398
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
399
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
400
+ super().__init__(c1, c2, n, shortcut, g, e)
401
+ c_ = int(c2 * e) # hidden channels
402
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
403
+
404
+
405
+ class GhostCSPA(BottleneckCSPA):
406
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
407
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
408
+ super().__init__(c1, c2, n, shortcut, g, e)
409
+ c_ = int(c2 * e) # hidden channels
410
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
411
+
412
+
413
+ class GhostCSPB(BottleneckCSPB):
414
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
415
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
416
+ super().__init__(c1, c2, n, shortcut, g, e)
417
+ c_ = int(c2) # hidden channels
418
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
419
+
420
+
421
+ class GhostCSPC(BottleneckCSPC):
422
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
423
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
424
+ super().__init__(c1, c2, n, shortcut, g, e)
425
+ c_ = int(c2 * e) # hidden channels
426
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
427
+
428
+ ##### end of cspnet #####
429
+
430
+
431
+ ##### yolor #####
432
+
433
+ class ImplicitA(nn.Module):
434
+ def __init__(self, channel, mean=0., std=.02):
435
+ super(ImplicitA, self).__init__()
436
+ self.channel = channel
437
+ self.mean = mean
438
+ self.std = std
439
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
440
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
441
+
442
+ def forward(self, x):
443
+ return self.implicit + x
444
+
445
+
446
+ class ImplicitM(nn.Module):
447
+ def __init__(self, channel, mean=1., std=.02):
448
+ super(ImplicitM, self).__init__()
449
+ self.channel = channel
450
+ self.mean = mean
451
+ self.std = std
452
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
453
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
454
+
455
+ def forward(self, x):
456
+ return self.implicit * x
457
+
458
+ ##### end of yolor #####
459
+
460
+
461
+ ##### repvgg #####
462
+
463
+ class RepConv(nn.Module):
464
+ # Represented convolution
465
+ # https://arxiv.org/abs/2101.03697
466
+
467
+ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, deploy=False):
468
+ super(RepConv, self).__init__()
469
+
470
+ self.deploy = deploy
471
+ self.groups = g
472
+ self.in_channels = c1
473
+ self.out_channels = c2
474
+
475
+ assert k == 3
476
+ assert autopad(k, p) == 1
477
+
478
+ padding_11 = autopad(k, p) - k // 2
479
+
480
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
481
+
482
+ if deploy:
483
+ self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
484
+
485
+ else:
486
+ self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
487
+
488
+ self.rbr_dense = nn.Sequential(
489
+ nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
490
+ nn.BatchNorm2d(num_features=c2),
491
+ )
492
+
493
+ self.rbr_1x1 = nn.Sequential(
494
+ nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
495
+ nn.BatchNorm2d(num_features=c2),
496
+ )
497
+
498
+ def forward(self, inputs):
499
+ if hasattr(self, "rbr_reparam"):
500
+ return self.act(self.rbr_reparam(inputs))
501
+
502
+ if self.rbr_identity is None:
503
+ id_out = 0
504
+ else:
505
+ id_out = self.rbr_identity(inputs)
506
+
507
+ return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
508
+
509
+ def get_equivalent_kernel_bias(self):
510
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
511
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
512
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
513
+ return (
514
+ kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
515
+ bias3x3 + bias1x1 + biasid,
516
+ )
517
+
518
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
519
+ if kernel1x1 is None:
520
+ return 0
521
+ else:
522
+ return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
523
+
524
+ def _fuse_bn_tensor(self, branch):
525
+ if branch is None:
526
+ return 0, 0
527
+ if isinstance(branch, nn.Sequential):
528
+ kernel = branch[0].weight
529
+ running_mean = branch[1].running_mean
530
+ running_var = branch[1].running_var
531
+ gamma = branch[1].weight
532
+ beta = branch[1].bias
533
+ eps = branch[1].eps
534
+ else:
535
+ assert isinstance(branch, nn.BatchNorm2d)
536
+ if not hasattr(self, "id_tensor"):
537
+ input_dim = self.in_channels // self.groups
538
+ kernel_value = np.zeros(
539
+ (self.in_channels, input_dim, 3, 3), dtype=np.float32
540
+ )
541
+ for i in range(self.in_channels):
542
+ kernel_value[i, i % input_dim, 1, 1] = 1
543
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
544
+ kernel = self.id_tensor
545
+ running_mean = branch.running_mean
546
+ running_var = branch.running_var
547
+ gamma = branch.weight
548
+ beta = branch.bias
549
+ eps = branch.eps
550
+ std = (running_var + eps).sqrt()
551
+ t = (gamma / std).reshape(-1, 1, 1, 1)
552
+ return kernel * t, beta - running_mean * gamma / std
553
+
554
+ def repvgg_convert(self):
555
+ kernel, bias = self.get_equivalent_kernel_bias()
556
+ return (
557
+ kernel.detach().cpu().numpy(),
558
+ bias.detach().cpu().numpy(),
559
+ )
560
+
561
+ def fuse_conv_bn(self, conv, bn):
562
+
563
+ std = (bn.running_var + bn.eps).sqrt()
564
+ bias = bn.bias - bn.running_mean * bn.weight / std
565
+
566
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
567
+ weights = conv.weight * t
568
+
569
+ bn = nn.Identity()
570
+ conv = nn.Conv2d(in_channels = conv.in_channels,
571
+ out_channels = conv.out_channels,
572
+ kernel_size = conv.kernel_size,
573
+ stride=conv.stride,
574
+ padding = conv.padding,
575
+ dilation = conv.dilation,
576
+ groups = conv.groups,
577
+ bias = True,
578
+ padding_mode = conv.padding_mode)
579
+
580
+ conv.weight = torch.nn.Parameter(weights)
581
+ conv.bias = torch.nn.Parameter(bias)
582
+ return conv
583
+
584
+ def fuse_repvgg_block(self):
585
+ if self.deploy:
586
+ return
587
+ print(f"RepConv.fuse_repvgg_block")
588
+
589
+ self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
590
+
591
+ self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
592
+ rbr_1x1_bias = self.rbr_1x1.bias
593
+ weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
594
+
595
+ # Fuse self.rbr_identity
596
+ if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
597
+ # print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm")
598
+ identity_conv_1x1 = nn.Conv2d(
599
+ in_channels=self.in_channels,
600
+ out_channels=self.out_channels,
601
+ kernel_size=1,
602
+ stride=1,
603
+ padding=0,
604
+ groups=self.groups,
605
+ bias=False)
606
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
607
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
608
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
609
+ identity_conv_1x1.weight.data.fill_(0.0)
610
+ identity_conv_1x1.weight.data.fill_diagonal_(1.0)
611
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
612
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
613
+
614
+ identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
615
+ bias_identity_expanded = identity_conv_1x1.bias
616
+ weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
617
+ else:
618
+ # print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}")
619
+ bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
620
+ weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )
621
+
622
+
623
+ #print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ")
624
+ #print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ")
625
+ #print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ")
626
+
627
+ self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
628
+ self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
629
+
630
+ self.rbr_reparam = self.rbr_dense
631
+ self.deploy = True
632
+
633
+ if self.rbr_identity is not None:
634
+ del self.rbr_identity
635
+ self.rbr_identity = None
636
+
637
+ if self.rbr_1x1 is not None:
638
+ del self.rbr_1x1
639
+ self.rbr_1x1 = None
640
+
641
+ if self.rbr_dense is not None:
642
+ del self.rbr_dense
643
+ self.rbr_dense = None
644
+
645
+
646
+ class RepBottleneck(Bottleneck):
647
+ # Standard bottleneck
648
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
649
+ super().__init__(c1, c2, shortcut=True, g=1, e=0.5)
650
+ c_ = int(c2 * e) # hidden channels
651
+ self.cv2 = RepConv(c_, c2, 3, 1, g=g)
652
+
653
+
654
+ class RepBottleneckCSPA(BottleneckCSPA):
655
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
656
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
657
+ super().__init__(c1, c2, n, shortcut, g, e)
658
+ c_ = int(c2 * e) # hidden channels
659
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
660
+
661
+
662
+ class RepBottleneckCSPB(BottleneckCSPB):
663
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
664
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
665
+ super().__init__(c1, c2, n, shortcut, g, e)
666
+ c_ = int(c2) # hidden channels
667
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
668
+
669
+
670
+ class RepBottleneckCSPC(BottleneckCSPC):
671
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
672
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
673
+ super().__init__(c1, c2, n, shortcut, g, e)
674
+ c_ = int(c2 * e) # hidden channels
675
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
676
+
677
+
678
+ class RepRes(Res):
679
+ # Standard bottleneck
680
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
681
+ super().__init__(c1, c2, shortcut, g, e)
682
+ c_ = int(c2 * e) # hidden channels
683
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
684
+
685
+
686
+ class RepResCSPA(ResCSPA):
687
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
688
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
689
+ super().__init__(c1, c2, n, shortcut, g, e)
690
+ c_ = int(c2 * e) # hidden channels
691
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
692
+
693
+
694
+ class RepResCSPB(ResCSPB):
695
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
696
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
697
+ super().__init__(c1, c2, n, shortcut, g, e)
698
+ c_ = int(c2) # hidden channels
699
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
700
+
701
+
702
+ class RepResCSPC(ResCSPC):
703
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
704
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
705
+ super().__init__(c1, c2, n, shortcut, g, e)
706
+ c_ = int(c2 * e) # hidden channels
707
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
708
+
709
+
710
+ class RepResX(ResX):
711
+ # Standard bottleneck
712
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
713
+ super().__init__(c1, c2, shortcut, g, e)
714
+ c_ = int(c2 * e) # hidden channels
715
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
716
+
717
+
718
+ class RepResXCSPA(ResXCSPA):
719
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
720
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
721
+ super().__init__(c1, c2, n, shortcut, g, e)
722
+ c_ = int(c2 * e) # hidden channels
723
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
724
+
725
+
726
+ class RepResXCSPB(ResXCSPB):
727
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
728
+ def __init__(self, c1, c2, n=1, shortcut=False, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
729
+ super().__init__(c1, c2, n, shortcut, g, e)
730
+ c_ = int(c2) # hidden channels
731
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
732
+
733
+
734
+ class RepResXCSPC(ResXCSPC):
735
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
736
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
737
+ super().__init__(c1, c2, n, shortcut, g, e)
738
+ c_ = int(c2 * e) # hidden channels
739
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
740
+
741
+ ##### end of repvgg #####
742
+
743
+
744
+ ##### transformer #####
745
+
746
+ class TransformerLayer(nn.Module):
747
+ # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
748
+ def __init__(self, c, num_heads):
749
+ super().__init__()
750
+ self.q = nn.Linear(c, c, bias=False)
751
+ self.k = nn.Linear(c, c, bias=False)
752
+ self.v = nn.Linear(c, c, bias=False)
753
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
754
+ self.fc1 = nn.Linear(c, c, bias=False)
755
+ self.fc2 = nn.Linear(c, c, bias=False)
756
+
757
+ def forward(self, x):
758
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
759
+ x = self.fc2(self.fc1(x)) + x
760
+ return x
761
+
762
+
763
+ class TransformerBlock(nn.Module):
764
+ # Vision Transformer https://arxiv.org/abs/2010.11929
765
+ def __init__(self, c1, c2, num_heads, num_layers):
766
+ super().__init__()
767
+ self.conv = None
768
+ if c1 != c2:
769
+ self.conv = Conv(c1, c2)
770
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
771
+ self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
772
+ self.c2 = c2
773
+
774
+ def forward(self, x):
775
+ if self.conv is not None:
776
+ x = self.conv(x)
777
+ b, _, w, h = x.shape
778
+ p = x.flatten(2)
779
+ p = p.unsqueeze(0)
780
+ p = p.transpose(0, 3)
781
+ p = p.squeeze(3)
782
+ e = self.linear(p)
783
+ x = p + e
784
+
785
+ x = self.tr(x)
786
+ x = x.unsqueeze(3)
787
+ x = x.transpose(0, 3)
788
+ x = x.reshape(b, self.c2, w, h)
789
+ return x
790
+
791
+ ##### end of transformer #####
792
+
793
+
794
+ ##### yolov5 #####
795
+
796
+ class Focus(nn.Module):
797
+ # Focus wh information into c-space
798
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
799
+ super(Focus, self).__init__()
800
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
801
+ # self.contract = Contract(gain=2)
802
+
803
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
804
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
805
+ # return self.conv(self.contract(x))
806
+
807
+
808
+ class SPPF(nn.Module):
809
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
810
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
811
+ super().__init__()
812
+ c_ = c1 // 2 # hidden channels
813
+ self.cv1 = Conv(c1, c_, 1, 1)
814
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
815
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
816
+
817
+ def forward(self, x):
818
+ x = self.cv1(x)
819
+ y1 = self.m(x)
820
+ y2 = self.m(y1)
821
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
822
+
823
+
824
+ class Contract(nn.Module):
825
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
826
+ def __init__(self, gain=2):
827
+ super().__init__()
828
+ self.gain = gain
829
+
830
+ def forward(self, x):
831
+ N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
832
+ s = self.gain
833
+ x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
834
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
835
+ return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
836
+
837
+
838
+ class Expand(nn.Module):
839
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
840
+ def __init__(self, gain=2):
841
+ super().__init__()
842
+ self.gain = gain
843
+
844
+ def forward(self, x):
845
+ N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
846
+ s = self.gain
847
+ x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
848
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
849
+ return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
850
+
851
+
852
+ class NMS(nn.Module):
853
+ # Non-Maximum Suppression (NMS) module
854
+ conf = 0.25 # confidence threshold
855
+ iou = 0.45 # IoU threshold
856
+ classes = None # (optional list) filter by class
857
+
858
+ def __init__(self):
859
+ super(NMS, self).__init__()
860
+
861
+ def forward(self, x):
862
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
863
+
864
+
865
+ class autoShape(nn.Module):
866
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
867
+ conf = 0.25 # NMS confidence threshold
868
+ iou = 0.45 # NMS IoU threshold
869
+ classes = None # (optional list) filter by class
870
+
871
+ def __init__(self, model):
872
+ super(autoShape, self).__init__()
873
+ self.model = model.eval()
874
+
875
+ def autoshape(self):
876
+ print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
877
+ return self
878
+
879
+ @torch.no_grad()
880
+ def forward(self, imgs, size=640, augment=False, profile=False):
881
+ # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
882
+ # filename: imgs = 'data/samples/zidane.jpg'
883
+ # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
884
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
885
+ # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
886
+ # numpy: = np.zeros((640,1280,3)) # HWC
887
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
888
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
889
+
890
+ t = [time_synchronized()]
891
+ p = next(self.model.parameters()) # for device and type
892
+ if isinstance(imgs, torch.Tensor): # torch
893
+ with amp.autocast(enabled=p.device.type != 'cpu'):
894
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
895
+
896
+ # Pre-process
897
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
898
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
899
+ for i, im in enumerate(imgs):
900
+ f = f'image{i}' # filename
901
+ if isinstance(im, str): # filename or uri
902
+ im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
903
+ elif isinstance(im, Image.Image): # PIL Image
904
+ im, f = np.asarray(im), getattr(im, 'filename', f) or f
905
+ files.append(Path(f).with_suffix('.jpg').name)
906
+ if im.shape[0] < 5: # image in CHW
907
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
908
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
909
+ s = im.shape[:2] # HWC
910
+ shape0.append(s) # image shape
911
+ g = (size / max(s)) # gain
912
+ shape1.append([y * g for y in s])
913
+ imgs[i] = im # update
914
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
915
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
916
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
917
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
918
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
919
+ t.append(time_synchronized())
920
+
921
+ with amp.autocast(enabled=p.device.type != 'cpu'):
922
+ # Inference
923
+ y = self.model(x, augment, profile)[0] # forward
924
+ t.append(time_synchronized())
925
+
926
+ # Post-process
927
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
928
+ for i in range(n):
929
+ scale_coords(shape1, y[i][:, :4], shape0[i])
930
+
931
+ t.append(time_synchronized())
932
+ return Detections(imgs, y, files, t, self.names, x.shape)
933
+
934
+
935
+ class Detections:
936
+ # detections class for YOLOv5 inference results
937
+ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
938
+ super(Detections, self).__init__()
939
+ d = pred[0].device # device
940
+ gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
941
+ self.imgs = imgs # list of images as numpy arrays
942
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
943
+ self.names = names # class names
944
+ self.files = files # image filenames
945
+ self.xyxy = pred # xyxy pixels
946
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
947
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
948
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
949
+ self.n = len(self.pred) # number of images (batch size)
950
+ self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
951
+ self.s = shape # inference BCHW shape
952
+
953
+ def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
954
+ colors = color_list()
955
+ for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
956
+ str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
957
+ if pred is not None:
958
+ for c in pred[:, -1].unique():
959
+ n = (pred[:, -1] == c).sum() # detections per class
960
+ str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
961
+ if show or save or render:
962
+ for *box, conf, cls in pred: # xyxy, confidence, class
963
+ label = f'{self.names[int(cls)]} {conf:.2f}'
964
+ plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
965
+ img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
966
+ if pprint:
967
+ print(str.rstrip(', '))
968
+ if show:
969
+ img.show(self.files[i]) # show
970
+ if save:
971
+ f = self.files[i]
972
+ img.save(Path(save_dir) / f) # save
973
+ print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
974
+ if render:
975
+ self.imgs[i] = np.asarray(img)
976
+
977
+ def print(self):
978
+ self.display(pprint=True) # print results
979
+ print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
980
+
981
+ def show(self):
982
+ self.display(show=True) # show results
983
+
984
+ def save(self, save_dir='runs/hub/exp'):
985
+ save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
986
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
987
+ self.display(save=True, save_dir=save_dir) # save results
988
+
989
+ def render(self):
990
+ self.display(render=True) # render results
991
+ return self.imgs
992
+
993
+ def pandas(self):
994
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
995
+ new = copy(self) # return copy
996
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
997
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
998
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
999
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1000
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1001
+ return new
1002
+
1003
+ def tolist(self):
1004
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1005
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
1006
+ for d in x:
1007
+ for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1008
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
1009
+ return x
1010
+
1011
+ def __len__(self):
1012
+ return self.n
1013
+
1014
+
1015
+ class Classify(nn.Module):
1016
+ # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
1017
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1018
+ super(Classify, self).__init__()
1019
+ self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
1020
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
1021
+ self.flat = nn.Flatten()
1022
+
1023
+ def forward(self, x):
1024
+ z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
1025
+ return self.flat(self.conv(z)) # flatten to x(b,c2)
1026
+
1027
+ ##### end of yolov5 ######
1028
+
1029
+
1030
+ ##### orepa #####
1031
+
1032
+ def transI_fusebn(kernel, bn):
1033
+ gamma = bn.weight
1034
+ std = (bn.running_var + bn.eps).sqrt()
1035
+ return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
1036
+
1037
+
1038
+ class ConvBN(nn.Module):
1039
+ def __init__(self, in_channels, out_channels, kernel_size,
1040
+ stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
1041
+ super().__init__()
1042
+ if nonlinear is None:
1043
+ self.nonlinear = nn.Identity()
1044
+ else:
1045
+ self.nonlinear = nonlinear
1046
+ if deploy:
1047
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1048
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
1049
+ else:
1050
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1051
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
1052
+ self.bn = nn.BatchNorm2d(num_features=out_channels)
1053
+
1054
+ def forward(self, x):
1055
+ if hasattr(self, 'bn'):
1056
+ return self.nonlinear(self.bn(self.conv(x)))
1057
+ else:
1058
+ return self.nonlinear(self.conv(x))
1059
+
1060
+ def switch_to_deploy(self):
1061
+ kernel, bias = transI_fusebn(self.conv.weight, self.bn)
1062
+ conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
1063
+ stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
1064
+ conv.weight.data = kernel
1065
+ conv.bias.data = bias
1066
+ for para in self.parameters():
1067
+ para.detach_()
1068
+ self.__delattr__('conv')
1069
+ self.__delattr__('bn')
1070
+ self.conv = conv
1071
+
1072
+ class OREPA_3x3_RepConv(nn.Module):
1073
+
1074
+ def __init__(self, in_channels, out_channels, kernel_size,
1075
+ stride=1, padding=0, dilation=1, groups=1,
1076
+ internal_channels_1x1_3x3=None,
1077
+ deploy=False, nonlinear=None, single_init=False):
1078
+ super(OREPA_3x3_RepConv, self).__init__()
1079
+ self.deploy = deploy
1080
+
1081
+ if nonlinear is None:
1082
+ self.nonlinear = nn.Identity()
1083
+ else:
1084
+ self.nonlinear = nonlinear
1085
+
1086
+ self.kernel_size = kernel_size
1087
+ self.in_channels = in_channels
1088
+ self.out_channels = out_channels
1089
+ self.groups = groups
1090
+ assert padding == kernel_size // 2
1091
+
1092
+ self.stride = stride
1093
+ self.padding = padding
1094
+ self.dilation = dilation
1095
+
1096
+ self.branch_counter = 0
1097
+
1098
+ self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
1099
+ nn.init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
1100
+ self.branch_counter += 1
1101
+
1102
+
1103
+ if groups < out_channels:
1104
+ self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1105
+ self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1106
+ nn.init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
1107
+ nn.init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
1108
+ self.weight_rbr_avg_conv.data
1109
+ self.weight_rbr_pfir_conv.data
1110
+ self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
1111
+ self.branch_counter += 1
1112
+
1113
+ else:
1114
+ raise NotImplementedError
1115
+ self.branch_counter += 1
1116
+
1117
+ if internal_channels_1x1_3x3 is None:
1118
+ internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
1119
+
1120
+ if internal_channels_1x1_3x3 == in_channels:
1121
+ self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
1122
+ id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
1123
+ for i in range(in_channels):
1124
+ id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
1125
+ id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
1126
+ self.register_buffer('id_tensor', id_tensor)
1127
+
1128
+ else:
1129
+ self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
1130
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
1131
+ self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
1132
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
1133
+ self.branch_counter += 1
1134
+
1135
+ expand_ratio = 8
1136
+ self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
1137
+ self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
1138
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
1139
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
1140
+ self.branch_counter += 1
1141
+
1142
+ if out_channels == in_channels and stride == 1:
1143
+ self.branch_counter += 1
1144
+
1145
+ self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
1146
+ self.bn = nn.BatchNorm2d(out_channels)
1147
+
1148
+ self.fre_init()
1149
+
1150
+ nn.init.constant_(self.vector[0, :], 0.25) #origin
1151
+ nn.init.constant_(self.vector[1, :], 0.25) #avg
1152
+ nn.init.constant_(self.vector[2, :], 0.0) #prior
1153
+ nn.init.constant_(self.vector[3, :], 0.5) #1x1_kxk
1154
+ nn.init.constant_(self.vector[4, :], 0.5) #dws_conv
1155
+
1156
+
1157
+ def fre_init(self):
1158
+ prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
1159
+ half_fg = self.out_channels/2
1160
+ for i in range(self.out_channels):
1161
+ for h in range(3):
1162
+ for w in range(3):
1163
+ if i < half_fg:
1164
+ prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
1165
+ else:
1166
+ prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)
1167
+
1168
+ self.register_buffer('weight_rbr_prior', prior_tensor)
1169
+
1170
+ def weight_gen(self):
1171
+
1172
+ weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])
1173
+
1174
+ weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])
1175
+
1176
+ weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])
1177
+
1178
+ weight_rbr_1x1_kxk_conv1 = None
1179
+ if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
1180
+ weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
1181
+ elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
1182
+ weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
1183
+ else:
1184
+ raise NotImplementedError
1185
+ weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2
1186
+
1187
+ if self.groups > 1:
1188
+ g = self.groups
1189
+ t, ig = weight_rbr_1x1_kxk_conv1.size()
1190
+ o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
1191
+ weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
1192
+ weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
1193
+ weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
1194
+ else:
1195
+ weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)
1196
+
1197
+ weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])
1198
+
1199
+ weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
1200
+ weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])
1201
+
1202
+ weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv
1203
+
1204
+ return weight
1205
+
1206
+ def dwsc2full(self, weight_dw, weight_pw, groups):
1207
+
1208
+ t, ig, h, w = weight_dw.size()
1209
+ o, _, _, _ = weight_pw.size()
1210
+ tg = int(t/groups)
1211
+ i = int(ig*groups)
1212
+ weight_dw = weight_dw.view(groups, tg, ig, h, w)
1213
+ weight_pw = weight_pw.squeeze().view(o, groups, tg)
1214
+
1215
+ weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
1216
+ return weight_dsc.view(o, i, h, w)
1217
+
1218
+ def forward(self, inputs):
1219
+ weight = self.weight_gen()
1220
+ out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
1221
+
1222
+ return self.nonlinear(self.bn(out))
1223
+
1224
+ class RepConv_OREPA(nn.Module):
1225
+
1226
+ def __init__(self, c1, c2, k=3, s=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.SiLU()):
1227
+ super(RepConv_OREPA, self).__init__()
1228
+ self.deploy = deploy
1229
+ self.groups = groups
1230
+ self.in_channels = c1
1231
+ self.out_channels = c2
1232
+
1233
+ self.padding = padding
1234
+ self.dilation = dilation
1235
+ self.groups = groups
1236
+
1237
+ assert k == 3
1238
+ assert padding == 1
1239
+
1240
+ padding_11 = padding - k // 2
1241
+
1242
+ if nonlinear is None:
1243
+ self.nonlinearity = nn.Identity()
1244
+ else:
1245
+ self.nonlinearity = nonlinear
1246
+
1247
+ if use_se:
1248
+ self.se = SEBlock(self.out_channels, internal_neurons=self.out_channels // 16)
1249
+ else:
1250
+ self.se = nn.Identity()
1251
+
1252
+ if deploy:
1253
+ self.rbr_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s,
1254
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
1255
+
1256
+ else:
1257
+ self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.out_channels == self.in_channels and s == 1 else None
1258
+ self.rbr_dense = OREPA_3x3_RepConv(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s, padding=padding, groups=groups, dilation=1)
1259
+ self.rbr_1x1 = ConvBN(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=s, padding=padding_11, groups=groups, dilation=1)
1260
+ print('RepVGG Block, identity = ', self.rbr_identity)
1261
+
1262
+
1263
+ def forward(self, inputs):
1264
+ if hasattr(self, 'rbr_reparam'):
1265
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
1266
+
1267
+ if self.rbr_identity is None:
1268
+ id_out = 0
1269
+ else:
1270
+ id_out = self.rbr_identity(inputs)
1271
+
1272
+ out1 = self.rbr_dense(inputs)
1273
+ out2 = self.rbr_1x1(inputs)
1274
+ out3 = id_out
1275
+ out = out1 + out2 + out3
1276
+
1277
+ return self.nonlinearity(self.se(out))
1278
+
1279
+
1280
+ # Optional. This improves the accuracy and facilitates quantization.
1281
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
1282
+ # 2. Use like this.
1283
+ # loss = criterion(....)
1284
+ # for every RepVGGBlock blk:
1285
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
1286
+ # optimizer.zero_grad()
1287
+ # loss.backward()
1288
+
1289
+ # Not used for OREPA
1290
+ def get_custom_L2(self):
1291
+ K3 = self.rbr_dense.weight_gen()
1292
+ K1 = self.rbr_1x1.conv.weight
1293
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1294
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1295
+
1296
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
1297
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
1298
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
1299
+ return l2_loss_eq_kernel + l2_loss_circle
1300
+
1301
+ def get_equivalent_kernel_bias(self):
1302
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
1303
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
1304
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
1305
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
1306
+
1307
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
1308
+ if kernel1x1 is None:
1309
+ return 0
1310
+ else:
1311
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
1312
+
1313
+ def _fuse_bn_tensor(self, branch):
1314
+ if branch is None:
1315
+ return 0, 0
1316
+ if not isinstance(branch, nn.BatchNorm2d):
1317
+ if isinstance(branch, OREPA_3x3_RepConv):
1318
+ kernel = branch.weight_gen()
1319
+ elif isinstance(branch, ConvBN):
1320
+ kernel = branch.conv.weight
1321
+ else:
1322
+ raise NotImplementedError
1323
+ running_mean = branch.bn.running_mean
1324
+ running_var = branch.bn.running_var
1325
+ gamma = branch.bn.weight
1326
+ beta = branch.bn.bias
1327
+ eps = branch.bn.eps
1328
+ else:
1329
+ if not hasattr(self, 'id_tensor'):
1330
+ input_dim = self.in_channels // self.groups
1331
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
1332
+ for i in range(self.in_channels):
1333
+ kernel_value[i, i % input_dim, 1, 1] = 1
1334
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
1335
+ kernel = self.id_tensor
1336
+ running_mean = branch.running_mean
1337
+ running_var = branch.running_var
1338
+ gamma = branch.weight
1339
+ beta = branch.bias
1340
+ eps = branch.eps
1341
+ std = (running_var + eps).sqrt()
1342
+ t = (gamma / std).reshape(-1, 1, 1, 1)
1343
+ return kernel * t, beta - running_mean * gamma / std
1344
+
1345
+ def switch_to_deploy(self):
1346
+ if hasattr(self, 'rbr_reparam'):
1347
+ return
1348
+ print(f"RepConv_OREPA.switch_to_deploy")
1349
+ kernel, bias = self.get_equivalent_kernel_bias()
1350
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
1351
+ kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
1352
+ padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
1353
+ self.rbr_reparam.weight.data = kernel
1354
+ self.rbr_reparam.bias.data = bias
1355
+ for para in self.parameters():
1356
+ para.detach_()
1357
+ self.__delattr__('rbr_dense')
1358
+ self.__delattr__('rbr_1x1')
1359
+ if hasattr(self, 'rbr_identity'):
1360
+ self.__delattr__('rbr_identity')
1361
+
1362
+ ##### end of orepa #####
1363
+
1364
+
1365
+ ##### swin transformer #####
1366
+
1367
+ class WindowAttention(nn.Module):
1368
+
1369
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
1370
+
1371
+ super().__init__()
1372
+ self.dim = dim
1373
+ self.window_size = window_size # Wh, Ww
1374
+ self.num_heads = num_heads
1375
+ head_dim = dim // num_heads
1376
+ self.scale = qk_scale or head_dim ** -0.5
1377
+
1378
+ # define a parameter table of relative position bias
1379
+ self.relative_position_bias_table = nn.Parameter(
1380
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
1381
+
1382
+ # get pair-wise relative position index for each token inside the window
1383
+ coords_h = torch.arange(self.window_size[0])
1384
+ coords_w = torch.arange(self.window_size[1])
1385
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1386
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1387
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1388
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1389
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1390
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1391
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1392
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1393
+ self.register_buffer("relative_position_index", relative_position_index)
1394
+
1395
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1396
+ self.attn_drop = nn.Dropout(attn_drop)
1397
+ self.proj = nn.Linear(dim, dim)
1398
+ self.proj_drop = nn.Dropout(proj_drop)
1399
+
1400
+ nn.init.normal_(self.relative_position_bias_table, std=.02)
1401
+ self.softmax = nn.Softmax(dim=-1)
1402
+
1403
+ def forward(self, x, mask=None):
1404
+
1405
+ B_, N, C = x.shape
1406
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
1407
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1408
+
1409
+ q = q * self.scale
1410
+ attn = (q @ k.transpose(-2, -1))
1411
+
1412
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
1413
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1414
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1415
+ attn = attn + relative_position_bias.unsqueeze(0)
1416
+
1417
+ if mask is not None:
1418
+ nW = mask.shape[0]
1419
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1420
+ attn = attn.view(-1, self.num_heads, N, N)
1421
+ attn = self.softmax(attn)
1422
+ else:
1423
+ attn = self.softmax(attn)
1424
+
1425
+ attn = self.attn_drop(attn)
1426
+
1427
+ # print(attn.dtype, v.dtype)
1428
+ try:
1429
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1430
+ except:
1431
+ #print(attn.dtype, v.dtype)
1432
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1433
+ x = self.proj(x)
1434
+ x = self.proj_drop(x)
1435
+ return x
1436
+
1437
+ class Mlp(nn.Module):
1438
+
1439
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1440
+ super().__init__()
1441
+ out_features = out_features or in_features
1442
+ hidden_features = hidden_features or in_features
1443
+ self.fc1 = nn.Linear(in_features, hidden_features)
1444
+ self.act = act_layer()
1445
+ self.fc2 = nn.Linear(hidden_features, out_features)
1446
+ self.drop = nn.Dropout(drop)
1447
+
1448
+ def forward(self, x):
1449
+ x = self.fc1(x)
1450
+ x = self.act(x)
1451
+ x = self.drop(x)
1452
+ x = self.fc2(x)
1453
+ x = self.drop(x)
1454
+ return x
1455
+
1456
+ def window_partition(x, window_size):
1457
+
1458
+ B, H, W, C = x.shape
1459
+ assert H % window_size == 0, 'feature map h and w can not divide by window size'
1460
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1461
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1462
+ return windows
1463
+
1464
+ def window_reverse(windows, window_size, H, W):
1465
+
1466
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1467
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1468
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1469
+ return x
1470
+
1471
+
1472
+ class SwinTransformerLayer(nn.Module):
1473
+
1474
+ def __init__(self, dim, num_heads, window_size=8, shift_size=0,
1475
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
1476
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
1477
+ super().__init__()
1478
+ self.dim = dim
1479
+ self.num_heads = num_heads
1480
+ self.window_size = window_size
1481
+ self.shift_size = shift_size
1482
+ self.mlp_ratio = mlp_ratio
1483
+ # if min(self.input_resolution) <= self.window_size:
1484
+ # # if window size is larger than input resolution, we don't partition windows
1485
+ # self.shift_size = 0
1486
+ # self.window_size = min(self.input_resolution)
1487
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1488
+
1489
+ self.norm1 = norm_layer(dim)
1490
+ self.attn = WindowAttention(
1491
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1492
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
1493
+
1494
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1495
+ self.norm2 = norm_layer(dim)
1496
+ mlp_hidden_dim = int(dim * mlp_ratio)
1497
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1498
+
1499
+ def create_mask(self, H, W):
1500
+ # calculate attention mask for SW-MSA
1501
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1502
+ h_slices = (slice(0, -self.window_size),
1503
+ slice(-self.window_size, -self.shift_size),
1504
+ slice(-self.shift_size, None))
1505
+ w_slices = (slice(0, -self.window_size),
1506
+ slice(-self.window_size, -self.shift_size),
1507
+ slice(-self.shift_size, None))
1508
+ cnt = 0
1509
+ for h in h_slices:
1510
+ for w in w_slices:
1511
+ img_mask[:, h, w, :] = cnt
1512
+ cnt += 1
1513
+
1514
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1515
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1516
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1517
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1518
+
1519
+ return attn_mask
1520
+
1521
+ def forward(self, x):
1522
+ # reshape x[b c h w] to x[b l c]
1523
+ _, _, H_, W_ = x.shape
1524
+
1525
+ Padding = False
1526
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1527
+ Padding = True
1528
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1529
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1530
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1531
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1532
+
1533
+ # print('2', x.shape)
1534
+ B, C, H, W = x.shape
1535
+ L = H * W
1536
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1537
+
1538
+ # create mask from init to forward
1539
+ if self.shift_size > 0:
1540
+ attn_mask = self.create_mask(H, W).to(x.device)
1541
+ else:
1542
+ attn_mask = None
1543
+
1544
+ shortcut = x
1545
+ x = self.norm1(x)
1546
+ x = x.view(B, H, W, C)
1547
+
1548
+ # cyclic shift
1549
+ if self.shift_size > 0:
1550
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1551
+ else:
1552
+ shifted_x = x
1553
+
1554
+ # partition windows
1555
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1556
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1557
+
1558
+ # W-MSA/SW-MSA
1559
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1560
+
1561
+ # merge windows
1562
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1563
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
1564
+
1565
+ # reverse cyclic shift
1566
+ if self.shift_size > 0:
1567
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1568
+ else:
1569
+ x = shifted_x
1570
+ x = x.view(B, H * W, C)
1571
+
1572
+ # FFN
1573
+ x = shortcut + self.drop_path(x)
1574
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1575
+
1576
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1577
+
1578
+ if Padding:
1579
+ x = x[:, :, :H_, :W_] # reverse padding
1580
+
1581
+ return x
1582
+
1583
+
1584
+ class SwinTransformerBlock(nn.Module):
1585
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
1586
+ super().__init__()
1587
+ self.conv = None
1588
+ if c1 != c2:
1589
+ self.conv = Conv(c1, c2)
1590
+
1591
+ # remove input_resolution
1592
+ self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
1593
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1594
+
1595
+ def forward(self, x):
1596
+ if self.conv is not None:
1597
+ x = self.conv(x)
1598
+ x = self.blocks(x)
1599
+ return x
1600
+
1601
+
1602
+ class STCSPA(nn.Module):
1603
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1604
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1605
+ super(STCSPA, self).__init__()
1606
+ c_ = int(c2 * e) # hidden channels
1607
+ self.cv1 = Conv(c1, c_, 1, 1)
1608
+ self.cv2 = Conv(c1, c_, 1, 1)
1609
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1610
+ num_heads = c_ // 32
1611
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1612
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1613
+
1614
+ def forward(self, x):
1615
+ y1 = self.m(self.cv1(x))
1616
+ y2 = self.cv2(x)
1617
+ return self.cv3(torch.cat((y1, y2), dim=1))
1618
+
1619
+
1620
+ class STCSPB(nn.Module):
1621
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1622
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1623
+ super(STCSPB, self).__init__()
1624
+ c_ = int(c2) # hidden channels
1625
+ self.cv1 = Conv(c1, c_, 1, 1)
1626
+ self.cv2 = Conv(c_, c_, 1, 1)
1627
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1628
+ num_heads = c_ // 32
1629
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1630
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1631
+
1632
+ def forward(self, x):
1633
+ x1 = self.cv1(x)
1634
+ y1 = self.m(x1)
1635
+ y2 = self.cv2(x1)
1636
+ return self.cv3(torch.cat((y1, y2), dim=1))
1637
+
1638
+
1639
+ class STCSPC(nn.Module):
1640
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1641
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1642
+ super(STCSPC, self).__init__()
1643
+ c_ = int(c2 * e) # hidden channels
1644
+ self.cv1 = Conv(c1, c_, 1, 1)
1645
+ self.cv2 = Conv(c1, c_, 1, 1)
1646
+ self.cv3 = Conv(c_, c_, 1, 1)
1647
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
1648
+ num_heads = c_ // 32
1649
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1650
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1651
+
1652
+ def forward(self, x):
1653
+ y1 = self.cv3(self.m(self.cv1(x)))
1654
+ y2 = self.cv2(x)
1655
+ return self.cv4(torch.cat((y1, y2), dim=1))
1656
+
1657
+ ##### end of swin transformer #####
1658
+
1659
+
1660
+ ##### swin transformer v2 #####
1661
+
1662
+ class WindowAttention_v2(nn.Module):
1663
+
1664
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
1665
+ pretrained_window_size=[0, 0]):
1666
+
1667
+ super().__init__()
1668
+ self.dim = dim
1669
+ self.window_size = window_size # Wh, Ww
1670
+ self.pretrained_window_size = pretrained_window_size
1671
+ self.num_heads = num_heads
1672
+
1673
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
1674
+
1675
+ # mlp to generate continuous relative position bias
1676
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
1677
+ nn.ReLU(inplace=True),
1678
+ nn.Linear(512, num_heads, bias=False))
1679
+
1680
+ # get relative_coords_table
1681
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
1682
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
1683
+ relative_coords_table = torch.stack(
1684
+ torch.meshgrid([relative_coords_h,
1685
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
1686
+ if pretrained_window_size[0] > 0:
1687
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
1688
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
1689
+ else:
1690
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
1691
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
1692
+ relative_coords_table *= 8 # normalize to -8, 8
1693
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1694
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
1695
+
1696
+ self.register_buffer("relative_coords_table", relative_coords_table)
1697
+
1698
+ # get pair-wise relative position index for each token inside the window
1699
+ coords_h = torch.arange(self.window_size[0])
1700
+ coords_w = torch.arange(self.window_size[1])
1701
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1702
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1703
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1704
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1705
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1706
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1707
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1708
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1709
+ self.register_buffer("relative_position_index", relative_position_index)
1710
+
1711
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
1712
+ if qkv_bias:
1713
+ self.q_bias = nn.Parameter(torch.zeros(dim))
1714
+ self.v_bias = nn.Parameter(torch.zeros(dim))
1715
+ else:
1716
+ self.q_bias = None
1717
+ self.v_bias = None
1718
+ self.attn_drop = nn.Dropout(attn_drop)
1719
+ self.proj = nn.Linear(dim, dim)
1720
+ self.proj_drop = nn.Dropout(proj_drop)
1721
+ self.softmax = nn.Softmax(dim=-1)
1722
+
1723
+ def forward(self, x, mask=None):
1724
+
1725
+ B_, N, C = x.shape
1726
+ qkv_bias = None
1727
+ if self.q_bias is not None:
1728
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
1729
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
1730
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
1731
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1732
+
1733
+ # cosine attention
1734
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
1735
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
1736
+ attn = attn * logit_scale
1737
+
1738
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
1739
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
1740
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1741
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1742
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
1743
+ attn = attn + relative_position_bias.unsqueeze(0)
1744
+
1745
+ if mask is not None:
1746
+ nW = mask.shape[0]
1747
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1748
+ attn = attn.view(-1, self.num_heads, N, N)
1749
+ attn = self.softmax(attn)
1750
+ else:
1751
+ attn = self.softmax(attn)
1752
+
1753
+ attn = self.attn_drop(attn)
1754
+
1755
+ try:
1756
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1757
+ except:
1758
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1759
+
1760
+ x = self.proj(x)
1761
+ x = self.proj_drop(x)
1762
+ return x
1763
+
1764
+ def extra_repr(self) -> str:
1765
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
1766
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
1767
+
1768
+ def flops(self, N):
1769
+ # calculate flops for 1 window with token length of N
1770
+ flops = 0
1771
+ # qkv = self.qkv(x)
1772
+ flops += N * self.dim * 3 * self.dim
1773
+ # attn = (q @ k.transpose(-2, -1))
1774
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
1775
+ # x = (attn @ v)
1776
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
1777
+ # x = self.proj(x)
1778
+ flops += N * self.dim * self.dim
1779
+ return flops
1780
+
1781
+ class Mlp_v2(nn.Module):
1782
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1783
+ super().__init__()
1784
+ out_features = out_features or in_features
1785
+ hidden_features = hidden_features or in_features
1786
+ self.fc1 = nn.Linear(in_features, hidden_features)
1787
+ self.act = act_layer()
1788
+ self.fc2 = nn.Linear(hidden_features, out_features)
1789
+ self.drop = nn.Dropout(drop)
1790
+
1791
+ def forward(self, x):
1792
+ x = self.fc1(x)
1793
+ x = self.act(x)
1794
+ x = self.drop(x)
1795
+ x = self.fc2(x)
1796
+ x = self.drop(x)
1797
+ return x
1798
+
1799
+
1800
+ def window_partition_v2(x, window_size):
1801
+
1802
+ B, H, W, C = x.shape
1803
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1804
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1805
+ return windows
1806
+
1807
+
1808
+ def window_reverse_v2(windows, window_size, H, W):
1809
+
1810
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1811
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1812
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1813
+ return x
1814
+
1815
+
1816
+ class SwinTransformerLayer_v2(nn.Module):
1817
+
1818
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
1819
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
1820
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
1821
+ super().__init__()
1822
+ self.dim = dim
1823
+ #self.input_resolution = input_resolution
1824
+ self.num_heads = num_heads
1825
+ self.window_size = window_size
1826
+ self.shift_size = shift_size
1827
+ self.mlp_ratio = mlp_ratio
1828
+ #if min(self.input_resolution) <= self.window_size:
1829
+ # # if window size is larger than input resolution, we don't partition windows
1830
+ # self.shift_size = 0
1831
+ # self.window_size = min(self.input_resolution)
1832
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1833
+
1834
+ self.norm1 = norm_layer(dim)
1835
+ self.attn = WindowAttention_v2(
1836
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1837
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1838
+ pretrained_window_size=(pretrained_window_size, pretrained_window_size))
1839
+
1840
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1841
+ self.norm2 = norm_layer(dim)
1842
+ mlp_hidden_dim = int(dim * mlp_ratio)
1843
+ self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1844
+
1845
+ def create_mask(self, H, W):
1846
+ # calculate attention mask for SW-MSA
1847
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1848
+ h_slices = (slice(0, -self.window_size),
1849
+ slice(-self.window_size, -self.shift_size),
1850
+ slice(-self.shift_size, None))
1851
+ w_slices = (slice(0, -self.window_size),
1852
+ slice(-self.window_size, -self.shift_size),
1853
+ slice(-self.shift_size, None))
1854
+ cnt = 0
1855
+ for h in h_slices:
1856
+ for w in w_slices:
1857
+ img_mask[:, h, w, :] = cnt
1858
+ cnt += 1
1859
+
1860
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1861
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1862
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1863
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1864
+
1865
+ return attn_mask
1866
+
1867
+ def forward(self, x):
1868
+ # reshape x[b c h w] to x[b l c]
1869
+ _, _, H_, W_ = x.shape
1870
+
1871
+ Padding = False
1872
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1873
+ Padding = True
1874
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1875
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1876
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1877
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1878
+
1879
+ # print('2', x.shape)
1880
+ B, C, H, W = x.shape
1881
+ L = H * W
1882
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1883
+
1884
+ # create mask from init to forward
1885
+ if self.shift_size > 0:
1886
+ attn_mask = self.create_mask(H, W).to(x.device)
1887
+ else:
1888
+ attn_mask = None
1889
+
1890
+ shortcut = x
1891
+ x = x.view(B, H, W, C)
1892
+
1893
+ # cyclic shift
1894
+ if self.shift_size > 0:
1895
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1896
+ else:
1897
+ shifted_x = x
1898
+
1899
+ # partition windows
1900
+ x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1901
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1902
+
1903
+ # W-MSA/SW-MSA
1904
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1905
+
1906
+ # merge windows
1907
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1908
+ shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
1909
+
1910
+ # reverse cyclic shift
1911
+ if self.shift_size > 0:
1912
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1913
+ else:
1914
+ x = shifted_x
1915
+ x = x.view(B, H * W, C)
1916
+ x = shortcut + self.drop_path(self.norm1(x))
1917
+
1918
+ # FFN
1919
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
1920
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1921
+
1922
+ if Padding:
1923
+ x = x[:, :, :H_, :W_] # reverse padding
1924
+
1925
+ return x
1926
+
1927
+ def extra_repr(self) -> str:
1928
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
1929
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
1930
+
1931
+ def flops(self):
1932
+ flops = 0
1933
+ H, W = self.input_resolution
1934
+ # norm1
1935
+ flops += self.dim * H * W
1936
+ # W-MSA/SW-MSA
1937
+ nW = H * W / self.window_size / self.window_size
1938
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
1939
+ # mlp
1940
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
1941
+ # norm2
1942
+ flops += self.dim * H * W
1943
+ return flops
1944
+
1945
+
1946
+ class SwinTransformer2Block(nn.Module):
1947
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
1948
+ super().__init__()
1949
+ self.conv = None
1950
+ if c1 != c2:
1951
+ self.conv = Conv(c1, c2)
1952
+
1953
+ # remove input_resolution
1954
+ self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
1955
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1956
+
1957
+ def forward(self, x):
1958
+ if self.conv is not None:
1959
+ x = self.conv(x)
1960
+ x = self.blocks(x)
1961
+ return x
1962
+
1963
+
1964
+ class ST2CSPA(nn.Module):
1965
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1966
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1967
+ super(ST2CSPA, self).__init__()
1968
+ c_ = int(c2 * e) # hidden channels
1969
+ self.cv1 = Conv(c1, c_, 1, 1)
1970
+ self.cv2 = Conv(c1, c_, 1, 1)
1971
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1972
+ num_heads = c_ // 32
1973
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
1974
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1975
+
1976
+ def forward(self, x):
1977
+ y1 = self.m(self.cv1(x))
1978
+ y2 = self.cv2(x)
1979
+ return self.cv3(torch.cat((y1, y2), dim=1))
1980
+
1981
+
1982
+ class ST2CSPB(nn.Module):
1983
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1984
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1985
+ super(ST2CSPB, self).__init__()
1986
+ c_ = int(c2) # hidden channels
1987
+ self.cv1 = Conv(c1, c_, 1, 1)
1988
+ self.cv2 = Conv(c_, c_, 1, 1)
1989
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1990
+ num_heads = c_ // 32
1991
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
1992
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1993
+
1994
+ def forward(self, x):
1995
+ x1 = self.cv1(x)
1996
+ y1 = self.m(x1)
1997
+ y2 = self.cv2(x1)
1998
+ return self.cv3(torch.cat((y1, y2), dim=1))
1999
+
2000
+
2001
+ class ST2CSPC(nn.Module):
2002
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
2003
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
2004
+ super(ST2CSPC, self).__init__()
2005
+ c_ = int(c2 * e) # hidden channels
2006
+ self.cv1 = Conv(c1, c_, 1, 1)
2007
+ self.cv2 = Conv(c1, c_, 1, 1)
2008
+ self.cv3 = Conv(c_, c_, 1, 1)
2009
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
2010
+ num_heads = c_ // 32
2011
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
2012
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
2013
+
2014
+ def forward(self, x):
2015
+ y1 = self.cv3(self.m(self.cv1(x)))
2016
+ y2 = self.cv2(x)
2017
+ return self.cv4(torch.cat((y1, y2), dim=1))
2018
+
2019
+ ##### end of swin transformer v2 #####
models/experimental.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.common import Conv, DWConv
7
+ from utils.google_utils import attempt_download
8
+
9
+
10
+ class CrossConv(nn.Module):
11
+ # Cross Convolution Downsample
12
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
13
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
14
+ super(CrossConv, self).__init__()
15
+ c_ = int(c2 * e) # hidden channels
16
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
17
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
18
+ self.add = shortcut and c1 == c2
19
+
20
+ def forward(self, x):
21
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
22
+
23
+
24
+ class Sum(nn.Module):
25
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
26
+ def __init__(self, n, weight=False): # n: number of inputs
27
+ super(Sum, self).__init__()
28
+ self.weight = weight # apply weights boolean
29
+ self.iter = range(n - 1) # iter object
30
+ if weight:
31
+ self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
32
+
33
+ def forward(self, x):
34
+ y = x[0] # no weight
35
+ if self.weight:
36
+ w = torch.sigmoid(self.w) * 2
37
+ for i in self.iter:
38
+ y = y + x[i + 1] * w[i]
39
+ else:
40
+ for i in self.iter:
41
+ y = y + x[i + 1]
42
+ return y
43
+
44
+
45
+ class MixConv2d(nn.Module):
46
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
47
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
48
+ super(MixConv2d, self).__init__()
49
+ groups = len(k)
50
+ if equal_ch: # equal c_ per group
51
+ i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
52
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
53
+ else: # equal weight.numel() per group
54
+ b = [c2] + [0] * groups
55
+ a = np.eye(groups + 1, groups, k=-1)
56
+ a -= np.roll(a, 1, axis=1)
57
+ a *= np.array(k) ** 2
58
+ a[0] = 1
59
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
60
+
61
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
62
+ self.bn = nn.BatchNorm2d(c2)
63
+ self.act = nn.LeakyReLU(0.1, inplace=True)
64
+
65
+ def forward(self, x):
66
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
67
+
68
+
69
+ class Ensemble(nn.ModuleList):
70
+ # Ensemble of models
71
+ def __init__(self):
72
+ super(Ensemble, self).__init__()
73
+
74
+ def forward(self, x, augment=False):
75
+ y = []
76
+ for module in self:
77
+ y.append(module(x, augment)[0])
78
+ # y = torch.stack(y).max(0)[0] # max ensemble
79
+ # y = torch.stack(y).mean(0) # mean ensemble
80
+ y = torch.cat(y, 1) # nms ensemble
81
+ return y, None # inference, train output
82
+
83
+
84
+
85
+
86
+
87
+ class ORT_NMS(torch.autograd.Function):
88
+ '''ONNX-Runtime NMS operation'''
89
+ @staticmethod
90
+ def forward(ctx,
91
+ boxes,
92
+ scores,
93
+ max_output_boxes_per_class=torch.tensor([100]),
94
+ iou_threshold=torch.tensor([0.45]),
95
+ score_threshold=torch.tensor([0.25])):
96
+ device = boxes.device
97
+ batch = scores.shape[0]
98
+ num_det = random.randint(0, 100)
99
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
100
+ idxs = torch.arange(100, 100 + num_det).to(device)
101
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
102
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
103
+ selected_indices = selected_indices.to(torch.int64)
104
+ return selected_indices
105
+
106
+ @staticmethod
107
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
108
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
109
+
110
+
111
+ class TRT_NMS(torch.autograd.Function):
112
+ '''TensorRT NMS operation'''
113
+ @staticmethod
114
+ def forward(
115
+ ctx,
116
+ boxes,
117
+ scores,
118
+ background_class=-1,
119
+ box_coding=1,
120
+ iou_threshold=0.45,
121
+ max_output_boxes=100,
122
+ plugin_version="1",
123
+ score_activation=0,
124
+ score_threshold=0.25,
125
+ ):
126
+ batch_size, num_boxes, num_classes = scores.shape
127
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
128
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
129
+ det_scores = torch.randn(batch_size, max_output_boxes)
130
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
131
+ return num_det, det_boxes, det_scores, det_classes
132
+
133
+ @staticmethod
134
+ def symbolic(g,
135
+ boxes,
136
+ scores,
137
+ background_class=-1,
138
+ box_coding=1,
139
+ iou_threshold=0.45,
140
+ max_output_boxes=100,
141
+ plugin_version="1",
142
+ score_activation=0,
143
+ score_threshold=0.25):
144
+ out = g.op("TRT::EfficientNMS_TRT",
145
+ boxes,
146
+ scores,
147
+ background_class_i=background_class,
148
+ box_coding_i=box_coding,
149
+ iou_threshold_f=iou_threshold,
150
+ max_output_boxes_i=max_output_boxes,
151
+ plugin_version_s=plugin_version,
152
+ score_activation_i=score_activation,
153
+ score_threshold_f=score_threshold,
154
+ outputs=4)
155
+ nums, boxes, scores, classes = out
156
+ return nums, boxes, scores, classes
157
+
158
+
159
+ class ONNX_ORT(nn.Module):
160
+ '''onnx module with ONNX-Runtime NMS operation.'''
161
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
162
+ super().__init__()
163
+ self.device = device if device else torch.device("cpu")
164
+ self.max_obj = torch.tensor([max_obj]).to(device)
165
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
166
+ self.score_threshold = torch.tensor([score_thres]).to(device)
167
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
168
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
169
+ dtype=torch.float32,
170
+ device=self.device)
171
+ self.n_classes=n_classes
172
+
173
+ def forward(self, x):
174
+ boxes = x[:, :, :4]
175
+ conf = x[:, :, 4:5]
176
+ scores = x[:, :, 5:]
177
+ if self.n_classes == 1:
178
+ scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
179
+ # so there is no need to multiplicate.
180
+ else:
181
+ scores *= conf # conf = obj_conf * cls_conf
182
+ boxes @= self.convert_matrix
183
+ max_score, category_id = scores.max(2, keepdim=True)
184
+ dis = category_id.float() * self.max_wh
185
+ nmsbox = boxes + dis
186
+ max_score_tp = max_score.transpose(1, 2).contiguous()
187
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
188
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
189
+ selected_boxes = boxes[X, Y, :]
190
+ selected_categories = category_id[X, Y, :].float()
191
+ selected_scores = max_score[X, Y, :]
192
+ X = X.unsqueeze(1).float()
193
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
194
+
195
+ class ONNX_TRT(nn.Module):
196
+ '''onnx module with TensorRT NMS operation.'''
197
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
198
+ super().__init__()
199
+ assert max_wh is None
200
+ self.device = device if device else torch.device('cpu')
201
+ self.background_class = -1,
202
+ self.box_coding = 1,
203
+ self.iou_threshold = iou_thres
204
+ self.max_obj = max_obj
205
+ self.plugin_version = '1'
206
+ self.score_activation = 0
207
+ self.score_threshold = score_thres
208
+ self.n_classes=n_classes
209
+
210
+ def forward(self, x):
211
+ boxes = x[:, :, :4]
212
+ conf = x[:, :, 4:5]
213
+ scores = x[:, :, 5:]
214
+ if self.n_classes == 1:
215
+ scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
216
+ # so there is no need to multiplicate.
217
+ else:
218
+ scores *= conf # conf = obj_conf * cls_conf
219
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
220
+ self.iou_threshold, self.max_obj,
221
+ self.plugin_version, self.score_activation,
222
+ self.score_threshold)
223
+ return num_det, det_boxes, det_scores, det_classes
224
+
225
+
226
+ class End2End(nn.Module):
227
+ '''export onnx or tensorrt model with NMS operation.'''
228
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
229
+ super().__init__()
230
+ device = device if device else torch.device('cpu')
231
+ assert isinstance(max_wh,(int)) or max_wh is None
232
+ self.model = model.to(device)
233
+ self.model.model[-1].end2end = True
234
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
235
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
236
+ self.end2end.eval()
237
+
238
+ def forward(self, x):
239
+ x = self.model(x)
240
+ x = self.end2end(x)
241
+ return x
242
+
243
+
244
+
245
+
246
+
247
+ def attempt_load(weights, map_location=None):
248
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
249
+ model = Ensemble()
250
+ for w in weights if isinstance(weights, list) else [weights]:
251
+ attempt_download(w)
252
+ ckpt = torch.load(w, map_location=map_location) # load
253
+ model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
254
+
255
+ # Compatibility updates
256
+ for m in model.modules():
257
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
258
+ m.inplace = True # pytorch 1.7.0 compatibility
259
+ elif type(m) is nn.Upsample:
260
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
261
+ elif type(m) is Conv:
262
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
263
+
264
+ if len(model) == 1:
265
+ return model[-1] # return model
266
+ else:
267
+ print('Ensemble created with %s\n' % weights)
268
+ for k in ['names', 'stride']:
269
+ setattr(model, k, getattr(model[-1], k))
270
+ return model # return ensemble
271
+
272
+
models/yolo.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ from copy import deepcopy
5
+
6
+ sys.path.append('./') # to run '$ python *.py' files in subdirectories
7
+ logger = logging.getLogger(__name__)
8
+ import torch
9
+ from models.common import *
10
+ from models.experimental import *
11
+ from utils.autoanchor import check_anchor_order
12
+ from utils.general import make_divisible, check_file, set_logging
13
+ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
14
+ select_device, copy_attr
15
+ from utils.loss import SigmoidBin
16
+
17
+ try:
18
+ import thop # for FLOPS computation
19
+ except ImportError:
20
+ thop = None
21
+
22
+
23
+ class Detect(nn.Module):
24
+ stride = None # strides computed during build
25
+ export = False # onnx export
26
+ end2end = False
27
+ include_nms = False
28
+ concat = False
29
+
30
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
31
+ super(Detect, self).__init__()
32
+ self.nc = nc # number of classes
33
+ self.no = nc + 5 # number of outputs per anchor
34
+ self.nl = len(anchors) # number of detection layers
35
+ self.na = len(anchors[0]) // 2 # number of anchors
36
+ self.grid = [torch.zeros(1)] * self.nl # init grid
37
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
38
+ self.register_buffer('anchors', a) # shape(nl,na,2)
39
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
40
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
41
+
42
+ def forward(self, x):
43
+ # x = x.copy() # for profiling
44
+ z = [] # inference output
45
+ self.training |= self.export
46
+ for i in range(self.nl):
47
+ x[i] = self.m[i](x[i]) # conv
48
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
49
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
50
+
51
+ if not self.training: # inference
52
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
53
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
54
+ y = x[i].sigmoid()
55
+ if not torch.onnx.is_in_onnx_export():
56
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
57
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
58
+ else:
59
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
60
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
61
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
62
+ y = torch.cat((xy, wh, conf), 4)
63
+ z.append(y.view(bs, -1, self.no))
64
+
65
+ if self.training:
66
+ out = x
67
+ elif self.end2end:
68
+ out = torch.cat(z, 1)
69
+ elif self.include_nms:
70
+ z = self.convert(z)
71
+ out = (z, )
72
+ elif self.concat:
73
+ out = torch.cat(z, 1)
74
+ else:
75
+ out = (torch.cat(z, 1), x)
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def _make_grid(nx=20, ny=20):
81
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
82
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
83
+
84
+ def convert(self, z):
85
+ z = torch.cat(z, 1)
86
+ box = z[:, :, :4]
87
+ conf = z[:, :, 4:5]
88
+ score = z[:, :, 5:]
89
+ score *= conf
90
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
91
+ dtype=torch.float32,
92
+ device=z.device)
93
+ box @= convert_matrix
94
+ return (box, score)
95
+
96
+
97
+ class IDetect(nn.Module):
98
+ stride = None # strides computed during build
99
+ export = False # onnx export
100
+ end2end = False
101
+ include_nms = False
102
+ concat = False
103
+
104
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
105
+ super(IDetect, self).__init__()
106
+ self.nc = nc # number of classes
107
+ self.no = nc + 5 # number of outputs per anchor
108
+ self.nl = len(anchors) # number of detection layers
109
+ self.na = len(anchors[0]) // 2 # number of anchors
110
+ self.grid = [torch.zeros(1)] * self.nl # init grid
111
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
112
+ self.register_buffer('anchors', a) # shape(nl,na,2)
113
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
114
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
115
+
116
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
117
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
118
+
119
+ def forward(self, x):
120
+ # x = x.copy() # for profiling
121
+ z = [] # inference output
122
+ self.training |= self.export
123
+ for i in range(self.nl):
124
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
125
+ x[i] = self.im[i](x[i])
126
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
127
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
128
+
129
+ if not self.training: # inference
130
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
131
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
132
+
133
+ y = x[i].sigmoid()
134
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
135
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
136
+ z.append(y.view(bs, -1, self.no))
137
+
138
+ return x if self.training else (torch.cat(z, 1), x)
139
+
140
+ def fuseforward(self, x):
141
+ # x = x.copy() # for profiling
142
+ z = [] # inference output
143
+ self.training |= self.export
144
+ for i in range(self.nl):
145
+ x[i] = self.m[i](x[i]) # conv
146
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
147
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
148
+
149
+ if not self.training: # inference
150
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
151
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
152
+
153
+ y = x[i].sigmoid()
154
+ if not torch.onnx.is_in_onnx_export():
155
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
156
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
157
+ else:
158
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
159
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
160
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
161
+ y = torch.cat((xy, wh, conf), 4)
162
+ z.append(y.view(bs, -1, self.no))
163
+
164
+ if self.training:
165
+ out = x
166
+ elif self.end2end:
167
+ out = torch.cat(z, 1)
168
+ elif self.include_nms:
169
+ z = self.convert(z)
170
+ out = (z, )
171
+ elif self.concat:
172
+ out = torch.cat(z, 1)
173
+ else:
174
+ out = (torch.cat(z, 1), x)
175
+
176
+ return out
177
+
178
+ def fuse(self):
179
+ print("IDetect.fuse")
180
+ # fuse ImplicitA and Convolution
181
+ for i in range(len(self.m)):
182
+ c1,c2,_,_ = self.m[i].weight.shape
183
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
184
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
185
+
186
+ # fuse ImplicitM and Convolution
187
+ for i in range(len(self.m)):
188
+ c1,c2, _,_ = self.im[i].implicit.shape
189
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
190
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
191
+
192
+ @staticmethod
193
+ def _make_grid(nx=20, ny=20):
194
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
195
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
196
+
197
+ def convert(self, z):
198
+ z = torch.cat(z, 1)
199
+ box = z[:, :, :4]
200
+ conf = z[:, :, 4:5]
201
+ score = z[:, :, 5:]
202
+ score *= conf
203
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
204
+ dtype=torch.float32,
205
+ device=z.device)
206
+ box @= convert_matrix
207
+ return (box, score)
208
+
209
+
210
+ class IKeypoint(nn.Module):
211
+ stride = None # strides computed during build
212
+ export = False # onnx export
213
+
214
+ def __init__(self, nc=80, anchors=(), nkpt=17, ch=(), inplace=True, dw_conv_kpt=False): # detection layer
215
+ super(IKeypoint, self).__init__()
216
+ self.nc = nc # number of classes
217
+ self.nkpt = nkpt
218
+ self.dw_conv_kpt = dw_conv_kpt
219
+ self.no_det=(nc + 5) # number of outputs per anchor for box and class
220
+ self.no_kpt = 3*self.nkpt ## number of outputs per anchor for keypoints
221
+ self.no = self.no_det+self.no_kpt
222
+ self.nl = len(anchors) # number of detection layers
223
+ self.na = len(anchors[0]) // 2 # number of anchors
224
+ self.grid = [torch.zeros(1)] * self.nl # init grid
225
+ self.flip_test = False
226
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
227
+ self.register_buffer('anchors', a) # shape(nl,na,2)
228
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
229
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no_det * self.na, 1) for x in ch) # output conv
230
+
231
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
232
+ self.im = nn.ModuleList(ImplicitM(self.no_det * self.na) for _ in ch)
233
+
234
+ if self.nkpt is not None:
235
+ if self.dw_conv_kpt: #keypoint head is slightly more complex
236
+ self.m_kpt = nn.ModuleList(
237
+ nn.Sequential(DWConv(x, x, k=3), Conv(x,x),
238
+ DWConv(x, x, k=3), Conv(x, x),
239
+ DWConv(x, x, k=3), Conv(x,x),
240
+ DWConv(x, x, k=3), Conv(x, x),
241
+ DWConv(x, x, k=3), Conv(x, x),
242
+ DWConv(x, x, k=3), nn.Conv2d(x, self.no_kpt * self.na, 1)) for x in ch)
243
+ else: #keypoint head is a single convolution
244
+ self.m_kpt = nn.ModuleList(nn.Conv2d(x, self.no_kpt * self.na, 1) for x in ch)
245
+
246
+ self.inplace = inplace # use in-place ops (e.g. slice assignment)
247
+
248
+ def forward(self, x):
249
+ # x = x.copy() # for profiling
250
+ z = [] # inference output
251
+ self.training |= self.export
252
+ for i in range(self.nl):
253
+ if self.nkpt is None or self.nkpt==0:
254
+ x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv
255
+ else :
256
+ x[i] = torch.cat((self.im[i](self.m[i](self.ia[i](x[i]))), self.m_kpt[i](x[i])), axis=1)
257
+
258
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
259
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
260
+ x_det = x[i][..., :6]
261
+ x_kpt = x[i][..., 6:]
262
+
263
+ if not self.training: # inference
264
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
265
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
266
+ kpt_grid_x = self.grid[i][..., 0:1]
267
+ kpt_grid_y = self.grid[i][..., 1:2]
268
+
269
+ if self.nkpt == 0:
270
+ y = x[i].sigmoid()
271
+ else:
272
+ y = x_det.sigmoid()
273
+
274
+ if self.inplace:
275
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
276
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
277
+ if self.nkpt != 0:
278
+ x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
279
+ x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
280
+ #x_kpt[..., 0::3] = (x_kpt[..., ::3] + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
281
+ #x_kpt[..., 1::3] = (x_kpt[..., 1::3] + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
282
+ #print('=============')
283
+ #print(self.anchor_grid[i].shape)
284
+ #print(self.anchor_grid[i][...,0].unsqueeze(4).shape)
285
+ #print(x_kpt[..., 0::3].shape)
286
+ #x_kpt[..., 0::3] = ((x_kpt[..., 0::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
287
+ #x_kpt[..., 1::3] = ((x_kpt[..., 1::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
288
+ #x_kpt[..., 0::3] = (((x_kpt[..., 0::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
289
+ #x_kpt[..., 1::3] = (((x_kpt[..., 1::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
290
+ x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid()
291
+
292
+ y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1)
293
+
294
+ else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
295
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
296
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
297
+ if self.nkpt != 0:
298
+ y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i] # xy
299
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
300
+
301
+ z.append(y.view(bs, -1, self.no))
302
+
303
+ return x if self.training else (torch.cat(z, 1), x)
304
+
305
+ @staticmethod
306
+ def _make_grid(nx=20, ny=20):
307
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
308
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
309
+
310
+
311
+ class IAuxDetect(nn.Module):
312
+ stride = None # strides computed during build
313
+ export = False # onnx export
314
+ end2end = False
315
+ include_nms = False
316
+ concat = False
317
+
318
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
319
+ super(IAuxDetect, self).__init__()
320
+ self.nc = nc # number of classes
321
+ self.no = nc + 5 # number of outputs per anchor
322
+ self.nl = len(anchors) # number of detection layers
323
+ self.na = len(anchors[0]) // 2 # number of anchors
324
+ self.grid = [torch.zeros(1)] * self.nl # init grid
325
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
326
+ self.register_buffer('anchors', a) # shape(nl,na,2)
327
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
328
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl]) # output conv
329
+ self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:]) # output conv
330
+
331
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
332
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])
333
+
334
+ def forward(self, x):
335
+ # x = x.copy() # for profiling
336
+ z = [] # inference output
337
+ self.training |= self.export
338
+ for i in range(self.nl):
339
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
340
+ x[i] = self.im[i](x[i])
341
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
342
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
343
+
344
+ x[i+self.nl] = self.m2[i](x[i+self.nl])
345
+ x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
346
+
347
+ if not self.training: # inference
348
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
349
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
350
+
351
+ y = x[i].sigmoid()
352
+ if not torch.onnx.is_in_onnx_export():
353
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
354
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
355
+ else:
356
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
357
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
358
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
359
+ y = torch.cat((xy, wh, conf), 4)
360
+ z.append(y.view(bs, -1, self.no))
361
+
362
+ return x if self.training else (torch.cat(z, 1), x[:self.nl])
363
+
364
+ def fuseforward(self, x):
365
+ # x = x.copy() # for profiling
366
+ z = [] # inference output
367
+ self.training |= self.export
368
+ for i in range(self.nl):
369
+ x[i] = self.m[i](x[i]) # conv
370
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
371
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
372
+
373
+ if not self.training: # inference
374
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
375
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
376
+
377
+ y = x[i].sigmoid()
378
+ if not torch.onnx.is_in_onnx_export():
379
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
380
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
381
+ else:
382
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
383
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
384
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
385
+ z.append(y.view(bs, -1, self.no))
386
+
387
+ if self.training:
388
+ out = x
389
+ elif self.end2end:
390
+ out = torch.cat(z, 1)
391
+ elif self.include_nms:
392
+ z = self.convert(z)
393
+ out = (z, )
394
+ elif self.concat:
395
+ out = torch.cat(z, 1)
396
+ else:
397
+ out = (torch.cat(z, 1), x)
398
+
399
+ return out
400
+
401
+ def fuse(self):
402
+ print("IAuxDetect.fuse")
403
+ # fuse ImplicitA and Convolution
404
+ for i in range(len(self.m)):
405
+ c1,c2,_,_ = self.m[i].weight.shape
406
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
407
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
408
+
409
+ # fuse ImplicitM and Convolution
410
+ for i in range(len(self.m)):
411
+ c1,c2, _,_ = self.im[i].implicit.shape
412
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
413
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
414
+
415
+ @staticmethod
416
+ def _make_grid(nx=20, ny=20):
417
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
418
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
419
+
420
+ def convert(self, z):
421
+ z = torch.cat(z, 1)
422
+ box = z[:, :, :4]
423
+ conf = z[:, :, 4:5]
424
+ score = z[:, :, 5:]
425
+ score *= conf
426
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
427
+ dtype=torch.float32,
428
+ device=z.device)
429
+ box @= convert_matrix
430
+ return (box, score)
431
+
432
+
433
+ class IBin(nn.Module):
434
+ stride = None # strides computed during build
435
+ export = False # onnx export
436
+
437
+ def __init__(self, nc=80, anchors=(), ch=(), bin_count=21): # detection layer
438
+ super(IBin, self).__init__()
439
+ self.nc = nc # number of classes
440
+ self.bin_count = bin_count
441
+
442
+ self.w_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
443
+ self.h_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
444
+ # classes, x,y,obj
445
+ self.no = nc + 3 + \
446
+ self.w_bin_sigmoid.get_length() + self.h_bin_sigmoid.get_length() # w-bce, h-bce
447
+ # + self.x_bin_sigmoid.get_length() + self.y_bin_sigmoid.get_length()
448
+
449
+ self.nl = len(anchors) # number of detection layers
450
+ self.na = len(anchors[0]) // 2 # number of anchors
451
+ self.grid = [torch.zeros(1)] * self.nl # init grid
452
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
453
+ self.register_buffer('anchors', a) # shape(nl,na,2)
454
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
455
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
456
+
457
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
458
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
459
+
460
+ def forward(self, x):
461
+
462
+ #self.x_bin_sigmoid.use_fw_regression = True
463
+ #self.y_bin_sigmoid.use_fw_regression = True
464
+ self.w_bin_sigmoid.use_fw_regression = True
465
+ self.h_bin_sigmoid.use_fw_regression = True
466
+
467
+ # x = x.copy() # for profiling
468
+ z = [] # inference output
469
+ self.training |= self.export
470
+ for i in range(self.nl):
471
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
472
+ x[i] = self.im[i](x[i])
473
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
474
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
475
+
476
+ if not self.training: # inference
477
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
478
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
479
+
480
+ y = x[i].sigmoid()
481
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
482
+ #y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
483
+
484
+
485
+ #px = (self.x_bin_sigmoid.forward(y[..., 0:12]) + self.grid[i][..., 0]) * self.stride[i]
486
+ #py = (self.y_bin_sigmoid.forward(y[..., 12:24]) + self.grid[i][..., 1]) * self.stride[i]
487
+
488
+ pw = self.w_bin_sigmoid.forward(y[..., 2:24]) * self.anchor_grid[i][..., 0]
489
+ ph = self.h_bin_sigmoid.forward(y[..., 24:46]) * self.anchor_grid[i][..., 1]
490
+
491
+ #y[..., 0] = px
492
+ #y[..., 1] = py
493
+ y[..., 2] = pw
494
+ y[..., 3] = ph
495
+
496
+ y = torch.cat((y[..., 0:4], y[..., 46:]), dim=-1)
497
+
498
+ z.append(y.view(bs, -1, y.shape[-1]))
499
+
500
+ return x if self.training else (torch.cat(z, 1), x)
501
+
502
+ @staticmethod
503
+ def _make_grid(nx=20, ny=20):
504
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
505
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
506
+
507
+
508
+ class Model(nn.Module):
509
+ def __init__(self, cfg='yolor-csp-c.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
510
+ super(Model, self).__init__()
511
+ self.traced = False
512
+ if isinstance(cfg, dict):
513
+ self.yaml = cfg # model dict
514
+ else: # is *.yaml
515
+ import yaml # for torch hub
516
+ self.yaml_file = Path(cfg).name
517
+ with open(cfg) as f:
518
+ self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
519
+
520
+ # Define model
521
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
522
+ if nc and nc != self.yaml['nc']:
523
+ logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
524
+ self.yaml['nc'] = nc # override yaml value
525
+ if anchors:
526
+ logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
527
+ self.yaml['anchors'] = round(anchors) # override yaml value
528
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
529
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
530
+ # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
531
+
532
+ # Build strides, anchors
533
+ m = self.model[-1] # Detect()
534
+ if isinstance(m, Detect):
535
+ s = 256 # 2x min stride
536
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
537
+ check_anchor_order(m)
538
+ m.anchors /= m.stride.view(-1, 1, 1)
539
+ self.stride = m.stride
540
+ self._initialize_biases() # only run once
541
+ # print('Strides: %s' % m.stride.tolist())
542
+ if isinstance(m, IDetect):
543
+ s = 256 # 2x min stride
544
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
545
+ check_anchor_order(m)
546
+ m.anchors /= m.stride.view(-1, 1, 1)
547
+ self.stride = m.stride
548
+ self._initialize_biases() # only run once
549
+ # print('Strides: %s' % m.stride.tolist())
550
+ if isinstance(m, IAuxDetect):
551
+ s = 256 # 2x min stride
552
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
553
+ #print(m.stride)
554
+ check_anchor_order(m)
555
+ m.anchors /= m.stride.view(-1, 1, 1)
556
+ self.stride = m.stride
557
+ self._initialize_aux_biases() # only run once
558
+ # print('Strides: %s' % m.stride.tolist())
559
+ if isinstance(m, IBin):
560
+ s = 256 # 2x min stride
561
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
562
+ check_anchor_order(m)
563
+ m.anchors /= m.stride.view(-1, 1, 1)
564
+ self.stride = m.stride
565
+ self._initialize_biases_bin() # only run once
566
+ # print('Strides: %s' % m.stride.tolist())
567
+ if isinstance(m, IKeypoint):
568
+ s = 256 # 2x min stride
569
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
570
+ check_anchor_order(m)
571
+ m.anchors /= m.stride.view(-1, 1, 1)
572
+ self.stride = m.stride
573
+ self._initialize_biases_kpt() # only run once
574
+ # print('Strides: %s' % m.stride.tolist())
575
+
576
+ # Init weights, biases
577
+ initialize_weights(self)
578
+ self.info()
579
+ logger.info('')
580
+
581
+ def forward(self, x, augment=False, profile=False):
582
+ if augment:
583
+ img_size = x.shape[-2:] # height, width
584
+ s = [1, 0.83, 0.67] # scales
585
+ f = [None, 3, None] # flips (2-ud, 3-lr)
586
+ y = [] # outputs
587
+ for si, fi in zip(s, f):
588
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
589
+ yi = self.forward_once(xi)[0] # forward
590
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
591
+ yi[..., :4] /= si # de-scale
592
+ if fi == 2:
593
+ yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
594
+ elif fi == 3:
595
+ yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
596
+ y.append(yi)
597
+ return torch.cat(y, 1), None # augmented inference, train
598
+ else:
599
+ return self.forward_once(x, profile) # single-scale inference, train
600
+
601
+ def forward_once(self, x, profile=False):
602
+ y, dt = [], [] # outputs
603
+ for m in self.model:
604
+ if m.f != -1: # if not from previous layer
605
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
606
+
607
+ if not hasattr(self, 'traced'):
608
+ self.traced=False
609
+
610
+ if self.traced:
611
+ if isinstance(m, Detect) or isinstance(m, IDetect) or isinstance(m, IAuxDetect) or isinstance(m, IKeypoint):
612
+ break
613
+
614
+ if profile:
615
+ c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
616
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
617
+ for _ in range(10):
618
+ m(x.copy() if c else x)
619
+ t = time_synchronized()
620
+ for _ in range(10):
621
+ m(x.copy() if c else x)
622
+ dt.append((time_synchronized() - t) * 100)
623
+ print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
624
+
625
+ x = m(x) # run
626
+
627
+ y.append(x if m.i in self.save else None) # save output
628
+
629
+ if profile:
630
+ print('%.1fms total' % sum(dt))
631
+ return x
632
+
633
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
634
+ # https://arxiv.org/abs/1708.02002 section 3.3
635
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
636
+ m = self.model[-1] # Detect() module
637
+ for mi, s in zip(m.m, m.stride): # from
638
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
639
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
640
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
641
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
642
+
643
+ def _initialize_aux_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
644
+ # https://arxiv.org/abs/1708.02002 section 3.3
645
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
646
+ m = self.model[-1] # Detect() module
647
+ for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
648
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
649
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
650
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
651
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
652
+ b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
653
+ b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
654
+ b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
655
+ mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
656
+
657
+ def _initialize_biases_bin(self, cf=None): # initialize biases into Detect(), cf is class frequency
658
+ # https://arxiv.org/abs/1708.02002 section 3.3
659
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
660
+ m = self.model[-1] # Bin() module
661
+ bc = m.bin_count
662
+ for mi, s in zip(m.m, m.stride): # from
663
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
664
+ old = b[:, (0,1,2,bc+3)].data
665
+ obj_idx = 2*bc+4
666
+ b[:, :obj_idx].data += math.log(0.6 / (bc + 1 - 0.99))
667
+ b[:, obj_idx].data += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
668
+ b[:, (obj_idx+1):].data += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
669
+ b[:, (0,1,2,bc+3)].data = old
670
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
671
+
672
+ def _initialize_biases_kpt(self, cf=None): # initialize biases into Detect(), cf is class frequency
673
+ # https://arxiv.org/abs/1708.02002 section 3.3
674
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
675
+ m = self.model[-1] # Detect() module
676
+ for mi, s in zip(m.m, m.stride): # from
677
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
678
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
679
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
680
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
681
+
682
+ def _print_biases(self):
683
+ m = self.model[-1] # Detect() module
684
+ for mi in m.m: # from
685
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
686
+ print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
687
+
688
+ # def _print_weights(self):
689
+ # for m in self.model.modules():
690
+ # if type(m) is Bottleneck:
691
+ # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
692
+
693
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
694
+ print('Fusing layers... ')
695
+ for m in self.model.modules():
696
+ if isinstance(m, RepConv):
697
+ #print(f" fuse_repvgg_block")
698
+ m.fuse_repvgg_block()
699
+ elif isinstance(m, RepConv_OREPA):
700
+ #print(f" switch_to_deploy")
701
+ m.switch_to_deploy()
702
+ elif type(m) is Conv and hasattr(m, 'bn'):
703
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
704
+ delattr(m, 'bn') # remove batchnorm
705
+ m.forward = m.fuseforward # update forward
706
+ elif isinstance(m, (IDetect, IAuxDetect)):
707
+ m.fuse()
708
+ m.forward = m.fuseforward
709
+ self.info()
710
+ return self
711
+
712
+ def nms(self, mode=True): # add or remove NMS module
713
+ present = type(self.model[-1]) is NMS # last layer is NMS
714
+ if mode and not present:
715
+ print('Adding NMS... ')
716
+ m = NMS() # module
717
+ m.f = -1 # from
718
+ m.i = self.model[-1].i + 1 # index
719
+ self.model.add_module(name='%s' % m.i, module=m) # add
720
+ self.eval()
721
+ elif not mode and present:
722
+ print('Removing NMS... ')
723
+ self.model = self.model[:-1] # remove
724
+ return self
725
+
726
+ def autoshape(self): # add autoShape module
727
+ print('Adding autoShape... ')
728
+ m = autoShape(self) # wrap model
729
+ copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
730
+ return m
731
+
732
+ def info(self, verbose=False, img_size=640): # print model information
733
+ model_info(self, verbose, img_size)
734
+
735
+
736
+ def parse_model(d, ch): # model_dict, input_channels(3)
737
+ logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
738
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
739
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
740
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
741
+
742
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
743
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
744
+ m = eval(m) if isinstance(m, str) else m # eval strings
745
+ for j, a in enumerate(args):
746
+ try:
747
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
748
+ except:
749
+ pass
750
+
751
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
752
+ if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC,
753
+ SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv,
754
+ Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
755
+ RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
756
+ Res, ResCSPA, ResCSPB, ResCSPC,
757
+ RepRes, RepResCSPA, RepResCSPB, RepResCSPC,
758
+ ResX, ResXCSPA, ResXCSPB, ResXCSPC,
759
+ RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC,
760
+ Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
761
+ SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
762
+ SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC]:
763
+ c1, c2 = ch[f], args[0]
764
+ if c2 != no: # if not output
765
+ c2 = make_divisible(c2 * gw, 8)
766
+
767
+ args = [c1, c2, *args[1:]]
768
+ if m in [DownC, SPPCSPC, GhostSPPCSPC,
769
+ BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
770
+ RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
771
+ ResCSPA, ResCSPB, ResCSPC,
772
+ RepResCSPA, RepResCSPB, RepResCSPC,
773
+ ResXCSPA, ResXCSPB, ResXCSPC,
774
+ RepResXCSPA, RepResXCSPB, RepResXCSPC,
775
+ GhostCSPA, GhostCSPB, GhostCSPC,
776
+ STCSPA, STCSPB, STCSPC,
777
+ ST2CSPA, ST2CSPB, ST2CSPC]:
778
+ args.insert(2, n) # number of repeats
779
+ n = 1
780
+ elif m is nn.BatchNorm2d:
781
+ args = [ch[f]]
782
+ elif m is Concat:
783
+ c2 = sum([ch[x] for x in f])
784
+ elif m is Chuncat:
785
+ c2 = sum([ch[x] for x in f])
786
+ elif m is Shortcut:
787
+ c2 = ch[f[0]]
788
+ elif m is Foldcut:
789
+ c2 = ch[f] // 2
790
+ elif m in [Detect, IDetect, IAuxDetect, IBin, IKeypoint]:
791
+ args.append([ch[x] for x in f])
792
+ if isinstance(args[1], int): # number of anchors
793
+ args[1] = [list(range(args[1] * 2))] * len(f)
794
+ elif m is ReOrg:
795
+ c2 = ch[f] * 4
796
+ elif m is Contract:
797
+ c2 = ch[f] * args[0] ** 2
798
+ elif m is Expand:
799
+ c2 = ch[f] // args[0] ** 2
800
+ else:
801
+ c2 = ch[f]
802
+
803
+ m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
804
+ t = str(m)[8:-2].replace('__main__.', '') # module type
805
+ np = sum([x.numel() for x in m_.parameters()]) # number params
806
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
807
+ logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
808
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
809
+ layers.append(m_)
810
+ if i == 0:
811
+ ch = []
812
+ ch.append(c2)
813
+ return nn.Sequential(*layers), sorted(save)
814
+
815
+
816
+ if __name__ == '__main__':
817
+ parser = argparse.ArgumentParser()
818
+ parser.add_argument('--cfg', type=str, default='yolor-csp-c.yaml', help='model.yaml')
819
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
820
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
821
+ opt = parser.parse_args()
822
+ opt.cfg = check_file(opt.cfg) # check file
823
+ set_logging()
824
+ device = select_device(opt.device)
825
+
826
+ # Create model
827
+ model = Model(opt.cfg).to(device)
828
+ model.train()
829
+
830
+ if opt.profile:
831
+ img = torch.rand(1, 3, 640, 640).to(device)
832
+ y = model(img, profile=True)
833
+
834
+ # Profile
835
+ # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
836
+ # y = model(img, profile=True)
837
+
838
+ # Tensorboard
839
+ # from torch.utils.tensorboard import SummaryWriter
840
+ # tb_writer = SummaryWriter()
841
+ # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
842
+ # tb_writer.add_graph(model.model, img) # add model to tensorboard
843
+ # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard