InPeerReview commited on
Commit
7ad9dfd
·
verified ·
1 Parent(s): 61e36aa

Upload 6 files

Browse files
Files changed (6) hide show
  1. models/Blocks.py +476 -0
  2. models/STNR.py +327 -0
  3. models/__init__.py +13 -0
  4. models/loss.py +155 -0
  5. models/mamba_customer.py +569 -0
  6. models/resnet.py +358 -0
models/Blocks.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from monai.networks.layers.utils import get_act_layer
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+ import math
9
+ from functools import partial
10
+ from typing import Callable
11
+ from timm.models.layers import DropPath, to_2tuple
12
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ class CAB(nn.Module):
17
+ def __init__(self, in_channels, out_channels=None, ratio=16, activation='relu'):
18
+ super(CAB, self).__init__()
19
+
20
+ self.in_channels = in_channels
21
+ self.out_channels = out_channels
22
+ if self.in_channels < ratio:
23
+ ratio = self.in_channels
24
+ self.reduced_channels = self.in_channels // ratio
25
+ if self.out_channels == None:
26
+ self.out_channels = in_channels
27
+
28
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
29
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
30
+ self.activation = get_act_layer(activation)
31
+ self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
32
+ self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)
33
+
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+ nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
37
+ nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='relu')
38
+
39
+ def forward(self, x):
40
+
41
+ avg = self.fc2(self.activation(self.fc1(self.avg_pool(x))))
42
+ max = self.fc2(self.activation(self.fc1(self.max_pool(x))))
43
+ attention = self.sigmoid(avg + max)
44
+
45
+ return attention
46
+
47
+ class SAB(nn.Module):
48
+ def __init__(self, kernel_size=7):
49
+ super(SAB, self).__init__()
50
+ assert kernel_size in (3, 7, 11), "kernel_size must be 3, 7 or 11"
51
+ padding = kernel_size // 2
52
+
53
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
54
+ self.sigmoid = nn.Sigmoid()
55
+ nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
56
+
57
+ def forward(self, x):
58
+ avg_out = torch.mean(x, dim=1, keepdim=True)
59
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
60
+
61
+ x_cat = torch.cat([avg_out, max_out], dim=1) # shape: [B, 2, H, W]
62
+ attention = self.sigmoid(self.conv(x_cat))
63
+
64
+ return attention
65
+
66
+ #--------------------------------
67
+
68
+
69
+ class ChannelAttention(nn.Module):
70
+ """Channel attention used in RCAN.
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
74
+ """
75
+
76
+ def __init__(self, num_feat, squeeze_factor=16):
77
+ super(ChannelAttention, self).__init__()
78
+ squeeze_channels = max(num_feat // squeeze_factor, 4) # 防止为0
79
+ self.attention = nn.Sequential(
80
+ nn.AdaptiveAvgPool2d(1),
81
+ nn.Conv2d(num_feat, squeeze_channels, 1, padding=0),
82
+ nn.ReLU(inplace=True),
83
+ nn.Conv2d(squeeze_channels, num_feat, 1, padding=0),
84
+ nn.Sigmoid())
85
+
86
+ def forward(self, x):
87
+ y = self.attention(x)
88
+ return x * y
89
+
90
+
91
+ class CAB(nn.Module):
92
+ def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30):
93
+ super(CAB, self).__init__()
94
+ mid_channels = max(num_feat // compress_ratio, 4) # 防止为0
95
+ if is_light_sr: # we use depth-wise conv for light-SR to achieve more efficient
96
+ self.cab = nn.Sequential(
97
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=num_feat),
98
+ ChannelAttention(num_feat, squeeze_factor)
99
+ )
100
+ else: # for classic SR
101
+ self.cab = nn.Sequential(
102
+ nn.Conv2d(num_feat, mid_channels, 3, 1, 1),
103
+ nn.GELU(),
104
+ nn.Conv2d(mid_channels, num_feat, 3, 1, 1),
105
+ ChannelAttention(num_feat, squeeze_factor)
106
+ )
107
+
108
+ def forward(self, x):
109
+ return self.cab(x)
110
+
111
+ class SS2D(nn.Module):
112
+ def __init__(
113
+ self,
114
+ d_model,
115
+ d_state=16,
116
+ d_conv=3,
117
+ expand=2.,
118
+ dt_rank="auto",
119
+ dt_min=0.001,
120
+ dt_max=0.1,
121
+ dt_init="random",
122
+ dt_scale=1.0,
123
+ dt_init_floor=1e-4,
124
+ dropout=0.,
125
+ conv_bias=True,
126
+ bias=False,
127
+ device=None,
128
+ dtype=None,
129
+ **kwargs,
130
+ ):
131
+ factory_kwargs = {"device": device, "dtype": dtype}
132
+ super().__init__()
133
+ self.d_model = d_model
134
+ self.d_state = d_state
135
+ self.d_conv = d_conv
136
+ self.expand = expand
137
+ self.d_inner = int(self.expand * self.d_model)
138
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
139
+
140
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
141
+ self.conv2d = nn.Conv2d(
142
+ in_channels=self.d_inner,
143
+ out_channels=self.d_inner,
144
+ groups=self.d_inner,
145
+ bias=conv_bias,
146
+ kernel_size=d_conv,
147
+ padding=(d_conv - 1) // 2,
148
+ **factory_kwargs,
149
+ )
150
+ self.act = nn.SiLU()
151
+
152
+ self.x_proj = (
153
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
154
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
155
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
156
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
157
+ )
158
+ self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
159
+ del self.x_proj
160
+
161
+ self.dt_projs = (
162
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
163
+ **factory_kwargs),
164
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
165
+ **factory_kwargs),
166
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
167
+ **factory_kwargs),
168
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
169
+ **factory_kwargs),
170
+ )
171
+ self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
172
+ self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
173
+ del self.dt_projs
174
+
175
+ self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
176
+ self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
177
+
178
+ self.selective_scan = selective_scan_fn
179
+
180
+ self.out_norm = nn.LayerNorm(self.d_inner)
181
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
182
+ self.dropout = nn.Dropout(dropout) if dropout > 0. else None
183
+
184
+ @staticmethod
185
+ def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
186
+ **factory_kwargs):
187
+ dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
188
+
189
+ # Initialize special dt projection to preserve variance at initialization
190
+ dt_init_std = dt_rank ** -0.5 * dt_scale
191
+ if dt_init == "constant":
192
+ nn.init.constant_(dt_proj.weight, dt_init_std)
193
+ elif dt_init == "random":
194
+ nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
195
+ else:
196
+ raise NotImplementedError
197
+
198
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
199
+ dt = torch.exp(
200
+ torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
201
+ + math.log(dt_min)
202
+ ).clamp(min=dt_init_floor)
203
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
204
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
205
+ with torch.no_grad():
206
+ dt_proj.bias.copy_(inv_dt)
207
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
208
+ dt_proj.bias._no_reinit = True
209
+
210
+ return dt_proj
211
+
212
+ @staticmethod
213
+ def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
214
+ # S4D real initialization
215
+ A = repeat(
216
+ torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
217
+ "n -> d n",
218
+ d=d_inner,
219
+ ).contiguous()
220
+ A_log = torch.log(A) # Keep A_log in fp32
221
+ if copies > 1:
222
+ A_log = repeat(A_log, "d n -> r d n", r=copies)
223
+ if merge:
224
+ A_log = A_log.flatten(0, 1)
225
+ A_log = nn.Parameter(A_log)
226
+ A_log._no_weight_decay = True
227
+ return A_log
228
+
229
+ @staticmethod
230
+ def D_init(d_inner, copies=1, device=None, merge=True):
231
+ # D "skip" parameter
232
+ D = torch.ones(d_inner, device=device)
233
+ if copies > 1:
234
+ D = repeat(D, "n1 -> r n1", r=copies)
235
+ if merge:
236
+ D = D.flatten(0, 1)
237
+ D = nn.Parameter(D) # Keep in fp32
238
+ D._no_weight_decay = True
239
+ return D
240
+
241
+ def forward_core(self, x: torch.Tensor):
242
+ B, C, H, W = x.shape
243
+ L = H * W
244
+ K = 4
245
+ x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
246
+ xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136)
247
+
248
+ x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
249
+ dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
250
+ dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
251
+ xs = xs.float().view(B, -1, L)
252
+ dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
253
+ Bs = Bs.float().view(B, K, -1, L)
254
+ Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
255
+ Ds = self.Ds.float().view(-1)
256
+ As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)
257
+ dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
258
+ out_y = self.selective_scan(
259
+ xs, dts,
260
+ As, Bs, Cs, Ds, z=None,
261
+ delta_bias=dt_projs_bias,
262
+ delta_softplus=True,
263
+ return_last_state=False,
264
+ ).view(B, K, -1, L)
265
+ assert out_y.dtype == torch.float
266
+
267
+ inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
268
+ wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
269
+ invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
270
+
271
+ return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
272
+
273
+ def forward(self, x: torch.Tensor, **kwargs):
274
+ B, H, W, C = x.shape
275
+
276
+ xz = self.in_proj(x)
277
+ x, z = xz.chunk(2, dim=-1)
278
+
279
+ x = x.permute(0, 3, 1, 2).contiguous()
280
+ x = self.act(self.conv2d(x))
281
+ y1, y2, y3, y4 = self.forward_core(x)
282
+ assert y1.dtype == torch.float32
283
+ y = y1 + y2 + y3 + y4
284
+ y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
285
+ y = self.out_norm(y)
286
+ y = y * F.silu(z)
287
+ out = self.out_proj(y)
288
+ if self.dropout is not None:
289
+ out = self.dropout(out)
290
+ return out
291
+
292
+
293
+ class VSSBlock(nn.Module):
294
+ def __init__(
295
+ self,
296
+ hidden_dim: int = 0,
297
+ drop_path: float = 0,
298
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
299
+ attn_drop_rate: float = 0,
300
+ d_state: int = 16,
301
+ expand: float = 2.,
302
+ is_light_sr: bool = False,
303
+ **kwargs,
304
+ ):
305
+ super().__init__()
306
+ self.ln_1 = norm_layer(hidden_dim)
307
+ self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs)
308
+ self.drop_path = DropPath(drop_path)
309
+ self.skip_scale= nn.Parameter(torch.ones(hidden_dim))
310
+ self.conv_blk = CAB(hidden_dim,is_light_sr)
311
+ self.ln_2 = nn.LayerNorm(hidden_dim)
312
+ self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim))
313
+
314
+
315
+
316
+ def forward(self, input, x_size):
317
+ # x [B,HW,C]
318
+ B, L, C = input.shape
319
+ input = input.view(B, *x_size, C).contiguous() # [B,H,W,C]
320
+ x = self.ln_1(input)
321
+ x = input*self.skip_scale + self.drop_path(self.self_attention(x))
322
+ x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()
323
+ x = x.view(B, -1, C).contiguous()
324
+ return x
325
+
326
+ class RoPE(nn.Module):
327
+ def __init__(self, embed_dim, num_heads):
328
+ super().__init__()
329
+ self.head_dim = embed_dim // num_heads
330
+ self.num_heads = num_heads
331
+
332
+ def forward(self, x_size):
333
+ H, W = x_size
334
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
+ pos_h = torch.arange(H, dtype=torch.float32, device=device)
336
+ pos_w = torch.arange(W, dtype=torch.float32, device=device)
337
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
338
+ sin_h = torch.sin(torch.einsum("i,j->ij", pos_h, inv_freq))
339
+ cos_h = torch.cos(torch.einsum("i,j->ij", pos_h, inv_freq))
340
+ sin_w = torch.sin(torch.einsum("i,j->ij", pos_w, inv_freq))
341
+ cos_w = torch.cos(torch.einsum("i,j->ij", pos_w, inv_freq))
342
+ sin = torch.einsum("i,j->ij", sin_h[:, 0], sin_w[:, 0]).unsqueeze(0).unsqueeze(0)
343
+ cos = torch.einsum("i,j->ij", cos_h[:, 0], cos_w[:, 0]).unsqueeze(0).unsqueeze(0)
344
+ sin = sin.expand(self.num_heads, -1, -1, -1).contiguous()
345
+ cos = cos.expand(self.num_heads, -1, -1, -1).contiguous()
346
+ return sin, cos
347
+
348
+ def rotate_every_two(x):
349
+ if x.shape[-1] % 2 != 0:
350
+ x = F.pad(x, (0, 1), mode='constant', value=0)
351
+ pad = True
352
+ else:
353
+ pad = False
354
+
355
+ x1 = x[..., ::2]
356
+ x2 = x[..., 1::2]
357
+ out = torch.stack((-x2, x1), -1).reshape(*x.shape[:-1], -1)
358
+
359
+ return out[..., :x.shape[-1]-1] if pad else out
360
+
361
+
362
+ def theta_shift(x, sin, cos):
363
+ if sin.shape[-1] < x.shape[-1]:
364
+ pad = x.shape[-1] - sin.shape[-1]
365
+ sin = F.pad(sin, (0, pad), mode='constant', value=0)
366
+ cos = F.pad(cos, (0, pad), mode='constant', value=1)
367
+ elif sin.shape[-1] > x.shape[-1]:
368
+ sin = sin[..., :x.shape[-1]]
369
+ cos = cos[..., :x.shape[-1]]
370
+
371
+ return (x * cos) + (rotate_every_two(x) * sin)
372
+
373
+ class OverlapWindowAttention(nn.Module):
374
+ def __init__(self, dim, num_heads=4, window_size=7, shift_size=3):
375
+ super().__init__()
376
+ self.dim = dim
377
+ self.num_heads = num_heads
378
+ self.head_dim = dim // num_heads
379
+ self.scale = self.head_dim ** -0.5
380
+ self.window_size = window_size
381
+ self.shift_size = shift_size
382
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
383
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
384
+
385
+ def forward(self, x, sin, cos):
386
+
387
+ B, C, H, W = x.shape
388
+ ws = self.window_size
389
+ pad_h = (ws - H % ws) % ws
390
+ pad_w = (ws - W % ws) % ws
391
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
392
+ H_pad, W_pad = x.shape[2], x.shape[3]
393
+
394
+ if self.shift_size > 0:
395
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
396
+
397
+ qkv = self.qkv(x)
398
+ qkv = rearrange(qkv, 'b (m c) h w -> m b c h w', m=3)
399
+ q, k, v = qkv[0], qkv[1], qkv[2]
400
+ q = q.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
401
+ k = k.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
402
+ v = v.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
403
+ q = theta_shift(q, sin, cos) * self.scale
404
+ k = theta_shift(k, sin, cos)
405
+
406
+ q = q.view(B, C, H_pad, W_pad)
407
+ k = k.view(B, C, H_pad, W_pad)
408
+ v = v.view(B, C, H_pad, W_pad)
409
+
410
+ q = rearrange(q, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
411
+ k = rearrange(k, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
412
+ v = rearrange(v, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
413
+
414
+ B, num_windows, window_len, C_new = q.shape
415
+ assert C_new % self.num_heads == 0, f"C_new={C_new} 不能整除 num_heads={self.num_heads}"
416
+ head_dim_new = C_new // self.num_heads
417
+
418
+ q = q.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
419
+ k = k.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
420
+ v = v.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
421
+
422
+ attn = torch.softmax(q @ k.transpose(-2, -1), dim=-1)
423
+ out = (attn @ v).transpose(2, 3).reshape(B, num_windows, window_len, self.num_heads * head_dim_new)
424
+ out = rearrange(out, 'b (h w) (ws1 ws2) c -> b c (h ws1) (w ws2)', h=H_pad // ws, ws1=ws, ws2=ws, w=W_pad // ws)
425
+
426
+ if self.shift_size > 0:
427
+ out = torch.roll(out, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
428
+
429
+ out = out[:, :, :H, :W]
430
+ out = self.proj(out)
431
+ return out
432
+
433
+ class ShallowFusionAttnBlock(nn.Module):
434
+ def __init__(self, dim, num_heads=4, window_size=7, shift_size=3):
435
+ super().__init__()
436
+ self.dim = dim
437
+ self.attn = OverlapWindowAttention(dim, num_heads=num_heads, window_size=window_size, shift_size=shift_size)
438
+ self.rope = RoPE(embed_dim=dim, num_heads=num_heads)
439
+ self.conv1 = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1)
440
+ self.conv2 = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1)
441
+ self.vss = VSSBlock(dim)
442
+
443
+ def patch_unembed(self, x, h, w):
444
+ return x.transpose(1, 2).reshape(x.size(0), -1, h, w)
445
+
446
+ def patch_embed(self, x):
447
+ return x.flatten(2).transpose(1, 2)
448
+
449
+ def forward(self, I1, I2, h, w):
450
+ B, C, H, W = I1.shape
451
+
452
+ diff = torch.abs(I1 - I2)
453
+ H_pad = (self.attn.window_size - h % self.attn.window_size) % self.attn.window_size + h
454
+ W_pad = (self.attn.window_size - w % self.attn.window_size) % self.attn.window_size + w
455
+ sin, cos = self.rope((H_pad, W_pad))
456
+ diff_attn = self.attn(diff, sin, cos)
457
+ token_attn = self.patch_embed(diff_attn) # [B, N, C]
458
+ I1_token = self.patch_embed(I1)
459
+ I2_token = self.patch_embed(I2)
460
+ I1 = I1_token + token_attn
461
+ I2 = I2_token + token_attn
462
+
463
+ I1_un = self.patch_unembed(I1, h, w)
464
+ I2_un = self.patch_unembed(I2, h, w)
465
+
466
+ I1_local = self.conv1(torch.cat([I1_un, I2_un], dim=1)) + I1_un
467
+ I2_local = self.conv2(torch.cat([I2_un, I1_un], dim=1)) + I2_un
468
+
469
+ I1_token = self.patch_embed(I1_local)
470
+ I2_token = self.patch_embed(I2_local)
471
+ vss_feat_1 = self.vss(I1_token, (h, w)).transpose(1, 2).view(B, C, h, w)
472
+ vss_feat_2 = self.vss(I2_token, (h, w)).transpose(1, 2).view(B, C, h, w)
473
+
474
+ I1_fuse = I1_local + vss_feat_1
475
+ I2_fuse = I2_local + vss_feat_2
476
+ return I1_fuse, I2_fuse
models/STNR.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from monai.networks.blocks.convolutions import Convolution
6
+ from monai.networks.blocks.segresnet_block import get_conv_layer, get_upsample_layer
7
+ from monai.networks.layers.factories import Dropout
8
+ from monai.networks.layers.utils import get_act_layer, get_norm_layer
9
+ from monai.utils import UpsampleMode
10
+ from einops import rearrange
11
+ from models.mamba_customer import ConvMamba, M3, PatchEmbed, PatchUnEmbed
12
+ from models.Blocks import CAB, SAB, VSSBlock, ShallowFusionAttnBlock
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def get_dwconv_layer(
17
+ spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
18
+ bias: bool = False
19
+ ):
20
+ depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
21
+ strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
22
+ point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
23
+ strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
24
+ return torch.nn.Sequential(depth_conv, point_conv)
25
+
26
+
27
+ class SRCMLayer(nn.Module):
28
+ def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2, conv_mode='deepwise'):
29
+ super().__init__()
30
+ self.input_dim = input_dim
31
+ self.output_dim = output_dim
32
+ self.norm = nn.LayerNorm(input_dim)
33
+ self.convmamba = ConvMamba(
34
+ d_model=input_dim,
35
+ d_state=d_state,
36
+ d_conv=d_conv,
37
+ expand=expand,
38
+ bimamba_type="v2",
39
+ conv_mode=conv_mode
40
+ )
41
+ self.proj = nn.Linear(input_dim, output_dim)
42
+ self.skip_scale = nn.Parameter(torch.ones(1))
43
+
44
+ def forward(self, x):
45
+ if x.dtype == torch.float16:
46
+ x = x.type(torch.float32)
47
+ B, C = x.shape[:2]
48
+ assert C == self.input_dim
49
+ n_tokens = x.shape[2:].numel()
50
+ img_dims = x.shape[2:]
51
+ x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
52
+ x_norm = self.norm(x_flat)
53
+ x_mamba = self.convmamba(x_norm) + self.skip_scale * x_flat
54
+ x_mamba = self.norm(x_mamba)
55
+ x_mamba = self.proj(x_mamba)
56
+ out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
57
+ return out
58
+
59
+
60
+ def get_srcm_layer(
61
+ spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1, conv_mode: str = "deepwise"
62
+ ):
63
+ srcm_layer = SRCMLayer(input_dim=in_channels, output_dim=out_channels, conv_mode=conv_mode)
64
+ if stride != 1:
65
+ if spatial_dims == 2:
66
+ return nn.Sequential(srcm_layer, nn.MaxPool2d(kernel_size=stride, stride=stride))
67
+ return srcm_layer
68
+
69
+ class SRCMBlock(nn.Module):
70
+
71
+ def __init__(
72
+ self,
73
+ spatial_dims: int,
74
+ in_channels: int,
75
+ norm: tuple | str,
76
+ kernel_size: int = 3,
77
+ conv_mode: str = "deepwise",
78
+ act: tuple | str = ("RELU", {"inplace": True}),
79
+ ) -> None:
80
+ """
81
+ Args:
82
+ spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
83
+ in_channels: number of input channels.
84
+ norm: feature normalization type and arguments.
85
+ kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
86
+ act: activation type and arguments. Defaults to ``RELU``.
87
+ """
88
+
89
+ super().__init__()
90
+
91
+ if kernel_size % 2 != 1:
92
+ raise AssertionError("kernel_size should be an odd number.")
93
+ # print(conv_mode)
94
+ self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
95
+ self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
96
+ self.act = get_act_layer(act)
97
+ self.conv1 = get_srcm_layer(
98
+ spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
99
+ )
100
+ self.conv2 = get_srcm_layer(
101
+ spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
102
+ )
103
+
104
+ def forward(self, x):
105
+ identity = x
106
+
107
+ x = self.norm1(x)
108
+ x = self.act(x)
109
+ x = self.conv1(x)
110
+
111
+ x = self.norm2(x)
112
+ x = self.act(x)
113
+ x = self.conv2(x)
114
+
115
+ x += identity
116
+
117
+ return x
118
+
119
+
120
+ class CSI(nn.Module):
121
+ def __init__(self, dim):
122
+ super(CSI, self).__init__()
123
+ self.shallow_fusion_attn = ShallowFusionAttnBlock(dim)
124
+ self.m3 = M3(dim)
125
+ self.vss = VSSBlock(hidden_dim=dim)
126
+ self.patch_embed = PatchEmbed(in_chans=dim, embed_dim=dim)
127
+ self.patch_unembed = PatchUnEmbed(in_chans=dim, embed_dim=dim)
128
+ def forward(self, I1, I2, h, w):
129
+ I1_fuse, I2_fuse = self.shallow_fusion_attn(I1, I2, h, w)
130
+ fusion = torch.abs(I1_fuse - I2_fuse)
131
+ I1_token = self.patch_embed(I1_fuse)
132
+ I2_token = self.patch_embed(I2_fuse)
133
+ fusion_token = self.patch_embed(fusion)
134
+ test_h, test_w = fusion.shape[2], fusion.shape[3]
135
+ fusion_token, _ = self.m3(I1_token, I2_token, fusion_token, test_h, test_w)
136
+ fusion_out = self.patch_unembed(fusion_token, (h, w))
137
+ return fusion_out
138
+
139
+ class STNR(nn.Module):
140
+ def __init__(
141
+ self,
142
+ spatial_dims: int = 2,
143
+ init_filters: int = 16,
144
+ in_channels: int = 1,
145
+ out_channels: int = 2,
146
+ conv_mode: str = "deepwise",
147
+ local_query_model = "orignal_dinner",
148
+ dropout_prob: float | None = None,
149
+ act: tuple | str = ("RELU", {"inplace": True}),
150
+ norm: tuple | str = ("GROUP", {"num_groups": 8}),
151
+ norm_name: str = "",
152
+ num_groups: int = 8,
153
+ use_conv_final: bool = True,
154
+ blocks_down: tuple = (1, 2, 2, 4),
155
+ blocks_up: tuple = (1, 1, 1),
156
+ mode: str = "",
157
+ up_mode="ResMamba",
158
+ up_conv_mode="deepwise",
159
+ resdiual=False,
160
+ stage = 4,
161
+ diff_abs="later",
162
+ mamba_act = "silu",
163
+ upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,
164
+ ):
165
+ super().__init__()
166
+
167
+ if spatial_dims not in (2, 3):
168
+ raise ValueError("`spatial_dims` can only be 2 or 3.")
169
+ self.mode = mode
170
+ self.stage = stage
171
+ self.up_conv_mode = up_conv_mode
172
+ self.mamba_act = mamba_act
173
+ self.resdiual = resdiual
174
+ self.up_mode = up_mode
175
+ self.diff_abs = diff_abs
176
+ self.conv_mode = conv_mode
177
+ self.local_query_model = local_query_model
178
+ self.spatial_dims = spatial_dims
179
+ self.init_filters = init_filters
180
+ self.channels_list = [self.init_filters, self.init_filters*2, self.init_filters*4, self.init_filters*8]
181
+ self.in_channels = in_channels
182
+ self.blocks_down = blocks_down
183
+ self.blocks_up = blocks_up
184
+ print(self.blocks_up)
185
+ self.dropout_prob = dropout_prob
186
+ self.act = act # input options
187
+ self.act_mod = get_act_layer(act)
188
+ if norm_name:
189
+ if norm_name.lower() != "group":
190
+ raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.")
191
+ norm = ("group", {"num_groups": num_groups})
192
+ self.norm = norm
193
+ print(self.norm)
194
+ self.upsample_mode = UpsampleMode(upsample_mode)
195
+ self.use_conv_final = use_conv_final
196
+ self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
197
+ self.srcm_encoder_layers = self._make_srcm_encoder_layers()
198
+ self.srcm_decoder_layers, self.up_samples = self._make_srcm_decoder_layers(up_mode=self.up_mode)
199
+ self.conv_final = self._make_final_conv(out_channels)
200
+ self.fusion_blocks = nn.ModuleList(
201
+ [CSI(self.channels_list[i]) for i in range(self.stage)]
202
+ )
203
+ self.cab_layers = nn.ModuleList([
204
+ CAB(ch) for ch in self.channels_list[::-1][1:]
205
+ ])
206
+ self.sab_layers = nn.ModuleList([
207
+ SAB(kernel_size=7) for _ in range(len(self.blocks_up))
208
+ ])
209
+ self.conv_down_layers = nn.ModuleList([
210
+ nn.Conv2d(ch * 2, ch, kernel_size=1, stride=1, padding=0) for ch in self.channels_list[::-1][1:]
211
+ ])
212
+ if dropout_prob is not None:
213
+ self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
214
+
215
+ def _make_srcm_encoder_layers(self):
216
+ srcm_encoder_layers = nn.ModuleList()
217
+ blocks_down, spatial_dims, filters, norm, conv_mode = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm, self.conv_mode)
218
+ for i, item in enumerate(blocks_down):
219
+ layer_in_channels = filters * 2 ** i
220
+ downsample_mamba = (
221
+ get_srcm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2, conv_mode=conv_mode)
222
+ if i > 0
223
+ else nn.Identity()
224
+ )
225
+ down_layer = nn.Sequential(
226
+ downsample_mamba,
227
+ *[SRCMBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act, conv_mode=conv_mode) for _ in range(item)]
228
+ )
229
+ srcm_encoder_layers.append(down_layer)
230
+ return srcm_encoder_layers
231
+
232
+ def _make_srcm_decoder_layers(self, up_mode):
233
+ srcm_decoder_layers, up_samples = nn.ModuleList(), nn.ModuleList()
234
+ upsample_mode, blocks_up, spatial_dims, filters, norm = (
235
+ self.upsample_mode,
236
+ self.blocks_up,
237
+ self.spatial_dims,
238
+ self.init_filters,
239
+ self.norm,
240
+ )
241
+ if up_mode == 'SRCM':
242
+ Block_up = SRCMBlock
243
+ n_up = len(blocks_up)
244
+ for i in range(n_up):
245
+ sample_in_channels = filters * 2 ** (n_up - i)
246
+ srcm_decoder_layers.append(
247
+ nn.Sequential(
248
+ *[
249
+ Block_up(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act, conv_mode=self.up_conv_mode)
250
+ for _ in range(blocks_up[i])
251
+ ]
252
+ )
253
+ )
254
+ up_samples.append(
255
+ nn.Sequential(
256
+ *[
257
+ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1),
258
+ get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode),
259
+ ]
260
+ )
261
+ )
262
+ return srcm_decoder_layers, up_samples
263
+
264
+ def _make_final_conv(self, out_channels: int):
265
+ return nn.Sequential(
266
+ get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
267
+ self.act_mod,
268
+ get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
269
+ )
270
+
271
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
272
+ x = self.convInit(x)
273
+ if self.dropout_prob is not None:
274
+ x = self.dropout(x)
275
+ down_x = []
276
+
277
+ for down in self.srcm_encoder_layers:
278
+ x = down(x)
279
+ down_x.append(x)
280
+
281
+ return x, down_x
282
+
283
+ def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor:
284
+ for i, (up, upl) in enumerate(zip(self.up_samples, self.srcm_decoder_layers)):
285
+ skip = down_x[i + 1]
286
+ x_up = up(x) + skip
287
+ x_cab = self.cab_layers[i](x_up) * x_up
288
+ x_sab = self.sab_layers[i](x_cab) * x_cab
289
+ x_srcm = upl(x_up)
290
+ combined_out = torch.cat([x_sab, x_srcm], dim=1)
291
+ final_out = self.conv_down_layers[i](combined_out)
292
+ x = final_out
293
+ if self.use_conv_final:
294
+ x = self.conv_final(x)
295
+ return x
296
+
297
+ def forward(self, x1: torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
298
+ b, c, h, w = x1.shape
299
+ x1, down_x1 = self.encode(x1)
300
+ x2, down_x2 = self.encode(x2)
301
+ down_x = []
302
+ for i in range(len(down_x1)):
303
+ x1_level, x2_level = down_x1[i], down_x2[i]
304
+ H_i, W_i = x1_level.shape[2], x1_level.shape[3]
305
+ if self.diff_abs == "later":
306
+ if self.mode == "FUSION":
307
+ if i < self.stage:
308
+ zero_res = torch.zeros_like(x1_level)
309
+ fusion = self.fusion_blocks[i](x1_level, x2_level, H_i, W_i)
310
+ else:
311
+ fusion = torch.abs(x1_level - x2_level)
312
+ else:
313
+ fusion = torch.abs(x1_level - x2_level)
314
+ down_x.append(fusion)
315
+ down_x.reverse()
316
+ x = self.decode(down_x[0], down_x)
317
+ return x
318
+
319
+ if __name__ == "__main__":
320
+ device = "cuda:0"
321
+ CDMamba = STNR(spatial_dims=2, in_channels=3, out_channels=2, init_filters=16, norm=("GROUP", {"num_groups": 8}),
322
+ mode="FUSION", conv_mode='orignal', local_query_model="orignal_dinner",
323
+ stage=4, mamba_act="silu", up_mode="SRCM", up_conv_mode='deepwise', blocks_down=(1, 2, 2, 4), blocks_up=(1, 1, 1),
324
+ resdiual=False, diff_abs="later").to(device)
325
+ x = torch.randn(1, 3, 256, 256).to(device)
326
+ y = CDMamba(x, x)
327
+ print(y.shape)
models/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet import *
2
+ import logging
3
+ logger = logging.getLogger('base')
4
+
5
+ def create_CD_model(opt):
6
+ # Our CDMamba model
7
+ from models.STNR import STNR as stnr
8
+
9
+ if opt['model']['name'] == 'STNR':
10
+ cd_model = stnr(spatial_dims=opt['model']['spatial_dims'], in_channels=opt['model']['in_channels'], init_filters=opt['model']['init_filters'], out_channels=opt['model']['n_classes'],
11
+ mode=opt['model']['mode'], conv_mode=opt['model']['conv_mode'], up_mode=opt['model']['up_mode'], up_conv_mode=opt['model']['up_conv_mode'], norm=opt['model']['norm'],
12
+ blocks_down=opt['model']['blocks_down'], blocks_up=opt['model']['blocks_up'], resdiual=opt['model']['resdiual'], diff_abs=opt['model']['diff_abs'], stage=opt['model']['stage'],
13
+ mamba_act=opt['model']['mamba_act'], local_query_model=opt['model']['local_query_model'])
models/loss.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor, einsum
4
+ import torch.nn .functional as F
5
+ from misc.torchutils import class2one_hot,simplex
6
+ from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss
7
+
8
+ def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
9
+ """
10
+ logSoftmax_with_loss
11
+ :param input: torch.Tensor, N*C*H*W
12
+ :param target: torch.Tensor, N*1*H*W,/ N*H*W
13
+ :param weight: torch.Tensor, C
14
+ :return: torch.Tensor [0]
15
+ """
16
+ target = target.long()
17
+ if target.dim() == 4:
18
+ target = torch.squeeze(target, dim=1)
19
+ if input.shape[-1] != target.shape[-1]:
20
+ input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
21
+
22
+ return F.cross_entropy(input=input, target=target, weight=weight,
23
+ ignore_index=ignore_index, reduction=reduction)
24
+
25
+
26
+ def dice_loss(predicts,target,weight=None):
27
+ idc= [0, 1]
28
+ probs = torch.softmax(predicts, dim=1)
29
+ # target = target.unsqueeze(1)
30
+ target = class2one_hot(target, 7)
31
+ assert simplex(probs) and simplex(target)
32
+
33
+ pc = probs[:, idc, ...].type(torch.float32)
34
+ tc = target[:, idc, ...].type(torch.float32)
35
+ intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
36
+ union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))
37
+
38
+ divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10)
39
+
40
+ loss = divided.mean()
41
+ return loss
42
+
43
+ def ce_dice(input, target, weight=None):
44
+ ce_loss = cross_entropy(input, target)
45
+ dice_loss_ = dice_loss(input, target)
46
+ loss = 0.5 * ce_loss + 0.5 * dice_loss_
47
+ return loss
48
+
49
+ def dice(input, target, weight=None):
50
+ dice_loss_ = dice_loss(input, target)
51
+ return dice_loss_
52
+
53
+ def ce2_dice1(input, target, weight=None):
54
+ ce_loss = cross_entropy(input, target)
55
+ dice_loss_ = dice_loss(input, target)
56
+ loss = ce_loss + 0.5 * dice_loss_
57
+ return loss
58
+
59
+ def ce1_dice2(input, target, weight=None):
60
+ ce_loss = cross_entropy(input, target)
61
+ dice_loss_ = dice_loss(input, target)
62
+ loss = 0.5 * ce_loss + dice_loss_
63
+ return loss
64
+
65
+ def ce_scl(input, target, weight=None):
66
+ ce_loss = cross_entropy(input, target)
67
+ dice_loss_ = dice_loss(input, target)
68
+ loss = 0.5 * ce_loss + 0.5 * dice_loss_
69
+ return loss
70
+
71
+
72
+ def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75):
73
+ logit = logit_pixel.view(-1)
74
+ truth = truth_pixel.view(-1)
75
+ assert (logit.shape == truth.shape)
76
+
77
+ loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none')
78
+
79
+ pos = (truth > 0.5).float()
80
+ neg = (truth < 0.5).float()
81
+ pos_num = pos.sum().item() + 1e-12
82
+ neg_num = neg.sum().item() + 1e-12
83
+ loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum()
84
+
85
+ return loss
86
+
87
+ class ChangeSimilarity(nn.Module):
88
+ """input: x1, x2 multi-class predictions, c = class_num
89
+ label_change: changed part
90
+ """
91
+
92
+ def __init__(self, reduction='mean'):
93
+ super(ChangeSimilarity, self).__init__()
94
+ self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction)
95
+
96
+ def forward(self, x1, x2, label_change):
97
+ b, c, h, w = x1.size()
98
+ x1 = F.softmax(x1, dim=1)
99
+ x2 = F.softmax(x2, dim=1)
100
+ x1 = x1.permute(0, 2, 3, 1)
101
+ x2 = x2.permute(0, 2, 3, 1)
102
+ x1 = torch.reshape(x1, [b * h * w, c])
103
+ x2 = torch.reshape(x2, [b * h * w, c])
104
+
105
+ label_unchange = ~label_change.bool()
106
+ target = label_unchange.float()
107
+ target = target - label_change.float()
108
+ target = torch.reshape(target, [b * h * w])
109
+
110
+ loss = self.loss_f(x1, x2, target)
111
+ return loss
112
+
113
+ def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]):
114
+ """Calculating the loss"""
115
+ loss = 0
116
+
117
+ # gamma=0, alpha=None --> CE
118
+ # focal = FocalLoss(gamma=0, alpha=None)
119
+ # ssim = SSIM()
120
+
121
+ for i,prediction in enumerate(predictions):
122
+
123
+ bce = cross_entropy(prediction, target)
124
+ dice = dice_loss(prediction, target)
125
+ # ssimloss = ssim(prediction, target)
126
+ loss += weight[i]*(bce + dice) #- ssimloss
127
+
128
+ return loss
129
+
130
+ class BCL(nn.Module):
131
+ """
132
+ batch-balanced contrastive loss
133
+ no-change,1
134
+ change,-1
135
+ """
136
+ def __init__(self, margin=2.0):
137
+ super(BCL, self).__init__()
138
+ self.margin = margin
139
+
140
+ def forward(self, distance, label):
141
+ label[label == 1] = -1
142
+ label[label == 0] = 1
143
+
144
+ mask = (label != 255).float()
145
+ distance = distance * mask
146
+
147
+ pos_num = torch.sum((label==1).float())+0.0001
148
+ neg_num = torch.sum((label==-1).float())+0.0001
149
+
150
+ loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
151
+ loss_2 = torch.sum((1-label) / 2 *
152
+ torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
153
+ ) / neg_num
154
+ loss = loss_1 + loss_2
155
+ return loss
models/mamba_customer.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+ import numbers
3
+ from mamba_ssm.modules.mamba_simple import Mamba
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+
7
+ from timm.models.layers import DropPath, to_2tuple
8
+
9
+ import math
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+
17
+ from einops import rearrange, repeat
18
+
19
+ try:
20
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
21
+ except ImportError:
22
+ causal_conv1d_fn, causal_conv1d_update = None
23
+
24
+ try:
25
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
26
+ except ImportError:
27
+ selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
28
+
29
+ try:
30
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
31
+ except ImportError:
32
+ selective_state_update = None
33
+
34
+ try:
35
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
36
+ except ImportError:
37
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
38
+
39
+ class LightweightModel(nn.Module):
40
+ def __init__(self, in_channels, out_channels):
41
+ super(LightweightModel, self).__init__()
42
+ self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels)
43
+ self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
44
+
45
+ def forward(self, x):
46
+ x = self.depthwise_conv(x)
47
+ x = self.pointwise_conv(x)
48
+ return x
49
+
50
+
51
+ class ConvMamba(nn.Module):
52
+ def __init__(
53
+ self,
54
+ d_model,
55
+ d_state=16,
56
+ d_conv=4,
57
+ expand=2,
58
+ dt_rank="auto",
59
+ dt_min=0.001,
60
+ dt_max=0.1,
61
+ dt_init="random",
62
+ dt_scale=1.0,
63
+ dt_init_floor=1e-4,
64
+ conv_bias=True,
65
+ bias=False,
66
+ use_fast_path=True,
67
+ layer_idx=None,
68
+ device=None,
69
+ dtype=None,
70
+ bimamba_type="none",
71
+ conv_mode = "deepwise"
72
+ ):
73
+ factory_kwargs = {"device": device, "dtype": dtype}
74
+ super().__init__()
75
+ self.conv_mode = conv_mode
76
+ self.d_model = d_model
77
+ self.d_state = d_state
78
+ self.d_conv = d_conv
79
+ self.expand = expand
80
+ self.d_inner = int(self.expand * self.d_model)
81
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
82
+ self.use_fast_path = use_fast_path
83
+ self.layer_idx = layer_idx
84
+ self.bimamba_type = bimamba_type
85
+
86
+ if self.conv_mode == "orignal":
87
+ self.local_relation = nn.Sequential(
88
+ nn.Conv2d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, stride=1, padding=1),
89
+ nn.SiLU(),
90
+ nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
91
+ )
92
+ elif self.conv_mode == "orignal_1_5_dmodel":
93
+ self.local_relation = nn.Sequential(
94
+ nn.Conv2d(in_channels=self.d_model, out_channels=int(1.5*self.d_model), kernel_size=3, stride=1, padding=1),
95
+ nn.SiLU(),
96
+ nn.Conv2d(in_channels=int(1.5*self.d_model), out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
97
+ )
98
+ elif self.conv_mode == "orignal_dinner":
99
+ self.local_relation = nn.Sequential(
100
+ nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
101
+ nn.SiLU(),
102
+ nn.Conv2d(in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
103
+ )
104
+ elif self.conv_mode == "deepwise":
105
+ self.local_relation = nn.Sequential(
106
+ LightweightModel(in_channels=self.d_model, out_channels=self.d_model),
107
+ nn.SiLU(),
108
+ LightweightModel(in_channels=self.d_model, out_channels=self.d_inner),
109
+ )
110
+ elif self.conv_mode == "deepwise_dinner":
111
+ self.local_relation = nn.Sequential(
112
+ LightweightModel(in_channels=self.d_model, out_channels=self.d_inner),
113
+ nn.SiLU(),
114
+ LightweightModel(in_channels=self.d_inner, out_channels=self.d_inner),
115
+ )
116
+
117
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
118
+
119
+ self.conv1d = nn.Conv1d(
120
+ in_channels=self.d_inner,
121
+ out_channels=self.d_inner,
122
+ bias=conv_bias,
123
+ kernel_size=d_conv,
124
+ groups=self.d_inner,
125
+ padding=d_conv - 1,
126
+ **factory_kwargs,
127
+ )
128
+
129
+ self.activation = "silu"
130
+ self.act = nn.SiLU()
131
+
132
+ self.x_proj = nn.Linear(
133
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
134
+ )
135
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
136
+
137
+ # Initialize special dt projection to preserve variance at initialization
138
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
139
+ if dt_init == "constant":
140
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
141
+ elif dt_init == "random":
142
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
147
+ dt = torch.exp(
148
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
149
+ + math.log(dt_min)
150
+ ).clamp(min=dt_init_floor)
151
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
152
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
153
+ with torch.no_grad():
154
+ self.dt_proj.bias.copy_(inv_dt)
155
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
156
+ self.dt_proj.bias._no_reinit = True
157
+
158
+ # S4D real initialization
159
+ A = repeat(
160
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
161
+ "n -> d n",
162
+ d=self.d_inner,
163
+ ).contiguous()
164
+ A_log = torch.log(A) # Keep A_log in fp32
165
+ self.A_log = nn.Parameter(A_log)
166
+ self.A_log._no_weight_decay = True
167
+
168
+ # D "skip" parameter
169
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
170
+ self.D._no_weight_decay = True
171
+
172
+ # bidirectional
173
+ assert bimamba_type == "v2"
174
+
175
+ A_b = repeat(
176
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
177
+ "n -> d n",
178
+ d=self.d_inner,
179
+ ).contiguous()
180
+ A_b_log = torch.log(A_b) # Keep A_b_log in fp32
181
+ self.A_b_log = nn.Parameter(A_b_log)
182
+ self.A_b_log._no_weight_decay = True
183
+
184
+ self.conv1d_b = nn.Conv1d(
185
+ in_channels=self.d_inner,
186
+ out_channels=self.d_inner,
187
+ bias=conv_bias,
188
+ kernel_size=d_conv,
189
+ groups=self.d_inner,
190
+ padding=d_conv - 1,
191
+ **factory_kwargs,
192
+ )
193
+
194
+ self.x_proj_b = nn.Linear(
195
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
196
+ )
197
+ self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
198
+
199
+ self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
200
+ self.D_b._no_weight_decay = True
201
+
202
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
203
+
204
+ def forward(self, hidden_states, inference_params=None):
205
+ """
206
+ hidden_states: (B, L, D)
207
+ Returns: same shape as hidden_states
208
+ """
209
+ batch, seqlen, dim = hidden_states.shape
210
+ h = int(math.sqrt(seqlen))
211
+
212
+ local_relation = self.local_relation(rearrange(hidden_states, "b (h w) d -> b d h w", h=h))
213
+ local_relation = rearrange(local_relation, "b d h w -> b d (h w)")
214
+
215
+ conv_state, ssm_state = None, None
216
+ if inference_params is not None:
217
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
218
+ if inference_params.seqlen_offset > 0:
219
+ # The states are updated inplace
220
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
221
+ return out
222
+
223
+ # We do matmul and transpose BLH -> HBL at the same time
224
+ xz = rearrange(
225
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
226
+ "d (b l) -> b d l",
227
+ l=seqlen,
228
+ )
229
+ if self.in_proj.bias is not None:
230
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
231
+
232
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
233
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
234
+ if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
235
+ if self.bimamba_type == "v2":
236
+ A_b = -torch.exp(self.A_b_log.float())
237
+ out = mamba_inner_fn_no_out_proj(
238
+ xz,
239
+ self.conv1d.weight,
240
+ self.conv1d.bias,
241
+ self.x_proj.weight,
242
+ self.dt_proj.weight,
243
+ A,
244
+ None, # input-dependent B
245
+ None, # input-dependent C
246
+ self.D.float(),
247
+ delta_bias=self.dt_proj.bias.float(),
248
+ delta_softplus=True,
249
+ )
250
+ out_b = mamba_inner_fn_no_out_proj(
251
+ xz.flip([-1]),
252
+ self.conv1d_b.weight,
253
+ self.conv1d_b.bias,
254
+ self.x_proj_b.weight,
255
+ self.dt_proj_b.weight,
256
+ A_b,
257
+ None,
258
+ None,
259
+ self.D_b.float(),
260
+ delta_bias=self.dt_proj_b.bias.float(),
261
+ delta_softplus=True,
262
+ )
263
+ # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
264
+ out = F.linear(rearrange(out + out_b.flip([-1]) + local_relation, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
265
+ else:
266
+ out = mamba_inner_fn(
267
+ xz,
268
+ self.conv1d.weight,
269
+ self.conv1d.bias,
270
+ self.x_proj.weight,
271
+ self.dt_proj.weight,
272
+ self.out_proj.weight,
273
+ self.out_proj.bias,
274
+ A,
275
+ None, # input-dependent B
276
+ None, # input-dependent C
277
+ self.D.float(),
278
+ delta_bias=self.dt_proj.bias.float(),
279
+ delta_softplus=True,
280
+ )
281
+ else:
282
+ x, z = xz.chunk(2, dim=1)
283
+ # Compute short convolution
284
+ if conv_state is not None:
285
+ conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
286
+ if causal_conv1d_fn is None:
287
+ x = self.act(self.conv1d(x)[..., :seqlen])
288
+ else:
289
+ assert self.activation in ["silu", "swish"]
290
+ x = causal_conv1d_fn(
291
+ x,
292
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
293
+ self.conv1d.bias,
294
+ self.activation,
295
+ )
296
+
297
+ # We're careful here about the layout, to avoid extra transposes.
298
+ # We want dt to have d as the slowest moving dimension
299
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
300
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
301
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
302
+ dt = self.dt_proj.weight @ dt.t()
303
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
304
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
305
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
306
+ assert self.activation in ["silu", "swish"]
307
+ y = selective_scan_fn(
308
+ x,
309
+ dt,
310
+ A,
311
+ B,
312
+ C,
313
+ self.D.float(),
314
+ z=z,
315
+ delta_bias=self.dt_proj.bias.float(),
316
+ delta_softplus=True,
317
+ return_last_state=ssm_state is not None,
318
+ )
319
+ if ssm_state is not None:
320
+ y, last_state = y
321
+ ssm_state.copy_(last_state)
322
+ y = rearrange(y, "b d l -> b l d")
323
+ out = self.out_proj(y)
324
+ return out
325
+
326
+ def step(self, hidden_states, conv_state, ssm_state):
327
+ dtype = hidden_states.dtype
328
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
329
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
330
+ x, z = xz.chunk(2, dim=-1) # (B D)
331
+
332
+ # Conv step
333
+ if causal_conv1d_update is None:
334
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
335
+ conv_state[:, :, -1] = x
336
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
337
+ if self.conv1d.bias is not None:
338
+ x = x + self.conv1d.bias
339
+ x = self.act(x).to(dtype=dtype)
340
+ else:
341
+ x = causal_conv1d_update(
342
+ x,
343
+ conv_state,
344
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
345
+ self.conv1d.bias,
346
+ self.activation,
347
+ )
348
+
349
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
350
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
351
+ # Don't add dt_bias here
352
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
353
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
354
+
355
+ # SSM step
356
+ if selective_state_update is None:
357
+ # Discretize A and B
358
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
359
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
360
+ dB = torch.einsum("bd,bn->bdn", dt, B)
361
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
362
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
363
+ y = y + self.D.to(dtype) * x
364
+ y = y * self.act(z) # (B D)
365
+ else:
366
+ y = selective_state_update(
367
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
368
+ )
369
+
370
+ out = self.out_proj(y)
371
+ return out.unsqueeze(1), conv_state, ssm_state
372
+
373
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
374
+ device = self.out_proj.weight.device
375
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
376
+ conv_state = torch.zeros(
377
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
378
+ )
379
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
380
+ # ssm_dtype = torch.float32
381
+ ssm_state = torch.zeros(
382
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
383
+ )
384
+ return conv_state, ssm_state
385
+
386
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
387
+ assert self.layer_idx is not None
388
+ if self.layer_idx not in inference_params.key_value_memory_dict:
389
+ batch_shape = (batch_size,)
390
+ conv_state = torch.zeros(
391
+ batch_size,
392
+ self.d_model * self.expand,
393
+ self.d_conv,
394
+ device=self.conv1d.weight.device,
395
+ dtype=self.conv1d.weight.dtype,
396
+ )
397
+ ssm_state = torch.zeros(
398
+ batch_size,
399
+ self.d_model * self.expand,
400
+ self.d_state,
401
+ device=self.dt_proj.weight.device,
402
+ dtype=self.dt_proj.weight.dtype,
403
+ # dtype=torch.float32,
404
+ )
405
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
406
+ else:
407
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
408
+ # TODO: What if batch size changes between generation, and we reuse the same states?
409
+ if initialize_states:
410
+ conv_state.zero_()
411
+ ssm_state.zero_()
412
+ return conv_state, ssm_state
413
+
414
+ def to_3d(x):
415
+ return rearrange(x, 'b c h w -> b (h w) c')
416
+
417
+ def to_4d(x, h, w):
418
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
419
+
420
+ class WithBias_LayerNorm(nn.Module):
421
+ def __init__(self, normalized_shape):
422
+ super(WithBias_LayerNorm, self).__init__()
423
+ if isinstance(normalized_shape, numbers.Integral):
424
+ normalized_shape = (normalized_shape,)
425
+ normalized_shape = torch.Size(normalized_shape)
426
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
427
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
428
+ def forward(self, x):
429
+ mu = x.mean(-1, keepdim=True)
430
+ sigma = x.var(-1, keepdim=True, unbiased=False)
431
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
432
+
433
+ class BiasFree_LayerNorm(nn.Module):
434
+ def __init__(self, normalized_shape):
435
+ super(BiasFree_LayerNorm, self).__init__()
436
+ if isinstance(normalized_shape, numbers.Integral):
437
+ normalized_shape = (normalized_shape,)
438
+ normalized_shape = torch.Size(normalized_shape)
439
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
440
+ def forward(self, x):
441
+ sigma = x.var(-1, keepdim=True, unbiased=False)
442
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
443
+
444
+ class LayerNorm(nn.Module):
445
+ def __init__(self, dim, norm_type='with_bias'):
446
+ super(LayerNorm, self).__init__()
447
+ if norm_type == 'BiasFree':
448
+ self.body = BiasFree_LayerNorm(dim)
449
+ else:
450
+ self.body = WithBias_LayerNorm(dim)
451
+ def forward(self, x):
452
+ if len(x.shape) == 4:
453
+ h, w = x.shape[-2:]
454
+ return to_4d(self.body(to_3d(x)), h, w)
455
+ else:
456
+ return self.body(x)
457
+
458
+ class M3(nn.Module):
459
+ def __init__(self, dim):
460
+ super(M3, self).__init__()
461
+ self.multi_modal_mamba_block = Mamba(dim, bimamba_type="m3")
462
+ self.norm1 = LayerNorm(dim, 'with_bias')# fusion
463
+ self.norm2 = LayerNorm(dim, 'with_bias')# I2
464
+ self.norm3 = LayerNorm(dim, 'with_bias')# I1
465
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
466
+
467
+ def forward(self, I1, I2, fusion, test_h, test_w):
468
+ fusion = self.norm1(fusion)
469
+ I2 = self.norm2(I2)
470
+ I1 = self.norm3(I1)
471
+ global_f = self.multi_modal_mamba_block(fusion, extra_emb1=I2, extra_emb2=I1)# [B, HW, C]
472
+ B, HW, C = global_f.shape
473
+ fusion = global_f.transpose(1, 2).view(B, C, test_h, test_w)
474
+ fusion = (self.dwconv(fusion) + fusion).flatten(2).transpose(1, 2)
475
+ return fusion, None
476
+
477
+ class PatchEmbed(nn.Module):
478
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
479
+ super(PatchEmbed, self).__init__()
480
+ img_size = to_2tuple(img_size)
481
+ patch_size = to_2tuple(patch_size)
482
+ self.img_size = img_size
483
+ self.patch_size = patch_size
484
+ self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
485
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
486
+ self.in_chans = in_chans
487
+ self.embed_dim = embed_dim
488
+ self.norm = norm_layer(embed_dim) if norm_layer is not None else None
489
+ def forward(self, x):
490
+ # x: [B, C, H, W]
491
+ x = x.flatten(2).transpose(1, 2) # [B, N, C]
492
+ if self.norm is not None:
493
+ x = self.norm(x)
494
+ return x
495
+
496
+ class PatchUnEmbed(nn.Module):
497
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
498
+ super(PatchUnEmbed, self).__init__()
499
+ img_size = to_2tuple(img_size)
500
+ patch_size = to_2tuple(patch_size)
501
+ self.img_size = img_size
502
+ self.patch_size = patch_size
503
+ self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
504
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
505
+ self.in_chans = in_chans
506
+ self.embed_dim = embed_dim
507
+ def forward(self, x, x_size):
508
+ B, HW, C = x.shape
509
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])
510
+ return x
511
+
512
+
513
+ class Block(nn.Module):
514
+ def __init__(
515
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
516
+ ):
517
+ """
518
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
519
+
520
+ This Block has a slightly different structure compared to a regular
521
+ prenorm Transformer block.
522
+ The standard block is: LN -> MHA/MLP -> Add.
523
+ [Ref: https://arxiv.org/abs/2002.04745]
524
+ Here we have: Add -> LN -> Mixer, returning both
525
+ the hidden_states (output of the mixer) and the residual.
526
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
527
+ The residual needs to be provided (except for the very first block).
528
+ """
529
+ super().__init__()
530
+ self.residual_in_fp32 = residual_in_fp32
531
+ self.fused_add_norm = fused_add_norm
532
+ self.mixer = mixer_cls(dim)
533
+ self.norm = norm_cls(dim)
534
+ if self.fused_add_norm:
535
+ assert RMSNorm is not None, "RMSNorm import fails"
536
+ assert isinstance(
537
+ self.norm, (nn.LayerNorm, RMSNorm)
538
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
539
+
540
+ def forward(
541
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
542
+ ):
543
+ r"""Pass the input through the encoder layer.
544
+
545
+ Args:
546
+ hidden_states: the sequence to the encoder layer (required).
547
+ residual: hidden_states = Mixer(LN(residual))
548
+ """
549
+ if not self.fused_add_norm:
550
+ residual = (hidden_states + residual) if residual is not None else hidden_states
551
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
552
+ if self.residual_in_fp32:
553
+ residual = residual.to(torch.float32)
554
+ else:
555
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
556
+ hidden_states, residual = fused_add_norm_fn(
557
+ hidden_states,
558
+ self.norm.weight,
559
+ self.norm.bias,
560
+ residual=residual,
561
+ prenorm=True,
562
+ residual_in_fp32=self.residual_in_fp32,
563
+ eps=self.norm.eps,
564
+ )
565
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
566
+ return hidden_states, residual
567
+
568
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
569
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
models/resnet.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.hub import load_state_dict_from_url
4
+
5
+
6
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8
+ 'wide_resnet50_2', 'wide_resnet101_2']
9
+
10
+
11
+ model_urls = {
12
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21
+ }
22
+
23
+
24
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25
+ """3x3 convolution with padding"""
26
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
28
+
29
+
30
+ def conv1x1(in_planes, out_planes, stride=1):
31
+ """1x1 convolution"""
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33
+
34
+
35
+ class BasicBlock(nn.Module):
36
+ expansion = 1
37
+
38
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39
+ base_width=64, dilation=1, norm_layer=None):
40
+ super(BasicBlock, self).__init__()
41
+ if norm_layer is None:
42
+ norm_layer = nn.BatchNorm2d
43
+ if groups != 1 or base_width != 64:
44
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45
+ if dilation > 1:
46
+ dilation = 1
47
+ # raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
+ self.conv1 = conv3x3(inplanes, planes, stride)
50
+ self.bn1 = norm_layer(planes)
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = norm_layer(planes)
54
+ self.downsample = downsample
55
+ self.stride = stride
56
+
57
+ def forward(self, x):
58
+ identity = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ if self.downsample is not None:
68
+ identity = self.downsample(x)
69
+
70
+ out += identity
71
+ out = self.relu(out)
72
+
73
+ return out
74
+
75
+
76
+ class Bottleneck(nn.Module):
77
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
80
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
81
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82
+
83
+ expansion = 4
84
+
85
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86
+ base_width=64, dilation=1, norm_layer=None):
87
+ super(Bottleneck, self).__init__()
88
+ if norm_layer is None:
89
+ norm_layer = nn.BatchNorm2d
90
+ width = int(planes * (base_width / 64.)) * groups
91
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92
+ self.conv1 = conv1x1(inplanes, width)
93
+ self.bn1 = norm_layer(width)
94
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
95
+ self.bn2 = norm_layer(width)
96
+ self.conv3 = conv1x1(width, planes * self.expansion)
97
+ self.bn3 = norm_layer(planes * self.expansion)
98
+ self.relu = nn.ReLU(inplace=True)
99
+ self.downsample = downsample
100
+ self.stride = stride
101
+
102
+ def forward(self, x):
103
+ identity = x
104
+
105
+ out = self.conv1(x)
106
+ out = self.bn1(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv2(out)
110
+ out = self.bn2(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv3(out)
114
+ out = self.bn3(out)
115
+
116
+ if self.downsample is not None:
117
+ identity = self.downsample(x)
118
+
119
+ out += identity
120
+ out = self.relu(out)
121
+
122
+ return out
123
+
124
+
125
+ class ResNet(nn.Module):
126
+
127
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
128
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
129
+ norm_layer=None, strides=None):
130
+ super(ResNet, self).__init__()
131
+ if norm_layer is None:
132
+ norm_layer = nn.BatchNorm2d
133
+ self._norm_layer = norm_layer
134
+
135
+ self.strides = strides
136
+ if self.strides is None:
137
+ self.strides = [2, 2, 2, 2, 2]
138
+
139
+ self.inplanes = 64
140
+ self.dilation = 1
141
+ if replace_stride_with_dilation is None:
142
+ # each element in the tuple indicates if we should replace
143
+ # the 2x2 stride with a dilated convolution instead
144
+ replace_stride_with_dilation = [False, False, False]
145
+ if len(replace_stride_with_dilation) != 3:
146
+ raise ValueError("replace_stride_with_dilation should be None "
147
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
148
+ self.groups = groups
149
+ self.base_width = width_per_group
150
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3,
151
+ bias=False)
152
+ self.bn1 = norm_layer(self.inplanes)
153
+ self.relu = nn.ReLU(inplace=True)
154
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
155
+ self.layer1 = self._make_layer(block, 64, layers[0])
156
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2],
157
+ dilate=replace_stride_with_dilation[0])
158
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3],
159
+ dilate=replace_stride_with_dilation[1])
160
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4],
161
+ dilate=replace_stride_with_dilation[2])
162
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
163
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
164
+
165
+ for m in self.modules():
166
+ if isinstance(m, nn.Conv2d):
167
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
+ nn.init.constant_(m.weight, 1)
170
+ nn.init.constant_(m.bias, 0)
171
+
172
+ # Zero-initialize the last BN in each residual branch,
173
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
174
+ # This improves the models by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
175
+ if zero_init_residual:
176
+ for m in self.modules():
177
+ if isinstance(m, Bottleneck):
178
+ nn.init.constant_(m.bn3.weight, 0)
179
+ elif isinstance(m, BasicBlock):
180
+ nn.init.constant_(m.bn2.weight, 0)
181
+
182
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
183
+ norm_layer = self._norm_layer
184
+ downsample = None
185
+ previous_dilation = self.dilation
186
+ if dilate:
187
+ self.dilation *= stride
188
+ stride = 1
189
+ if stride != 1 or self.inplanes != planes * block.expansion:
190
+ downsample = nn.Sequential(
191
+ conv1x1(self.inplanes, planes * block.expansion, stride),
192
+ norm_layer(planes * block.expansion),
193
+ )
194
+
195
+ layers = []
196
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
197
+ self.base_width, previous_dilation, norm_layer))
198
+ self.inplanes = planes * block.expansion
199
+ for _ in range(1, blocks):
200
+ layers.append(block(self.inplanes, planes, groups=self.groups,
201
+ base_width=self.base_width, dilation=self.dilation,
202
+ norm_layer=norm_layer))
203
+
204
+ return nn.Sequential(*layers)
205
+
206
+ def _forward_impl(self, x):
207
+ # See note [TorchScript super()]
208
+ x = self.conv1(x)
209
+ x = self.bn1(x)
210
+ x = self.relu(x)
211
+ x = self.maxpool(x)
212
+
213
+ x = self.layer1(x)
214
+ x = self.layer2(x)
215
+ x = self.layer3(x)
216
+ x = self.layer4(x)
217
+
218
+ x = self.avgpool(x)
219
+ x = torch.flatten(x, 1)
220
+ x = self.fc(x)
221
+
222
+ return x
223
+
224
+ def forward(self, x):
225
+ return self._forward_impl(x)
226
+
227
+
228
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
229
+ model = ResNet(block, layers, **kwargs)
230
+ if pretrained:
231
+ state_dict = load_state_dict_from_url(model_urls[arch],
232
+ progress=progress)
233
+ model.load_state_dict(state_dict)
234
+ return model
235
+
236
+
237
+ def resnet18(pretrained=False, progress=True, **kwargs):
238
+ r"""ResNet-18 models from
239
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
240
+
241
+ Args:
242
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
243
+ progress (bool): If True, displays a progress bar of the download to stderr
244
+ """
245
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
246
+ **kwargs)
247
+
248
+
249
+ def resnet34(pretrained=False, progress=True, **kwargs):
250
+ r"""ResNet-34 models from
251
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
252
+
253
+ Args:
254
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
255
+ progress (bool): If True, displays a progress bar of the download to stderr
256
+ """
257
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
258
+ **kwargs)
259
+
260
+
261
+ def resnet50(pretrained=False, progress=True, **kwargs):
262
+ r"""ResNet-50 models from
263
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
264
+
265
+ Args:
266
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
267
+ progress (bool): If True, displays a progress bar of the download to stderr
268
+ """
269
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
270
+ **kwargs)
271
+
272
+
273
+ def resnet101(pretrained=False, progress=True, **kwargs):
274
+ r"""ResNet-101 models from
275
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
276
+
277
+ Args:
278
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
279
+ progress (bool): If True, displays a progress bar of the download to stderr
280
+ """
281
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
282
+ **kwargs)
283
+
284
+
285
+ def resnet152(pretrained=False, progress=True, **kwargs):
286
+ r"""ResNet-152 models from
287
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
288
+
289
+ Args:
290
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
291
+ progress (bool): If True, displays a progress bar of the download to stderr
292
+ """
293
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
294
+ **kwargs)
295
+
296
+
297
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
298
+ r"""ResNeXt-50 32x4d models from
299
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
300
+
301
+ Args:
302
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
303
+ progress (bool): If True, displays a progress bar of the download to stderr
304
+ """
305
+ kwargs['groups'] = 32
306
+ kwargs['width_per_group'] = 4
307
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
308
+ pretrained, progress, **kwargs)
309
+
310
+
311
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
312
+ r"""ResNeXt-101 32x8d models from
313
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
314
+
315
+ Args:
316
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
317
+ progress (bool): If True, displays a progress bar of the download to stderr
318
+ """
319
+ kwargs['groups'] = 32
320
+ kwargs['width_per_group'] = 8
321
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
322
+ pretrained, progress, **kwargs)
323
+
324
+
325
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
326
+ r"""Wide ResNet-50-2 models from
327
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
328
+
329
+ The models is the same as ResNet except for the bottleneck number of channels
330
+ which is twice larger in every block. The number of channels in outer 1x1
331
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
332
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
333
+
334
+ Args:
335
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
336
+ progress (bool): If True, displays a progress bar of the download to stderr
337
+ """
338
+ kwargs['width_per_group'] = 64 * 2
339
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
340
+ pretrained, progress, **kwargs)
341
+
342
+
343
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
344
+ r"""Wide ResNet-101-2 models from
345
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
346
+
347
+ The models is the same as ResNet except for the bottleneck number of channels
348
+ which is twice larger in every block. The number of channels in outer 1x1
349
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
350
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
351
+
352
+ Args:
353
+ pretrained (bool): If True, returns a models pre-trained on ImageNet
354
+ progress (bool): If True, displays a progress bar of the download to stderr
355
+ """
356
+ kwargs['width_per_group'] = 64 * 2
357
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
358
+ pretrained, progress, **kwargs)