DraconicDragon commited on
Commit
6b5de5c
·
verified ·
1 Parent(s): c97f7c3

Upload 3 files

Browse files
Files changed (3) hide show
  1. lsnet/lsnet.py +405 -0
  2. lsnet/lsnet_artist.py +248 -0
  3. lsnet/ska.py +61 -0
lsnet/lsnet.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import itertools
3
+
4
+ from timm.models.vision_transformer import trunc_normal_
5
+ from timm.layers import SqueezeExcite
6
+ from timm.models import register_model
7
+ from .ska import SKA
8
+
9
+ from timm.models import build_model_with_cfg
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+
12
+ class Conv2d_BN(torch.nn.Sequential):
13
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
14
+ groups=1, bn_weight_init=1):
15
+ super().__init__()
16
+ self.add_module('c', torch.nn.Conv2d(
17
+ a, b, ks, stride, pad, dilation, groups, bias=False))
18
+ self.add_module('bn', torch.nn.BatchNorm2d(b))
19
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
20
+ torch.nn.init.constant_(self.bn.bias, 0)
21
+
22
+ @torch.no_grad()
23
+ def fuse(self):
24
+ c, bn = self._modules.values()
25
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
26
+ w = c.weight * w[:, None, None, None]
27
+ b = bn.bias - bn.running_mean * bn.weight / \
28
+ (bn.running_var + bn.eps)**0.5
29
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
30
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
31
+ device=c.weight.device)
32
+ m.weight.data.copy_(w)
33
+ m.bias.data.copy_(b)
34
+ return m
35
+
36
+
37
+ class BN_Linear(torch.nn.Sequential):
38
+ def __init__(self, a, b, bias=True, std=0.02):
39
+ super().__init__()
40
+ self.add_module('bn', torch.nn.BatchNorm1d(a))
41
+ self.add_module('l', torch.nn.Linear(a, b, bias=bias))
42
+ trunc_normal_(self.l.weight, std=std)
43
+ if bias:
44
+ torch.nn.init.constant_(self.l.bias, 0)
45
+
46
+ @torch.no_grad()
47
+ def fuse(self):
48
+ bn, l = self._modules.values()
49
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
50
+ b = bn.bias - self.bn.running_mean * \
51
+ self.bn.weight / (bn.running_var + bn.eps)**0.5
52
+ w = l.weight * w[None, :]
53
+ if l.bias is None:
54
+ b = b @ self.l.weight.T
55
+ else:
56
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
57
+ m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
58
+ m.weight.data.copy_(w)
59
+ m.bias.data.copy_(b)
60
+ return m
61
+
62
+ class Residual(torch.nn.Module):
63
+ def __init__(self, m, drop=0.):
64
+ super().__init__()
65
+ self.m = m
66
+ self.drop = drop
67
+
68
+ def forward(self, x):
69
+ if self.training and self.drop > 0:
70
+ return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
71
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
72
+ else:
73
+ return x + self.m(x)
74
+
75
+ class FFN(torch.nn.Module):
76
+ def __init__(self, ed, h):
77
+ super().__init__()
78
+ self.pw1 = Conv2d_BN(ed, h)
79
+ self.act = torch.nn.ReLU()
80
+ self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0)
81
+
82
+ def forward(self, x):
83
+ x = self.pw2(self.act(self.pw1(x)))
84
+ return x
85
+
86
+ class Attention(torch.nn.Module):
87
+ def __init__(self, dim, key_dim, num_heads=8,
88
+ attn_ratio=4,
89
+ resolution=14):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ self.scale = key_dim ** -0.5
93
+ self.key_dim = key_dim
94
+ self.nh_kd = nh_kd = key_dim * num_heads
95
+ self.d = int(attn_ratio * key_dim)
96
+ self.dh = int(attn_ratio * key_dim) * num_heads
97
+ self.attn_ratio = attn_ratio
98
+ h = self.dh + nh_kd * 2
99
+ self.qkv = Conv2d_BN(dim, h, ks=1)
100
+ self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
101
+ self.dh, dim, bn_weight_init=0))
102
+ self.dw = Conv2d_BN(nh_kd, nh_kd, 3, 1, 1, groups=nh_kd)
103
+ points = list(itertools.product(range(resolution), range(resolution)))
104
+ N = len(points)
105
+ attention_offsets = {}
106
+ idxs = []
107
+ for p1 in points:
108
+ for p2 in points:
109
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
110
+ if offset not in attention_offsets:
111
+ attention_offsets[offset] = len(attention_offsets)
112
+ idxs.append(attention_offsets[offset])
113
+ self.attention_biases = torch.nn.Parameter(
114
+ torch.zeros(num_heads, len(attention_offsets)))
115
+ self.register_buffer('attention_bias_idxs',
116
+ torch.LongTensor(idxs).view(N, N))
117
+
118
+ @torch.no_grad()
119
+ def train(self, mode=True):
120
+ super().train(mode)
121
+ if mode and hasattr(self, 'ab'):
122
+ del self.ab
123
+ else:
124
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
125
+
126
+ def forward(self, x):
127
+ B, _, H, W = x.shape
128
+ N = H * W
129
+ qkv = self.qkv(x)
130
+ q, k, v = qkv.view(B, -1, H, W).split([self.nh_kd, self.nh_kd, self.dh], dim=1)
131
+ q = self.dw(q)
132
+ q, k, v = q.view(B, self.num_heads, -1, N), k.view(B, self.num_heads, -1, N), v.view(B, self.num_heads, -1, N)
133
+ attn = (
134
+ (q.transpose(-2, -1) @ k) * self.scale
135
+ +
136
+ (self.attention_biases[:, self.attention_bias_idxs]
137
+ if self.training else self.ab)
138
+ )
139
+ attn = attn.softmax(dim=-1)
140
+ x = (v @ attn.transpose(-2, -1)).reshape(B, -1, H, W)
141
+ x = self.proj(x)
142
+ return x
143
+
144
+ class RepVGGDW(torch.nn.Module):
145
+ def __init__(self, ed) -> None:
146
+ super().__init__()
147
+ self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
148
+ self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
149
+ self.dim = ed
150
+
151
+ def forward(self, x):
152
+ return self.conv(x) + self.conv1(x) + x
153
+
154
+ @torch.no_grad()
155
+ def fuse(self):
156
+ conv = self.conv.fuse()
157
+ conv1 = self.conv1.fuse()
158
+
159
+ conv_w = conv.weight
160
+ conv_b = conv.bias
161
+ conv1_w = conv1.weight
162
+ conv1_b = conv1.bias
163
+
164
+ conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
165
+
166
+ identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
167
+
168
+ final_conv_w = conv_w + conv1_w + identity
169
+ final_conv_b = conv_b + conv1_b
170
+
171
+ conv.weight.data.copy_(final_conv_w)
172
+ conv.bias.data.copy_(final_conv_b)
173
+ return conv
174
+
175
+ import torch.nn as nn
176
+
177
+ class LKP(nn.Module):
178
+ def __init__(self, dim, lks, sks, groups):
179
+ super().__init__()
180
+ self.cv1 = Conv2d_BN(dim, dim // 2)
181
+ self.act = nn.ReLU()
182
+ self.cv2 = Conv2d_BN(dim // 2, dim // 2, ks=lks, pad=(lks - 1) // 2, groups=dim // 2)
183
+ self.cv3 = Conv2d_BN(dim // 2, dim // 2)
184
+ self.cv4 = nn.Conv2d(dim // 2, sks ** 2 * dim // groups, kernel_size=1)
185
+ self.norm = nn.GroupNorm(num_groups=dim // groups, num_channels=sks ** 2 * dim // groups)
186
+
187
+ self.sks = sks
188
+ self.groups = groups
189
+ self.dim = dim
190
+
191
+ def forward(self, x):
192
+ x = self.act(self.cv3(self.cv2(self.act(self.cv1(x)))))
193
+ w = self.norm(self.cv4(x))
194
+ b, _, h, width = w.size()
195
+ w = w.view(b, self.dim // self.groups, self.sks ** 2, h, width)
196
+ return w
197
+
198
+ class LSConv(nn.Module):
199
+ def __init__(self, dim):
200
+ super(LSConv, self).__init__()
201
+ self.lkp = LKP(dim, lks=7, sks=3, groups=8)
202
+ self.ska = SKA()
203
+ self.bn = nn.BatchNorm2d(dim)
204
+
205
+ def forward(self, x):
206
+ return self.bn(self.ska(x, self.lkp(x))) + x
207
+
208
+ class Block(torch.nn.Module):
209
+ def __init__(self,
210
+ ed, kd, nh=8,
211
+ ar=4,
212
+ resolution=14,
213
+ stage=-1, depth=-1):
214
+ super().__init__()
215
+
216
+ if depth % 2 == 0:
217
+ self.mixer = RepVGGDW(ed)
218
+ self.se = SqueezeExcite(ed, 0.25)
219
+ else:
220
+ self.se = torch.nn.Identity()
221
+ if stage == 3:
222
+ self.mixer = Residual(Attention(ed, kd, nh, ar, resolution=resolution))
223
+ else:
224
+ self.mixer = LSConv(ed)
225
+
226
+ self.ffn = Residual(FFN(ed, int(ed * 2)))
227
+
228
+ def forward(self, x):
229
+ return self.ffn(self.se(self.mixer(x)))
230
+
231
+ class LSNet(torch.nn.Module):
232
+ def __init__(self, img_size=224,
233
+ patch_size=16,
234
+ in_chans=3,
235
+ num_classes=1000,
236
+ embed_dim=[64, 128, 192, 256],
237
+ key_dim=[16, 16, 16, 16],
238
+ depth=[1, 2, 3, 4],
239
+ num_heads=[4, 4, 4, 4],
240
+ distillation=False,
241
+ **kwargs):
242
+ super().__init__()
243
+
244
+ default_cfg = kwargs.pop('default_cfg', None)
245
+ pretrained_cfg = kwargs.pop('pretrained_cfg', None)
246
+ pretrained_cfg_overlay = kwargs.pop('pretrained_cfg_overlay', None)
247
+
248
+ if default_cfg is not None:
249
+ self.default_cfg = default_cfg
250
+ if pretrained_cfg is not None:
251
+ self.pretrained_cfg = pretrained_cfg
252
+ if pretrained_cfg_overlay is not None:
253
+ self.pretrained_cfg_overlay = pretrained_cfg_overlay
254
+
255
+ if kwargs:
256
+ self.extra_init_kwargs = kwargs
257
+
258
+ resolution = img_size
259
+ self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 4, 3, 2, 1), torch.nn.ReLU(),
260
+ Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), torch.nn.ReLU(),
261
+ Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)
262
+ )
263
+
264
+ resolution = img_size // patch_size
265
+ attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
266
+ self.blocks1 = nn.Sequential()
267
+ self.blocks2 = nn.Sequential()
268
+ self.blocks3 = nn.Sequential()
269
+ self.blocks4 = nn.Sequential()
270
+ blocks = [self.blocks1, self.blocks2, self.blocks3, self.blocks4]
271
+
272
+ for i, (ed, kd, dpth, nh, ar) in enumerate(
273
+ zip(embed_dim, key_dim, depth, num_heads, attn_ratio)):
274
+ for d in range(dpth):
275
+ blocks[i].append(Block(ed, kd, nh, ar, resolution, stage=i, depth=d))
276
+
277
+ if i != len(depth) - 1:
278
+ blk = blocks[i+1]
279
+ resolution_ = (resolution - 1) // 2 + 1
280
+ blk.append(Conv2d_BN(embed_dim[i], embed_dim[i], ks=3, stride=2, pad=1, groups=embed_dim[i]))
281
+ blk.append(Conv2d_BN(embed_dim[i], embed_dim[i+1], ks=1, stride=1, pad=0))
282
+ resolution = resolution_
283
+
284
+ self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
285
+ self.distillation = distillation
286
+ if distillation:
287
+ self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
288
+
289
+ self.num_classes = num_classes
290
+ self.num_features = embed_dim[-1]
291
+
292
+ @torch.jit.ignore # type: ignore
293
+ def no_weight_decay(self):
294
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
295
+
296
+ def forward(self, x):
297
+ x = self.patch_embed(x)
298
+ x = self.blocks1(x)
299
+ x = self.blocks2(x)
300
+ x = self.blocks3(x)
301
+ x = self.blocks4(x)
302
+ x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
303
+ if self.distillation:
304
+ x = self.head(x), self.head_dist(x)
305
+ if not self.training:
306
+ x = (x[0] + x[1]) / 2
307
+ else:
308
+ x = self.head(x)
309
+ return x
310
+
311
+ def _cfg(url='', **kwargs):
312
+ return {
313
+ 'url': url,
314
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
315
+ 'crop_pct': .9, 'interpolation': 'bicubic',
316
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
317
+ 'first_conv': 'patch_embed.0.c', 'classifier': ('head.linear', 'head_dist.linear'),
318
+ **kwargs
319
+ }
320
+
321
+ def _with_hf_hub(kwargs):
322
+ """兼容不同 timm 版本的 hf hub 配置字段"""
323
+ if 'hf_hub' in kwargs and 'hf_hub_id' not in kwargs:
324
+ kwargs['hf_hub_id'] = kwargs.pop('hf_hub')
325
+ return kwargs
326
+
327
+
328
+ default_cfgs = dict(
329
+ lsnet_t=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_t'})),
330
+ lsnet_t_distill=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_t_distill'})),
331
+ lsnet_s=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_s'})),
332
+ lsnet_s_distill=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_s_distill'})),
333
+ lsnet_b=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_b'})),
334
+ lsnet_b_distill=_cfg(**_with_hf_hub({'hf_hub': 'jameslahm/lsnet_b_distill'})),
335
+ )
336
+
337
+ def _create_lsnet(variant, pretrained=False, **kwargs):
338
+ cfg = default_cfgs.get(variant, None)
339
+ if cfg is not None:
340
+ kwargs.setdefault('default_cfg', cfg)
341
+ kwargs.setdefault('pretrained_cfg', cfg)
342
+ model = build_model_with_cfg(
343
+ LSNet,
344
+ variant,
345
+ pretrained,
346
+ **kwargs,
347
+ )
348
+ return model
349
+
350
+ @register_model
351
+ def lsnet_t(num_classes=1000, distillation=False, pretrained=False, **kwargs):
352
+ model = _create_lsnet("lsnet_t" + ("_distill" if distillation else ""),
353
+ pretrained=pretrained,
354
+ num_classes=num_classes,
355
+ distillation=distillation,
356
+ img_size=224,
357
+ patch_size=8,
358
+ embed_dim=[64, 128, 256, 384],
359
+ depth=[0, 2, 8, 10],
360
+ num_heads=[3, 3, 3, 4],
361
+ )
362
+ return model
363
+
364
+ @register_model
365
+ def lsnet_s(num_classes=1000, distillation=False, pretrained=False, **kwargs):
366
+ model = _create_lsnet("lsnet_s" + ("_distill" if distillation else ""),
367
+ pretrained=pretrained,
368
+ num_classes=num_classes,
369
+ distillation=distillation,
370
+ img_size=224,
371
+ patch_size=8,
372
+ embed_dim=[96, 192, 320, 448],
373
+ depth=[1, 2, 8, 10],
374
+ num_heads=[3, 3, 3, 4],
375
+ )
376
+ return model
377
+
378
+ @register_model
379
+ def lsnet_b(num_classes=1000, distillation=False, pretrained=False, **kwargs):
380
+ model = _create_lsnet("lsnet_b" + ("_distill" if distillation else ""),
381
+ pretrained=pretrained,
382
+ num_classes=num_classes,
383
+ distillation=distillation,
384
+ img_size=224,
385
+ patch_size=8,
386
+ embed_dim=[128, 256, 384, 512],
387
+ depth=[4, 6, 8, 10],
388
+ num_heads=[3, 3, 3, 4],
389
+ )
390
+ return model
391
+
392
+ @register_model
393
+ def lsnet_t_distill(**kwargs):
394
+ kwargs["distillation"] = True
395
+ return lsnet_t(**kwargs)
396
+
397
+ @register_model
398
+ def lsnet_s_distill(**kwargs):
399
+ kwargs["distillation"] = True
400
+ return lsnet_s(**kwargs)
401
+
402
+ @register_model
403
+ def lsnet_b_distill(**kwargs):
404
+ kwargs["distillation"] = True
405
+ return lsnet_b(**kwargs)
lsnet/lsnet_artist.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .lsnet import LSNet, Conv2d_BN, BN_Linear
4
+ from timm.models import register_model
5
+ from timm.models import build_model_with_cfg
6
+
7
+
8
+ class LSNetArtist(LSNet):
9
+ def __init__(self,
10
+ img_size=224,
11
+ patch_size=8,
12
+ in_chans=3,
13
+ num_classes=1000,
14
+ embed_dim=[64, 128, 256, 384],
15
+ key_dim=[16, 16, 16, 16],
16
+ depth=[0, 2, 8, 10],
17
+ num_heads=[3, 3, 3, 4],
18
+ distillation=False,
19
+ feature_dim=None, # 特征向量维度,默认为embed_dim[-1]
20
+ use_projection=True, # 是否使用projection层
21
+ **kwargs):
22
+ default_cfg = kwargs.pop('default_cfg', None)
23
+ pretrained_cfg = kwargs.pop('pretrained_cfg', None)
24
+ pretrained_cfg_overlay = kwargs.pop('pretrained_cfg_overlay', None)
25
+
26
+ super().__init__(
27
+ img_size=img_size,
28
+ patch_size=patch_size,
29
+ in_chans=in_chans,
30
+ num_classes=num_classes,
31
+ embed_dim=embed_dim,
32
+ key_dim=key_dim,
33
+ depth=depth,
34
+ num_heads=num_heads,
35
+ distillation=distillation,
36
+ default_cfg=default_cfg,
37
+ pretrained_cfg=pretrained_cfg,
38
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
39
+ **kwargs
40
+ )
41
+
42
+ self.feature_dim = feature_dim if feature_dim is not None else embed_dim[-1]
43
+ self.use_projection = use_projection
44
+
45
+ # 如果使用projection层,添加一个映射层来生成固定维度的特征
46
+ if self.use_projection and self.feature_dim != embed_dim[-1]:
47
+ self.projection = nn.Sequential(
48
+ BN_Linear(embed_dim[-1], self.feature_dim),
49
+ nn.ReLU(),
50
+ )
51
+ else:
52
+ self.projection = nn.Identity()
53
+
54
+ # 重新定义分类头(基于特征维度)
55
+ if num_classes > 0:
56
+ self.head = BN_Linear(self.feature_dim, num_classes)
57
+ if distillation:
58
+ self.head_dist = BN_Linear(self.feature_dim, num_classes)
59
+
60
+ def forward_features(self, x):
61
+ """
62
+ 提取特征,不经过分类头
63
+ 用于聚类或特征提取
64
+ """
65
+ x = self.patch_embed(x)
66
+ x = self.blocks1(x)
67
+ x = self.blocks2(x)
68
+ x = self.blocks3(x)
69
+ x = self.blocks4(x)
70
+ x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
71
+ x = self.projection(x)
72
+ return x
73
+
74
+ def forward(self, x, return_features=False):
75
+ """
76
+ x: 输入图像
77
+ return_features: 是否只返回特征向量(用于聚类)
78
+ False时返回分类logits(用于分类)
79
+
80
+ 如果return_features=True: 返回特征向量 (batch_size, feature_dim)
81
+ 如果return_features=False: 返回分类logits (batch_size, num_classes)
82
+ """
83
+ features = self.forward_features(x)
84
+
85
+ if return_features:
86
+ # 返回特征向量用于聚类
87
+ return features
88
+
89
+ # 返回分类结果
90
+ if self.distillation:
91
+ x = self.head(features), self.head_dist(features)
92
+ if not self.training:
93
+ x = (x[0] + x[1]) / 2
94
+ else:
95
+ x = self.head(features)
96
+
97
+ return x
98
+
99
+ def get_features(self, x):
100
+ """
101
+ 提取特征向量
102
+ """
103
+ return self.forward(x, return_features=True)
104
+
105
+ def classify(self, x):
106
+ """
107
+ 进行分类
108
+ """
109
+ return self.forward(x, return_features=False)
110
+
111
+
112
+ def _cfg_artist(url='', **kwargs):
113
+ return {
114
+ 'url': url,
115
+ 'num_classes': 1000,
116
+ 'input_size': (3, 224, 224),
117
+ 'pool_size': (4, 4),
118
+ 'crop_pct': .9,
119
+ 'interpolation': 'bicubic',
120
+ 'mean': (0.485, 0.456, 0.406),
121
+ 'std': (0.229, 0.224, 0.225),
122
+ 'first_conv': 'patch_embed.0.c',
123
+ 'classifier': ('head.linear', 'head_dist.linear'),
124
+ **kwargs
125
+ }
126
+
127
+
128
+ default_cfgs_artist = dict(
129
+ lsnet_t_artist = _cfg_artist(),
130
+ lsnet_s_artist = _cfg_artist(),
131
+ lsnet_b_artist = _cfg_artist(),
132
+ lsnet_l_artist = _cfg_artist(),
133
+ lsnet_xl_artist = _cfg_artist(),
134
+ )
135
+
136
+
137
+ def _create_lsnet_artist(variant, pretrained=False, **kwargs):
138
+ cfg = default_cfgs_artist.get(variant, None)
139
+ if cfg is not None:
140
+ kwargs.setdefault('default_cfg', cfg)
141
+ kwargs.setdefault('pretrained_cfg', cfg)
142
+ model = build_model_with_cfg(
143
+ LSNetArtist,
144
+ variant,
145
+ pretrained,
146
+ **kwargs,
147
+ )
148
+ return model
149
+
150
+
151
+ @register_model
152
+ def lsnet_t_artist(num_classes=1000, distillation=False, pretrained=False,
153
+ feature_dim=None, use_projection=True, **kwargs):
154
+ model = _create_lsnet_artist(
155
+ "lsnet_t_artist",
156
+ pretrained=pretrained,
157
+ num_classes=num_classes,
158
+ distillation=distillation,
159
+ img_size=224,
160
+ patch_size=8,
161
+ embed_dim=[64, 128, 256, 384],
162
+ depth=[0, 2, 8, 10],
163
+ num_heads=[3, 3, 3, 4],
164
+ feature_dim=feature_dim,
165
+ use_projection=use_projection,
166
+ **kwargs
167
+ )
168
+ return model
169
+
170
+
171
+ @register_model
172
+ def lsnet_s_artist(num_classes=1000, distillation=False, pretrained=False,
173
+ feature_dim=None, use_projection=True, **kwargs):
174
+ model = _create_lsnet_artist(
175
+ "lsnet_s_artist",
176
+ pretrained=pretrained,
177
+ num_classes=num_classes,
178
+ distillation=distillation,
179
+ img_size=224,
180
+ patch_size=8,
181
+ embed_dim=[96, 192, 320, 448],
182
+ depth=[1, 2, 8, 10],
183
+ num_heads=[3, 3, 3, 4],
184
+ feature_dim=feature_dim,
185
+ use_projection=use_projection,
186
+ **kwargs
187
+ )
188
+ return model
189
+
190
+
191
+ @register_model
192
+ def lsnet_b_artist(num_classes=1000, distillation=False, pretrained=False,
193
+ feature_dim=None, use_projection=True, **kwargs):
194
+ model = _create_lsnet_artist(
195
+ "lsnet_b_artist",
196
+ pretrained=pretrained,
197
+ num_classes=num_classes,
198
+ distillation=distillation,
199
+ img_size=224,
200
+ patch_size=8,
201
+ embed_dim=[128, 256, 384, 512],
202
+ depth=[4, 6, 8, 10],
203
+ num_heads=[3, 3, 3, 4],
204
+ feature_dim=feature_dim,
205
+ use_projection=use_projection,
206
+ **kwargs
207
+ )
208
+ return model
209
+
210
+
211
+ @register_model
212
+ def lsnet_l_artist(num_classes=1000, distillation=False, pretrained=False,
213
+ feature_dim=None, use_projection=True, **kwargs):
214
+ model = _create_lsnet_artist(
215
+ "lsnet_l_artist",
216
+ pretrained=pretrained,
217
+ num_classes=num_classes,
218
+ distillation=distillation,
219
+ img_size=224,
220
+ patch_size=8,
221
+ embed_dim=[160, 320, 480, 640],
222
+ depth=[6, 8, 12, 14],
223
+ num_heads=[4, 4, 4, 4],
224
+ feature_dim=feature_dim,
225
+ use_projection=use_projection,
226
+ **kwargs
227
+ )
228
+ return model
229
+
230
+
231
+ @register_model
232
+ def lsnet_xl_artist(num_classes=1000, distillation=False, pretrained=False,
233
+ feature_dim=None, use_projection=True, **kwargs):
234
+ model = _create_lsnet_artist(
235
+ "lsnet_xl_artist",
236
+ pretrained=pretrained,
237
+ num_classes=num_classes,
238
+ distillation=distillation,
239
+ img_size=224,
240
+ patch_size=8,
241
+ embed_dim=[192, 384, 576, 768],
242
+ depth=[8, 12, 16, 20],
243
+ num_heads=[6, 6, 6, 6],
244
+ feature_dim=feature_dim,
245
+ use_projection=use_projection,
246
+ **kwargs
247
+ )
248
+ return model
lsnet/ska.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class PyTorchSkaFn(Function):
9
+ @staticmethod
10
+ def forward(ctx, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
11
+ # Get kernel size and padding from the weight tensor shape
12
+ # w shape is (n, wc, ks*ks, h, w)
13
+ ks = int(math.sqrt(w.shape[2]))
14
+ pad = (ks - 1) // 2
15
+
16
+ n, ic, h, width = x.shape
17
+ wc = w.shape[1] # wc = weight channels
18
+
19
+ # 1. Extract patches from the input tensor
20
+ # This creates a "view" of the input where each (h*w) column
21
+ # contains the flattened data for a ks x ks patch.
22
+ # Shape: (n, ic * ks * ks, h * w)
23
+ x_unfolded = F.unfold(x, kernel_size=ks, padding=pad)
24
+
25
+ # 2. Reshape the unfolded input for element-wise multiplication
26
+ # Shape: (n, ic, ks * ks, h * w)
27
+ x_unfolded = x_unfolded.view(n, ic, ks * ks, h * width)
28
+
29
+ # 3. Prepare the weights for multiplication
30
+ # The original weights have wc channels, which are repeated across the
31
+ # input channels 'ic'.
32
+ # We need to reshape w to match the unfolded input.
33
+ # w original shape: (n, wc, ks*ks, h, w)
34
+ # w reshaped: (n, wc, ks*ks, h*w)
35
+ w = w.view(n, wc, ks * ks, h * width)
36
+
37
+ # If the number of input channels is not equal to weight channels,
38
+ # it implies the weights are grouped/repeated.
39
+ if ic != wc:
40
+ # This handles the "ci % wc" logic from the Triton kernel,
41
+ # repeating the weight channels to match the input channels.
42
+ repeats = ic // wc
43
+ w = w.repeat(1, repeats, 1, 1)
44
+
45
+ # 4. Perform the core operation: element-wise multiplication and sum
46
+ # This is the equivalent of the Triton kernel's main loop.
47
+ # (x_unfolded * w) -> shape: (n, ic, ks*ks, h*w)
48
+ # .sum(dim=2) sums across the kernel dimension (ks*ks).
49
+ # output shape: (n, ic, h*w)
50
+ output = (x_unfolded * w).sum(dim=2)
51
+
52
+ # 5. Reshape the output back to the original image format
53
+ # Shape: (n, ic, h, w)
54
+ output = output.view(n, ic, h, width)
55
+
56
+ return output
57
+
58
+
59
+ class SKA(torch.nn.Module):
60
+ def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
61
+ return PyTorchSkaFn.apply(x, w) # type: ignore