InPeerReview commited on
Commit
2317bc0
·
verified ·
1 Parent(s): a201e30

Upload 2 files

Browse files
rscd/models/decoderheads/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rscd.models.decoderheads.stnet import STNet
2
+ from rscd.models.decoderheads.cdmask import CDMask
3
+ from rscd.models.decoderheads.DDLNet import DDLNet
4
+
5
+ from rscd.models.decoderheads.BIThead import BASE_Transformer
6
+ from rscd.models.decoderheads.SNUNet_ECAM import ECAM
7
+ from rscd.models.decoderheads.CFde import ChangeFormer_DE
8
+ from rscd.models.decoderheads.lgpnet_b import LGPNet_b
9
+ from rscd.models.decoderheads.SARASNet import Change_detection
10
+ from rscd.models.decoderheads.AFCF3D_de import AFCD3D_decoder
11
+ from rscd.models.decoderheads.USSFCNet import USSFCNet_decoder
12
+
13
+ from rscd.models.decoderheads.mamba_cttf import CTTF
14
+ from rscd.models.decoderheads.xformer3 import CDXLSTM3
15
+ from rscd.models.decoderheads.detector import changedetector
16
+ from rscd.models.decoderheads.ChangeMambaDecoder import CMDecoder
17
+
18
+ from rscd.models.decoderheads.none import none_class
19
+
20
+ from rscd.models.decoderheads.a2net import A2Net
21
+
22
+ from rscd.models.decoderheads.nnUNetTrainer_WNet2D import WNet2D
23
+ from rscd.models.decoderheads.nnUNetTrainer_WNet2D_L import WNet2D_L
24
+ from rscd.models.decoderheads.Sea_WNet2D import Sea_WNet
25
+
26
+ from rscd.models.decoderheads.fc_ef import FC_ef
27
+ from rscd.models.decoderheads.fc_siam_conc import FC_siam_conc
28
+ from rscd.models.decoderheads.fc_sima_diff import FC_siam_diff
29
+
30
+ from rscd.models.decoderheads.IFnet import DSIFN
31
+ from rscd.models.decoderheads.LCD import LCD_Net
32
+
33
+ from rscd.models.decoderheads.acabfnet import CrossNet
34
+ from rscd.models.decoderheads.paformer import Paformer
rscd/models/decoderheads/mamba_cttf.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.cuda.amp import autocast
4
+ from rscd.models.decoderheads.vision_lstm import ViLBlock, SequenceTraversal
5
+ from torch.nn import functional as F
6
+ from functools import partial
7
+ from rscd.models.backbones.lib_mamba.vmambanew import SS2D
8
+ import pywt
9
+
10
+ class PA(nn.Module):
11
+ def __init__(self, dim, norm_layer, act_layer):
12
+ super().__init__()
13
+ self.p_conv = nn.Sequential(
14
+ nn.Conv2d(dim, dim*4, 1, bias=False),
15
+ norm_layer(dim*4),
16
+ act_layer(),
17
+ nn.Conv2d(dim*4, dim, 1, bias=False)
18
+ )
19
+ self.gate_fn = nn.Sigmoid()
20
+
21
+ def forward(self, x):
22
+ att = self.p_conv(x)
23
+ x = x * self.gate_fn(att)
24
+
25
+ return x
26
+
27
+ class Mish(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ def forward(self, x):
32
+ return x * torch.tanh(F.softplus(x))
33
+
34
+
35
+ class _ScaleModule(nn.Module):
36
+ def __init__(self, dims, init_scale=1.0):
37
+ super().__init__()
38
+ self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
39
+
40
+ def forward(self, x):
41
+ return torch.mul(self.weight, x)
42
+
43
+ def create_wavelet_filter(wave, in_size, out_size, dtype=torch.float):
44
+ w = pywt.Wavelet(wave)
45
+ dec_hi = torch.tensor(w.dec_hi[::-1], dtype=dtype)
46
+ dec_lo = torch.tensor(w.dec_lo[::-1], dtype=dtype)
47
+ dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
48
+ dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
49
+ dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
50
+ dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
51
+ dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
52
+ rec_hi = torch.tensor(w.rec_hi[::-1], dtype=dtype).flip(dims=[0])
53
+ rec_lo = torch.tensor(w.rec_lo[::-1], dtype=dtype).flip(dims=[0])
54
+ rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
55
+ rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
56
+ rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
57
+ rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
58
+ rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
59
+ return dec_filters, rec_filters
60
+
61
+ def wavelet_transform(x, filters):
62
+ b, c, h, w = x.shape
63
+ pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
64
+ x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
65
+ x = x.reshape(b, c, 4, h // 2, w // 2)
66
+ return x
67
+
68
+ def inverse_wavelet_transform(x, filters):
69
+ b, c, _, h_half, w_half = x.shape
70
+ pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
71
+ x = x.reshape(b, c * 4, h_half, w_half)
72
+ x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
73
+ return x
74
+
75
+ class MBWTConv2d(nn.Module):
76
+ def __init__(self, in_channels, kernel_size=5, wt_levels=1, wt_type='db1', ssm_ratio=1, forward_type="v05"):
77
+ super().__init__()
78
+ assert in_channels == in_channels
79
+ self.wt_levels = wt_levels
80
+ self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels)
81
+ self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
82
+ self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
83
+ self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
84
+ self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
85
+ self.global_atten = SS2D(d_model=in_channels, d_state=1, ssm_ratio=ssm_ratio, initialize="v2",
86
+ forward_type=forward_type, channel_first=True, k_group=2)
87
+ self.base_scale = _ScaleModule([1, in_channels, 1, 1])
88
+ self.wavelet_convs = nn.ModuleList([
89
+ nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', groups=in_channels * 4)
90
+ for _ in range(wt_levels)
91
+ ])
92
+ self.wavelet_scale = nn.ModuleList([
93
+ _ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1)
94
+ for _ in range(wt_levels)
95
+ ])
96
+
97
+ def forward(self, x):
98
+ x_ll_in_levels, x_h_in_levels, shapes_in_levels = [], [], []
99
+ curr_x_ll = x
100
+ for i in range(self.wt_levels):
101
+ curr_shape = curr_x_ll.shape
102
+ shapes_in_levels.append(curr_shape)
103
+ if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
104
+ curr_x_ll = F.pad(curr_x_ll, (0, curr_shape[3] % 2, 0, curr_shape[2] % 2))
105
+ curr_x = self.wt_function(curr_x_ll)
106
+ curr_x_ll = curr_x[:, :, 0, :, :]
107
+ shape_x = curr_x.shape
108
+ curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
109
+ curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)).reshape(shape_x)
110
+ x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
111
+ x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
112
+ next_x_ll = 0
113
+ for i in range(self.wt_levels - 1, -1, -1):
114
+ curr_x_ll = x_ll_in_levels.pop() + next_x_ll
115
+ curr_x = torch.cat([curr_x_ll.unsqueeze(2), x_h_in_levels.pop()], dim=2)
116
+ next_x_ll = self.iwt_function(curr_x)
117
+ next_x_ll = next_x_ll[:, :, :shapes_in_levels[i][2], :shapes_in_levels[i][3]]
118
+ x_tag = next_x_ll
119
+ x = self.base_scale(self.global_atten(x)) + x_tag
120
+ return x
121
+
122
+ class ChannelAttention(nn.Module):
123
+ def __init__(self, in_planes, ratio=16):
124
+ super(ChannelAttention, self).__init__()
125
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
126
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
127
+
128
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
129
+ self.relu1 = nn.ReLU()
130
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
131
+ self.sigmoid = nn.Sigmoid()
132
+
133
+ def forward(self, x):
134
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
135
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
136
+ out = avg_out + max_out
137
+ return self.sigmoid(out)
138
+
139
+ # 空间注意力模块
140
+ class SpatialAttention(nn.Module):
141
+ def __init__(self, kernel_size=7):
142
+ super(SpatialAttention, self).__init__()
143
+
144
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
145
+ padding = 3 if kernel_size == 7 else 1
146
+
147
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
148
+ self.sigmoid = nn.Sigmoid()
149
+
150
+ def forward(self, x):
151
+ avg_out = torch.mean(x, dim=1, keepdim=True)
152
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
153
+ x = torch.cat([avg_out, max_out], dim=1)
154
+ x = self.conv1(x)
155
+ return self.sigmoid(x)
156
+
157
+ # CBAM 注意力模块
158
+ class CBAM(nn.Module):
159
+ def __init__(self, in_planes):
160
+ super(CBAM, self).__init__()
161
+ self.ca = ChannelAttention(in_planes)
162
+ self.sa = SpatialAttention()
163
+
164
+ def forward(self, x):
165
+ x = self.ca(x) * x
166
+ x = self.sa(x) * x
167
+ return x
168
+
169
+ class DynamicConv2d(nn.Module):
170
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False, num_experts=4):
171
+ super(DynamicConv2d, self).__init__()
172
+ self.in_channels = in_channels
173
+ self.out_channels = out_channels
174
+ self.kernel_size = kernel_size
175
+ self.stride = stride
176
+ self.padding = padding
177
+ self.groups = groups
178
+ self.bias = bias
179
+ self.num_experts = num_experts
180
+
181
+ self.experts = nn.ModuleList([
182
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias)
183
+ for _ in range(num_experts)
184
+ ])
185
+ self.gating = nn.Sequential(
186
+ nn.AdaptiveAvgPool2d(1),
187
+ nn.Conv2d(in_channels, num_experts, 1, bias=False),
188
+ nn.Softmax(dim=1)
189
+ )
190
+
191
+ def forward(self, x):
192
+ gates = self.gating(x)
193
+ gates = gates.view(x.size(0), self.num_experts, 1, 1, 1)
194
+ outputs = []
195
+ for i, expert in enumerate(self.experts):
196
+ outputs.append(expert(x).unsqueeze(1))
197
+ outputs = torch.cat(outputs, dim=1)
198
+ out = (gates * outputs).sum(dim=1)
199
+ return out
200
+
201
+ class DWConv2d_BN_ReLU(nn.Sequential):
202
+ def __init__(self, in_channels, out_channels, kernel_size=3):
203
+ super().__init__()
204
+ self.add_module('dwconv3x3', DynamicConv2d(in_channels, in_channels, kernel_size=kernel_size,
205
+ stride=1, padding=kernel_size // 2, groups=in_channels, bias=False))
206
+ self.add_module('bn1', nn.BatchNorm2d(in_channels))
207
+ self.add_module('relu', Mish())
208
+ self.add_module('dwconv1x1', nn.Conv2d(in_channels, out_channels, kernel_size=1,
209
+ stride=1, padding=0, groups=in_channels, bias=False))
210
+ self.add_module('bn2', nn.BatchNorm2d(out_channels))
211
+
212
+ class Conv2d_BN(nn.Sequential):
213
+ def __init__(self, a, b, ks=1, stride=1, pad=0, groups=1):
214
+ super().__init__()
215
+ self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, groups=groups, bias=False))
216
+ self.add_module('bn', nn.BatchNorm2d(b))
217
+
218
+ class FFN(nn.Module):
219
+ def __init__(self, ed, h):
220
+ super().__init__()
221
+ self.pw1 = Conv2d_BN(ed, h)
222
+ self.act = Mish()
223
+ self.pw2 = Conv2d_BN(h, ed)
224
+
225
+ def forward(self, x):
226
+ return self.pw2(self.act(self.pw1(x)))
227
+
228
+ class StochasticDepth(nn.Module):
229
+ def __init__(self, survival_prob=0.8):
230
+ super().__init__()
231
+ self.survival_prob = survival_prob
232
+
233
+ def forward(self, x):
234
+ if not self.training:
235
+ return x
236
+ batch_size = x.shape[0]
237
+ random_tensor = self.survival_prob + torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
238
+ binary_tensor = torch.floor(random_tensor)
239
+ return x * binary_tensor / self.survival_prob
240
+
241
+ class Residual(nn.Module):
242
+ def __init__(self, m, survival_prob=0.8):
243
+ super().__init__()
244
+ self.m = m
245
+ self.stochastic_depth = StochasticDepth(survival_prob)
246
+
247
+ def forward(self, x):
248
+ return x + self.stochastic_depth(self.m(x))
249
+
250
+ class GLP_block(nn.Module):
251
+ def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
252
+ super().__init__()
253
+ self.dim = dim
254
+ self.global_channels = int(global_ratio * dim)
255
+ self.local_channels = int(local_ratio * dim)
256
+ self.pa_channels = int(pa_ratio * dim)
257
+ self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
258
+ self.local_op = nn.ModuleList([
259
+ DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
260
+ for k in [3, 5, 7]
261
+ ]) if self.local_channels > 0 else nn.Identity()
262
+ self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
263
+ ssm_ratio=ssm_ratio, forward_type=forward_type) \
264
+ if self.global_channels > 0 else nn.Identity()
265
+ self.cbam = CBAM(dim)
266
+ self.proj = nn.Sequential(
267
+ Mish(),
268
+ Conv2d_BN(dim, dim),
269
+ CBAM(dim)
270
+ )
271
+
272
+ self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
273
+ if self.pa_channels > 0 else nn.Identity()
274
+
275
+ def forward(self, x):
276
+ x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
277
+ if isinstance(self.local_op, nn.ModuleList):
278
+ local_features = [op(x2) for op in self.local_op]
279
+ local_features = torch.cat(local_features, dim=1)
280
+ local_features = torch.mean(local_features, dim=1, keepdim=True)
281
+ local_features = local_features.expand(-1, self.local_channels, -1, -1)
282
+ else:
283
+ local_features = self.local_op(x2)
284
+ out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
285
+ return self.proj(out)
286
+
287
+
288
+
289
+
290
+
291
+ class SASF(nn.Module):
292
+ def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,pa_ratio = 0.1, kernels=3, ssm_ratio=1, forward_type="v052d"):
293
+ super().__init__()
294
+ self.dim = dim
295
+ self.global_channels = int(global_ratio * dim)
296
+ self.local_channels = int(local_ratio * dim)
297
+ self.pa_channels = int(pa_ratio * dim)
298
+ self.identity_channels = dim - self.global_channels - self.local_channels - self.pa_channels
299
+ self.local_op = nn.ModuleList([
300
+ DWConv2d_BN_ReLU(self.local_channels, self.local_channels, k)
301
+ for k in [3, 5, 7]
302
+ ]) if self.local_channels > 0 else nn.Identity()
303
+ self.global_op = MBWTConv2d(self.global_channels, kernel_size=kernels,
304
+ ssm_ratio=ssm_ratio, forward_type=forward_type) \
305
+ if self.global_channels > 0 else nn.Identity()
306
+ self.cbam = CBAM(dim)
307
+ self.proj = nn.Sequential(
308
+ Mish(),
309
+ Conv2d_BN(dim, dim),
310
+ CBAM(dim)
311
+ )
312
+
313
+ self.pa_op = PA(self.pa_channels, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU) \
314
+ if self.pa_channels > 0 else nn.Identity()
315
+
316
+ def forward(self, x):
317
+ x1, x2, x3, x4 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels, self.pa_channels], dim=1)
318
+ if isinstance(self.local_op, nn.ModuleList):
319
+ local_features = [op(x2) for op in self.local_op]
320
+ local_features = torch.cat(local_features, dim=1)
321
+ local_features = torch.mean(local_features, dim=1, keepdim=True)
322
+ local_features = local_features.expand(-1, self.local_channels, -1, -1)
323
+ else:
324
+ local_features = self.local_op(x2)
325
+ out = torch.cat([self.global_op(x1), local_features, x3, self.pa_op(x4)], dim=1)
326
+ return self.proj(out)
327
+
328
+
329
+ class ViLLayer(nn.Module):
330
+ def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
331
+ super().__init__()
332
+ self.dim = dim
333
+ self.norm = nn.LayerNorm(dim)
334
+ self.vil = ViLBlock(
335
+ dim= self.dim,
336
+ direction=SequenceTraversal.ROWWISE_FROM_TOP_LEFT
337
+ )
338
+
339
+ @autocast(enabled=False)
340
+ def forward(self, x):
341
+ if x.dtype == torch.float16:
342
+ x = x.type(torch.float32)
343
+ B, C = x.shape[:2]
344
+ assert C == self.dim
345
+ n_tokens = x.shape[2:].numel()
346
+ img_dims = x.shape[2:]
347
+ x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
348
+ x_vil = self.vil(x_flat)
349
+ out = x_vil.transpose(-1, -2).reshape(B, C, *img_dims)
350
+
351
+ return out
352
+
353
+ def dsconv_3x3(in_channel, out_channel):
354
+ return nn.Sequential(
355
+ nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
356
+ nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
357
+ nn.BatchNorm2d(out_channel),
358
+ nn.ReLU(inplace=True)
359
+ )
360
+
361
+ def conv_1x1(in_channel, out_channel):
362
+ return nn.Sequential(
363
+ nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
364
+ nn.BatchNorm2d(out_channel),
365
+ nn.ReLU(inplace=True)
366
+ )
367
+
368
+ class SqueezeAxialPositionalEmbedding(nn.Module):
369
+ def __init__(self, dim, shape):
370
+ super().__init__()
371
+
372
+ self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]))
373
+
374
+ def forward(self, x):
375
+ B, C, N = x.shape
376
+ x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False)
377
+
378
+ return x
379
+
380
+ class SEBlock(nn.Module):
381
+ def __init__(self, channels, r=16):
382
+ super().__init__()
383
+ self.fc = nn.Sequential(
384
+ nn.AdaptiveAvgPool2d(1),
385
+ nn.Conv2d(channels, channels//r, 1),
386
+ nn.ReLU(inplace=True),
387
+ nn.Conv2d(channels//r, channels, 1),
388
+ nn.Sigmoid()
389
+ )
390
+ def forward(self, x):
391
+ w = self.fc(x) # (B, C, 1, 1)
392
+ return x * w
393
+ class CTTF1(nn.Module):
394
+ def __init__(self, in_channel, out_channel,global_ratio=0.2, local_ratio=0.2, pa_ratio = 0.2 ,kernels=5, ssm_ratio=2.0, forward_type="v052d"):
395
+ super().__init__()
396
+ self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
397
+ self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
398
+ self.catconv = dsconv_3x3(in_channel * 2, out_channel)
399
+ self.convA = nn.Conv2d(in_channel, 1, 1)
400
+ self.convB = nn.Conv2d(in_channel, 1, 1)
401
+ self.sigmoid = nn.Sigmoid()
402
+
403
+ self.mixer = Residual(GLP_block(in_channel, global_ratio, local_ratio,pa_ratio, kernels, ssm_ratio, forward_type))
404
+ self.mixer2 = Residual(
405
+ SASF(in_channel, global_ratio = 0, local_ratio = 0.1, pa_ratio = 0, kernels = 5, ssm_ratio = 1, forward_type = "v052d"))
406
+
407
+ self.fuse = nn.Sequential(
408
+ nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
409
+ nn.ReLU(inplace=True)
410
+ )
411
+ self.cbam = CBAM(in_channel * 3)
412
+
413
+ self.act = nn.SiLU()
414
+ def forward(self, xA, xB):
415
+ x_diffA = self.mixer(xA)
416
+ x_diffB = self.mixer(xB)
417
+
418
+ f1 = x_diffA
419
+ f2 = x_diffB
420
+ diff_signed = f1 - f2
421
+ diff_abs = torch.abs(diff_signed)
422
+ sum_feat = f1 + f2
423
+
424
+ diff_signed = self.mixer2(diff_signed)
425
+ diff_abs = self.mixer2(diff_abs)
426
+ sum_feat = self.mixer2(sum_feat)
427
+ # 将多路特征在通道维度拼接
428
+ f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
429
+ # 再接一个 1x1 卷积降维或提炼信息
430
+ f_fuse = self.cbam(f_fuse)
431
+ x_diff = self.fuse(f_fuse)
432
+
433
+ return x_diff
434
+
435
+ class CTTF2(nn.Module):
436
+ def __init__(self, in_channel, out_channel, global_ratio=0.25, local_ratio=0.25, pa_ratio=0, kernels=7,
437
+ ssm_ratio=2.0, forward_type="v052d"):
438
+ super().__init__()
439
+ self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
440
+ self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
441
+ self.catconv = dsconv_3x3(in_channel * 2, out_channel)
442
+ self.convA = nn.Conv2d(in_channel, 1, 1)
443
+ self.convB = nn.Conv2d(in_channel, 1, 1)
444
+ self.sigmoid = nn.Sigmoid()
445
+
446
+
447
+ self.mixer = Residual(
448
+ GLP_block(in_channel, global_ratio, local_ratio, pa_ratio, kernels, ssm_ratio, forward_type))
449
+ self.mixer2 = Residual(
450
+ SASF(in_channel, global_ratio=0, local_ratio=0.1, pa_ratio=0, kernels=5, ssm_ratio=1,
451
+ forward_type="v052d"))
452
+
453
+ self.fuse = nn.Sequential(
454
+ nn.Conv2d(in_channel * 3, in_channel, kernel_size=1),
455
+ nn.ReLU(inplace=True)
456
+ )
457
+ self.cbam = CBAM(in_channel * 3)
458
+
459
+ self.act = nn.SiLU()
460
+
461
+ def forward(self, xA, xB):
462
+ x_diffA = self.mixer(xA)
463
+ x_diffB = self.mixer(xB)
464
+
465
+ f1 = x_diffA
466
+ f2 = x_diffB
467
+ diff_signed = f1 - f2
468
+ diff_abs = torch.abs(diff_signed)
469
+ sum_feat = f1 + f2
470
+
471
+ diff_signed = self.mixer2(diff_signed)
472
+ diff_abs = self.mixer2(diff_abs)
473
+ sum_feat = self.mixer2(sum_feat)
474
+ f_fuse = torch.cat([diff_signed, diff_abs, sum_feat], dim=1) # (B, 4C, H, W)
475
+ f_fuse = self.cbam(f_fuse)
476
+ x_diff = self.fuse(f_fuse)
477
+
478
+ return x_diff
479
+
480
+ class Mlp(nn.Module):
481
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True):
482
+ super().__init__()
483
+ out_features = out_features or in_features
484
+ hidden_features = hidden_features or in_features
485
+
486
+ Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
487
+ self.fc1 = Linear(in_features, hidden_features)
488
+ self.act = act_layer()
489
+ self.fc2 = Linear(hidden_features, out_features)
490
+ self.drop = nn.Dropout(drop)
491
+
492
+ def forward(self, x):
493
+ x = self.fc1(x)
494
+ x = self.act(x)
495
+ x = self.drop(x)
496
+ x = self.fc2(x)
497
+ x = self.drop(x)
498
+ return x
499
+
500
+ class LHBlock(nn.Module):
501
+ def __init__(self, channels_l, channels_h):
502
+ super().__init__()
503
+ self.channels_l = channels_l
504
+ self.channels_h = channels_h
505
+ self.cross_size = 12
506
+ self.cross_kv = nn.Sequential(
507
+ nn.BatchNorm2d(channels_l),
508
+ nn.AdaptiveMaxPool2d(output_size=(self.cross_size, self.cross_size)),
509
+ nn.Conv2d(channels_l, 2 * channels_h, 1, 1, 0)
510
+ )
511
+
512
+ self.conv = conv_1x1(channels_l, channels_h)
513
+ self.norm = nn.BatchNorm2d(channels_h)
514
+
515
+ self.mlp_l = Mlp(in_features=channels_l, out_features=channels_l)
516
+ self.mlp_h = Mlp(in_features=channels_h, out_features=channels_h)
517
+
518
+ def _act_sn(self, x):
519
+ _, _, H, W = x.shape
520
+ inner_channel = self.cross_size * self.cross_size
521
+ x = x.reshape([-1, inner_channel, H, W]) * (inner_channel**-0.5)
522
+ x = F.softmax(x, dim=1)
523
+ x = x.reshape([1, -1, H, W])
524
+ return x
525
+
526
+ def attn_h(self, x_h, cross_k, cross_v):
527
+ B, _, H, W = x_h.shape
528
+ x_h = self.norm(x_h)
529
+ x_h = x_h.reshape([1, -1, H, W]) # n,c_in,h,w -> 1,n*c_in,h,w
530
+ x_h = F.conv2d(x_h, cross_k, bias=None, stride=1, padding=0,
531
+ groups=B) # 1,n*c_in,h,w -> 1,n*144,h,w (group=B)
532
+ x_h = self._act_sn(x_h)
533
+ x_h = F.conv2d(x_h, cross_v, bias=None, stride=1, padding=0,
534
+ groups=B) # 1,n*144,h,w -> 1, n*c_in,h,w (group=B)
535
+ x_h = x_h.reshape([-1, self.channels_h, H,
536
+ W]) # 1, n*c_in,h,w -> n,c_in,h,w (c_in = c_out)
537
+
538
+ return x_h
539
+
540
+ def forward(self, x_l, x_h):
541
+ x_l = x_l + self.mlp_l(x_l)
542
+ x_l_conv = self.conv(x_l)
543
+ x_h = x_h + F.interpolate(x_l_conv, size=x_h.shape[2:], mode='bilinear')
544
+
545
+ cross_kv = self.cross_kv(x_l)
546
+ cross_k, cross_v = cross_kv.split(self.channels_h, 1)
547
+ cross_k = cross_k.permute(0, 2, 3, 1).reshape([-1, self.channels_h, 1, 1]) # n*144,channels_h,1,1
548
+ cross_v = cross_v.reshape([-1, self.cross_size * self.cross_size, 1, 1]) # n*channels_h,144,1,1
549
+
550
+ x_h = x_h + self.attn_h(x_h, cross_k, cross_v) # [4, 40, 128, 128]
551
+ x_h = x_h + self.mlp_h(x_h)
552
+
553
+ return x_h
554
+
555
+
556
+ class CTTF(nn.Module):
557
+ def __init__(self, channels=[40, 80, 192, 384]):
558
+ super().__init__()
559
+ self.channels = channels
560
+ self.fusion0 = CTTF1(channels[0], channels[0])
561
+ self.fusion1 = CTTF1(channels[1], channels[1])
562
+ self.fusion2 = CTTF2(channels[2], channels[2])
563
+ self.fusion3 = CTTF2(channels[3], channels[3])
564
+
565
+ self.LHBlock1 = LHBlock(channels[1], channels[0])
566
+ self.LHBlock2 = LHBlock(channels[2], channels[0])
567
+ self.LHBlock3 = LHBlock(channels[3], channels[0])
568
+
569
+ self.mlp1 = Mlp(in_features=channels[0], out_features=channels[0])
570
+ self.mlp2 = Mlp(in_features=channels[0], out_features=2)
571
+ self.dwc = dsconv_3x3(channels[0], channels[0])
572
+
573
+ def forward(self, inputs):
574
+ featuresA, featuresB = inputs
575
+ # fA_0, fA_1, fA_2, fA_3 = featuresA
576
+ # fB_0, fB_1, fB_2, fB_3 = featuresB
577
+ x_diff_0 = self.fusion0(featuresA[0], featuresB[0]) # [4, 40, 128, 128]
578
+ x_diff_1 = self.fusion1(featuresA[1], featuresB[1]) # [4, 80, 64, 64]
579
+ # x_diff_2 = featuresA[2] - featuresB[2]
580
+ # x_diff_3 = featuresA[3] - featuresB[3]
581
+ x_diff_2 = self.fusion2(featuresA[2], featuresB[2]) # [4, 192, 32, 32]
582
+ x_diff_3 = self.fusion3(featuresA[3], featuresB[3]) # [4, 384, 16, 16]
583
+
584
+ x_h = x_diff_0
585
+ x_h = self.LHBlock1(x_diff_1, x_h) # [4, 40, 128, 128]
586
+ x_h = self.LHBlock2(x_diff_2, x_h)
587
+ x_h = self.LHBlock3(x_diff_3, x_h)
588
+
589
+ out = self.mlp2(self.dwc(x_h) + self.mlp1(x_h))
590
+
591
+ out = F.interpolate(
592
+ out,
593
+ scale_factor=(4, 4),
594
+ mode="bilinear",
595
+ align_corners=False,
596
+ )
597
+ return out
598
+