InPeerReview commited on
Commit
8266e5f
·
verified ·
1 Parent(s): a050a18

Delete model/decoder.py

Browse files
Files changed (1) hide show
  1. model/decoder.py +0 -309
model/decoder.py DELETED
@@ -1,309 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from einops import rearrange
5
- from model.utils import weight_init
6
-
7
-
8
-
9
- def drop_path(x, drop_prob: float = 0., training: bool = False):
10
- if drop_prob == 0. or not training:
11
- return x
12
- keep_prob = 1 - drop_prob
13
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
14
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
15
- random_tensor.floor_() # binarize
16
- output = x.div(keep_prob) * random_tensor
17
- return output
18
-
19
-
20
- class DropPath(nn.Module):
21
- def __init__(self, drop_prob=None):
22
- super(DropPath, self).__init__()
23
- self.drop_prob = drop_prob
24
-
25
- def forward(self, x):
26
- return drop_path(x, self.drop_prob, self.training)
27
-
28
-
29
- class Mlp(nn.Module):
30
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
31
- super().__init__()
32
- out_features = out_features or in_features
33
- hidden_features = hidden_features or in_features
34
- self.fc1 = nn.Linear(in_features, hidden_features)
35
- self.act = act_layer()
36
- self.fc2 = nn.Linear(hidden_features, out_features)
37
- self.drop = nn.Dropout(drop)
38
-
39
- def forward(self, x):
40
- x = self.fc1(x)
41
- x = self.act(x)
42
- x = self.drop(x)
43
- x = self.fc2(x)
44
- x = self.drop(x)
45
- return x
46
-
47
-
48
-
49
- class CrossAttention(nn.Module):
50
- def __init__(self, dim1, dim2, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
51
- super().__init__()
52
- self.num_heads = num_heads
53
- head_dim = dim1 // num_heads
54
- self.scale = head_dim ** -0.5
55
-
56
- self.q = nn.Linear(dim1, dim1, bias=qkv_bias)
57
- self.kv = nn.Linear(dim2, dim1 * 2, bias=qkv_bias)
58
-
59
- self.attn_drop = nn.Dropout(attn_drop)
60
- self.proj = nn.Linear(dim1, dim1)
61
- self.proj_drop = nn.Dropout(proj_drop)
62
-
63
- def forward(self, x, y):
64
- B1, N1, C1 = x.shape
65
- B2, N2, C2 = y.shape
66
-
67
- q = self.q(x).reshape(B1, N1, self.num_heads, C1 // self.num_heads).permute(0, 2, 1, 3)
68
- kv = self.kv(y).reshape(B2, N2, 2, self.num_heads, C1 // self.num_heads).permute(2, 0, 3, 1, 4)
69
-
70
- k, v = kv[0], kv[1]
71
-
72
- attn = (q @ k.transpose(-2, -1)) * self.scale
73
- attn = attn.softmax(dim=-1)
74
- attn = self.attn_drop(attn)
75
-
76
- x = (attn @ v).transpose(1, 2).reshape(B1, N1, C1)
77
-
78
- x = self.proj(x)
79
- x = self.proj_drop(x)
80
-
81
- return x
82
-
83
-
84
-
85
- class Block(nn.Module):
86
- def __init__(self, dim1, dim2, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
87
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
88
- super().__init__()
89
- self.norm1 = norm_layer(dim1)
90
- self.norm2 = norm_layer(dim2)
91
- self.attn = CrossAttention(dim1, dim2, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
92
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
93
- self.norm3 = norm_layer(dim1)
94
- mlp_hidden_dim = int(dim1 * mlp_ratio)
95
- self.mlp = Mlp(in_features=dim1, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
96
-
97
- def forward(self, x, y):
98
- x = x + self.drop_path(self.attn(self.norm1(x), self.norm2(y)))
99
- x = x + self.drop_path(self.mlp(self.norm3(x)))
100
- return x
101
-
102
-
103
-
104
- class ContentAwareAggregation(nn.Module):
105
- def __init__(self, low_dim, high_dim):
106
- super().__init__()
107
- self.project = nn.Sequential(
108
- nn.Conv2d(high_dim, low_dim, kernel_size=1),
109
- nn.BatchNorm2d(low_dim),
110
- nn.ReLU(inplace=True)
111
- )
112
-
113
- self.attn_gen = nn.Sequential(
114
- nn.Conv2d(low_dim, low_dim, kernel_size=3, padding=1, groups=low_dim),
115
- nn.BatchNorm2d(low_dim),
116
- nn.ReLU(inplace=True),
117
- nn.Conv2d(low_dim, low_dim, kernel_size=1),
118
- nn.Sigmoid()
119
- )
120
-
121
- def forward(self, low_feat, high_feat):
122
- high_feat = F.interpolate(high_feat, size=low_feat.shape[2:], mode='bilinear', align_corners=False)
123
- high_feat = self.project(high_feat)
124
- attn = self.attn_gen(low_feat + high_feat)
125
- out = attn * low_feat + high_feat
126
- return out
127
-
128
-
129
-
130
- class FeatureInjector(nn.Module):
131
- def __init__(self, dim1=384, dim2=[64, 128, 256], num_heads=8, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
132
- drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
133
- super().__init__()
134
-
135
- self.c2_c5 = Block(dim1, dim2[0], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
136
- self.c3_c5 = Block(dim1, dim2[1], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
137
- self.c4_c5 = Block(dim1, dim2[2], num_heads, mlp_ratio, qkv_bias, drop, attn_drop, drop_path, act_layer, norm_layer)
138
-
139
- self.fuse = nn.Conv2d(dim1*3, dim1, 1, bias=False)
140
- self.caa = ContentAwareAggregation(dim1, dim1)
141
-
142
- weight_init(self)
143
-
144
- def base_forward(self, c2, c3, c4, c5):
145
- H, W = c5.shape[2:]
146
-
147
- c2 = rearrange(c2, 'b c h w -> b (h w) c')
148
- c3 = rearrange(c3, 'b c h w -> b (h w) c')
149
- c4 = rearrange(c4, 'b c h w -> b (h w) c')
150
- c5 = rearrange(c5, 'b c h w -> b (h w) c')
151
-
152
- _c2 = self.c2_c5(c5, c2)
153
- _c2 = rearrange(_c2, 'b (h w) c -> b c h w', h=H, w=W)
154
-
155
- _c3 = self.c3_c5(c5, c3)
156
- _c3 = rearrange(_c3, 'b (h w) c -> b c h w', h=H, w=W)
157
-
158
- _c4 = self.c4_c5(c5, c4)
159
- _c4 = rearrange(_c4, 'b (h w) c -> b c h w', h=H, w=W)
160
-
161
- _c5 = self.fuse(torch.cat([_c2, _c3, _c4], dim=1))
162
-
163
- return _c5
164
-
165
- def forward(self, fx, fy):
166
- _c5x = self.base_forward(fx[0], fx[1], fx[2], fx[3])
167
- _c5y = self.base_forward(fy[0], fy[1], fy[2], fy[3])
168
-
169
-
170
- _c5x = self.caa(_c5x, _c5y)
171
- _c5y = self.caa(_c5y, _c5x)
172
-
173
- return _c5x, _c5y
174
-
175
-
176
- class DualAttentionGate(nn.Module):
177
- def __init__(self, channels, ratio=8):
178
- super().__init__()
179
- # 通道注意力分支
180
- self.channel_att = nn.Sequential(
181
- nn.AdaptiveAvgPool2d(1), # [B,C,1,1]
182
- nn.Conv2d(channels, channels // ratio, 1, bias=False), # [B,C/8,1,1]
183
- nn.ReLU(),
184
- nn.Conv2d(channels // ratio, channels, 1, bias=False), # [B,C,1,1]
185
- nn.Sigmoid()
186
- )
187
-
188
- # 空间注意力分支
189
- self.spatial_att = nn.Sequential(
190
- nn.Conv2d(2, 1, 7, padding=3, bias=False), # 输入2通道(mean+std)
191
- nn.Sigmoid() # 输出[B,1,H,W]
192
- )
193
-
194
- def forward(self, x):
195
- """
196
- 输入: x [B,C,H,W]
197
- 输出: 增强后的特征 [B,C,H,W]
198
- """
199
- # 通道注意力
200
- c_att = self.channel_att(x) # [B,C,1,1]
201
-
202
- # 空间注意力
203
- mean = torch.mean(x, dim=1, keepdim=True) # [B,1,H,W]
204
- std = torch.std(x, dim=1, keepdim=True) # [B,1,H,W]
205
- s_att = self.spatial_att(torch.cat([mean, std], dim=1)) # [B,1,H,W]
206
-
207
- # 双重注意力融合
208
- return x * c_att * s_att # 逐元素相乘
209
-
210
-
211
- class SimplifiedFGFM(nn.Module):
212
- def __init__(self, in_channels, out_channels):
213
- super().__init__()
214
- self.down = nn.Conv2d(in_channels, out_channels, 1, bias=False)
215
- self.flow_make = nn.Conv2d(out_channels * 2, 4, 3, padding=1, bias=False)
216
- self.dual_att = DualAttentionGate(out_channels)
217
-
218
- def flow_warp(self, input, flow, size):
219
- # 保持原有光流变形实现
220
- out_h, out_w = size
221
- n, c, h, w = input.size()
222
-
223
- norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device)
224
- grid = torch.meshgrid(
225
- torch.linspace(-1.0, 1.0, out_h),
226
- torch.linspace(-1.0, 1.0, out_w),
227
- indexing='ij'
228
- )
229
- grid = torch.stack((grid[1], grid[0]), 2).repeat(n, 1, 1, 1).type_as(input)
230
- grid = grid + flow.permute(0, 2, 3, 1) / norm
231
-
232
- return F.grid_sample(input, grid, align_corners=True)
233
-
234
- def forward(self, lowres_feature, highres_feature):
235
- # 1. 光流对齐
236
- l_feature = self.down(lowres_feature)
237
- l_feature_up = F.interpolate(l_feature, size=highres_feature.shape[2:], mode='bilinear', align_corners=True)
238
-
239
- flow = self.flow_make(torch.cat([l_feature_up, highres_feature], dim=1))
240
- flow_l, flow_h = flow[:, :2, :, :], flow[:, 2:, :, :]
241
-
242
- l_warp = self.flow_warp(l_feature, flow_l, highres_feature.shape[2:])
243
- h_warp = self.flow_warp(highres_feature, flow_h, highres_feature.shape[2:])
244
-
245
- # 2. 双注意力融合
246
- fused = self.dual_att(l_warp + h_warp)
247
- return fused
248
-
249
-
250
- # Decoder 模块
251
- class Decoder(nn.Module):
252
- def __init__(self, in_dim=[64, 128, 256, 384], decay=4, num_class=1):
253
- super().__init__()
254
- c2_channel, c3_channel, c4_channel, c5_channel = in_dim
255
-
256
- self.structure_enhance = FeatureInjector(dim1=c5_channel)
257
-
258
- # 使用改进的 SimplifiedFGFM 模块替换传统上采样
259
- self.fgfm_c4 = SimplifiedFGFM(in_channels=c5_channel, out_channels=c4_channel)
260
- self.fgfm_c3 = SimplifiedFGFM(in_channels=c4_channel, out_channels=c3_channel)
261
- self.fgfm_c2 = SimplifiedFGFM(in_channels=c3_channel, out_channels=c2_channel)
262
-
263
- # 最终分类器
264
- self.classfier = nn.Sequential(
265
- nn.ConvTranspose2d(c2_channel, c2_channel, kernel_size=4, stride=2, padding=1),
266
- nn.Conv2d(c2_channel, num_class, 3, 1, padding=1, bias=False)
267
- )
268
-
269
- # 各层级的差异建模模块(MLP)
270
- self.mlp = nn.ModuleList([
271
- nn.Sequential(
272
- nn.Conv2d(dim * 3, dim // decay, 1, bias=False),
273
- nn.BatchNorm2d(dim // decay),
274
- nn.ReLU(),
275
- nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
276
- nn.ReLU(),
277
- nn.Conv2d(dim // decay, dim // decay, 3, 1, padding=1, bias=False),
278
- nn.ReLU(),
279
- nn.Conv2d(dim // decay, dim, 3, 1, padding=1, bias=False)
280
- ) for dim in in_dim
281
- ])
282
-
283
- def difference_modeling(self, x, y, block):
284
- f = torch.cat([x, y, torch.abs(x - y)], dim=1)
285
- return block(f)
286
-
287
- def forward(self, fx, fy):
288
- c2x, c3x, c4x = fx[:-1]
289
- c2y, c3y, c4y = fy[:-1]
290
-
291
- # 融合后的高阶语义特征(c5)
292
- c5x, c5y = self.structure_enhance(fx, fy)
293
-
294
- # 各层特征差异建模
295
- c2 = self.difference_modeling(c2x, c2y, self.mlp[0])
296
- c3 = self.difference_modeling(c3x, c3y, self.mlp[1])
297
- c4 = self.difference_modeling(c4x, c4y, self.mlp[2])
298
- c5 = self.difference_modeling(c5x, c5y, self.mlp[3])
299
-
300
- # 使用改进的 FGFM 进行流引导特征融合
301
- c4f = self.fgfm_c4(c5, c4)
302
- c3f = self.fgfm_c3(c4f, c3)
303
- c2f = self.fgfm_c2(c3f, c2)
304
-
305
- # 输出变化掩码
306
- pred = self.classfier(c2f)
307
- pred_mask = torch.sigmoid(pred)
308
-
309
- return pred_mask