InPeerReview commited on
Commit
3aafbf3
·
verified ·
1 Parent(s): af36ebf

Upload 4 files

Browse files
Files changed (4) hide show
  1. model/DCCS.py +158 -0
  2. model/LaSEA.py +243 -0
  3. model/auxiliary.py +701 -0
  4. model/loss.py +123 -0
model/DCCS.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from thop import profile
5
+ from model.auxiliary import VSSM
6
+ import torch
7
+ from model.LaSEA import *
8
+ import torch
9
+ import time
10
+ from thop import profile
11
+ class ChannelAttention(nn.Module):
12
+ def __init__(self, in_planes, ratio=16):
13
+ super(ChannelAttention, self).__init__()
14
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
15
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
16
+ self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
17
+ self.relu1 = nn.ReLU()
18
+ self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ def forward(self, x):
22
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
23
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
24
+ out = avg_out + max_out
25
+ return self.sigmoid(out)
26
+
27
+
28
+ class SpatialAttention(nn.Module):
29
+ def __init__(self, kernel_size=7):
30
+ super(SpatialAttention, self).__init__()
31
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
32
+ padding = 3 if kernel_size == 7 else 1
33
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+ def forward(self, x):
37
+ avg_out = torch.mean(x, dim=1, keepdim=True)
38
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
39
+ x = torch.cat([avg_out, max_out], dim=1)
40
+ x = self.conv1(x)
41
+ return self.sigmoid(x)
42
+
43
+
44
+ class ResNet(nn.Module):
45
+ def __init__(self, in_channels, out_channels, stride=1):
46
+ super(ResNet, self).__init__()
47
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
48
+ self.bn1 = nn.BatchNorm2d(out_channels)
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
51
+ self.bn2 = nn.BatchNorm2d(out_channels)
52
+ if stride != 1 or out_channels != in_channels:
53
+ self.shortcut = nn.Sequential(
54
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
55
+ nn.BatchNorm2d(out_channels))
56
+ else:
57
+ self.shortcut = None
58
+
59
+ self.ca = ChannelAttention(out_channels)
60
+ self.sa = SpatialAttention()
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+ if self.shortcut is not None:
65
+ residual = self.shortcut(x)
66
+ out = self.conv1(x)
67
+ out = self.bn1(out)
68
+ out = self.relu(out)
69
+
70
+ out = self.conv2(out)
71
+ out = self.bn2(out)
72
+ out = self.ca(out) * out
73
+ out = self.sa(out) * out
74
+ out += residual
75
+ out = self.relu(out)
76
+ return out
77
+
78
+
79
+ class DCCS(nn.Module):
80
+ def __init__(self, input_channels, block=ResNet):
81
+ super().__init__()
82
+ param_channels = [16, 32, 64, 128, 256]
83
+ param_blocks = [2, 2, 2, 2]
84
+ self.pool = nn.MaxPool2d(2, 2)
85
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
86
+ self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
87
+ self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
88
+ self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
89
+ self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1)
90
+ self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block)
91
+ self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0])
92
+ self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1])
93
+ self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2])
94
+
95
+ self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3])
96
+
97
+ self.decoder_3 = self._make_layer(param_channels[3] + param_channels[4], param_channels[3], block,
98
+ param_blocks[2])
99
+ self.decoder_2 = self._make_layer(param_channels[2] + param_channels[3], param_channels[2], block,
100
+ param_blocks[1])
101
+ self.decoder_1 = self._make_layer(param_channels[1] + param_channels[2], param_channels[1], block,
102
+ param_blocks[0])
103
+ self.decoder_0 = self._make_layer(param_channels[0] + param_channels[1], param_channels[0], block)
104
+
105
+ self.output_0 = nn.Conv2d(param_channels[0], 1, 1)
106
+ self.output_1 = nn.Conv2d(param_channels[1], 1, 1)
107
+ self.output_2 = nn.Conv2d(param_channels[2], 1, 1)
108
+ self.output_3 = nn.Conv2d(param_channels[3], 1, 1)
109
+ self.final = nn.Conv2d(4, 1, 3, 1, 1)
110
+ self.VSSM = VSSM()
111
+ self.post_fuse3 = nn.Conv2d(param_channels[3] * 2, param_channels[3], kernel_size=1)
112
+ self.post_fuse2 = nn.Conv2d(param_channels[2] * 2, param_channels[2], kernel_size=1)
113
+ self.post_fuse1 = nn.Conv2d(param_channels[1] * 2, param_channels[1], kernel_size=1)
114
+ self.post_fuse0 = nn.Conv2d(param_channels[0] * 2, param_channels[0], kernel_size=1)
115
+ self.GLFA = GLFA(in_channels=256)
116
+ def _make_layer(self, in_channels, out_channels, block, block_num=1):
117
+ layer = []
118
+ layer.append(block(in_channels, out_channels))
119
+ for _ in range(block_num - 1):
120
+ layer.append(block(out_channels, out_channels))
121
+ return nn.Sequential(*layer)
122
+ def forward(self, x, warm_flag):
123
+ outputs = self.VSSM(x)
124
+ x_e0f = outputs[0].permute(0, 3, 1, 2).contiguous()
125
+ x_e1f = outputs[1].permute(0, 3, 1, 2).contiguous()
126
+ x_e2f = outputs[2].permute(0, 3, 1, 2).contiguous()
127
+ x_e3f = outputs[3].permute(0, 3, 1, 2).contiguous()
128
+ x_e0z = self.encoder_0(self.conv_init(x))
129
+ x_e0 = torch.cat([x_e0z, x_e0f], dim=1)
130
+ x_e0z = self.post_fuse0(x_e0)
131
+ x_e1z = self.encoder_1(self.pool(x_e0z))
132
+ x_e1_fused = torch.cat([x_e1z, x_e1f], dim=1)
133
+ x_e1z = self.post_fuse1(x_e1_fused)
134
+ x_e2z = self.encoder_2(self.pool(x_e1z))
135
+ x_e2_fused = torch.cat([x_e2z, x_e2f], dim=1)
136
+ x_e2z = self.post_fuse2(x_e2_fused)
137
+ x_e3z = self.encoder_3(self.pool(x_e2z))
138
+ x_e3_fused = torch.cat([x_e3z, x_e3f], dim=1)
139
+ x_e3z = self.post_fuse3(x_e3_fused)
140
+ x_m = self.middle_layer(self.pool(x_e3z))
141
+ x_m = self.GLFA(x_m)
142
+ x_d3 = self.decoder_3(torch.cat([x_e3z, self.up(x_m)], 1))
143
+ x_d2 = self.decoder_2(torch.cat([x_e2z, self.up(x_d3)], 1))
144
+ x_d1 = self.decoder_1(torch.cat([x_e1z, self.up(x_d2)], 1))
145
+ x_d0 = self.decoder_0(torch.cat([x_e0z, self.up(x_d1)], 1))
146
+
147
+ if warm_flag:
148
+ mask0 = self.output_0(x_d0)
149
+ mask1 = self.output_1(x_d1)
150
+ mask2 = self.output_2(x_d2)
151
+ mask3 = self.output_3(x_d3)
152
+ output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1))
153
+ return [mask0, mask1, mask2, mask3], output
154
+
155
+ else:
156
+ output = self.output_0(x_d0)
157
+ return [], output
158
+
model/LaSEA.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Callable, Union, Tuple, Any
4
+ import torch
5
+ from torch import nn, Tensor
6
+ import numpy as np
7
+ from typing import Optional
8
+ import math
9
+ from torch import nn
10
+
11
+ def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
12
+ if min_value is None:
13
+ min_value = divisor
14
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
15
+ if new_v < 0.9 * v:
16
+ new_v += divisor
17
+ return new_v
18
+ def callMethod(self, ElementName):
19
+ return getattr(self, ElementName)
20
+ def setMethod(self, ElementName, ElementValue):
21
+ return setattr(self, ElementName, ElementValue)
22
+ def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor:
23
+ if isinstance(Feature, Tensor):
24
+ Feature = [Feature]
25
+ Indexs = None
26
+ Output = []
27
+ for f in Feature:
28
+ B, C, H, W = f.shape
29
+ if Mode == 1:
30
+ f = f.flatten(2)
31
+ if Indexs is None:
32
+ Indexs = torch.randperm(f.shape[-1], device=f.device)
33
+ f = f[:, :, Indexs.to(f.device)]
34
+ f = f.reshape(B, C, H, W)
35
+ else:
36
+ if Indexs is None:
37
+ Indexs = [torch.randperm(H, device=f.device),
38
+ torch.randperm(W, device=f.device)]
39
+ f = f[:, :, Indexs[0].to(f.device)]
40
+ f = f[:, :, :, Indexs[1].to(f.device)]
41
+ Output.append(f)
42
+ return Output
43
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
44
+ def __init__(self, output_size: int or tuple=1):
45
+ super(AdaptiveAvgPool2d, self).__init__(output_size=output_size)
46
+
47
+ def profileModule(self, Input: Tensor):
48
+ Output = self.forward(Input)
49
+ return Output, 0.0, 0.0
50
+
51
+ class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d):
52
+ def __init__(self, output_size: int or tuple=1):
53
+ super(AdaptiveMaxPool2d, self).__init__(output_size=output_size)
54
+
55
+ def profileModule(self, Input: Tensor):
56
+ Output = self.forward(Input)
57
+ return Output, 0.0, 0.0
58
+ class BaseConv2d(nn.Module):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: Optional[int] = 1,
65
+ padding: Optional[int] = None,
66
+ groups: Optional[int] = 1,
67
+ bias: Optional[bool] = None,
68
+ BNorm: bool = False,
69
+ ActLayer: Optional[Callable[..., nn.Module]] = None,
70
+ dilation: int = 1,
71
+ Momentum: Optional[float] = 0.1,
72
+ **kwargs: Any
73
+ ) -> None:
74
+ super(BaseConv2d, self).__init__()
75
+ if padding is None:
76
+ padding = int((kernel_size - 1) // 2 * dilation)
77
+
78
+ if bias is None:
79
+ bias = not BNorm
80
+
81
+ self.in_channels = in_channels
82
+ self.out_channels = out_channels
83
+ self.kernel_size = kernel_size
84
+ self.stride = stride
85
+ self.padding = padding
86
+ self.groups = groups
87
+ self.bias = bias
88
+
89
+ self.Conv = nn.Conv2d(in_channels, out_channels,
90
+ kernel_size, stride, padding, dilation, groups, bias, **kwargs)
91
+
92
+ self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity()
93
+
94
+ if ActLayer is not None:
95
+ if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid):
96
+ self.Act = ActLayer()
97
+ else:
98
+ self.Act = ActLayer(inplace=True)
99
+ else:
100
+ self.Act = ActLayer
101
+
102
+ self.apply(initWeight)
103
+
104
+ def forward(self, x: Tensor) -> Tensor:
105
+ x = self.Conv(x)
106
+ x = self.Bn(x)
107
+ if self.Act is not None:
108
+ x = self.Act(x)
109
+ return x
110
+
111
+ NormLayerTuple = (
112
+ nn.BatchNorm1d,
113
+ nn.BatchNorm2d,
114
+ nn.SyncBatchNorm,
115
+ nn.LayerNorm,
116
+ nn.InstanceNorm1d,
117
+ nn.InstanceNorm2d,
118
+ nn.GroupNorm,
119
+ nn.BatchNorm3d,
120
+ )
121
+ def initWeight(Module):
122
+ if Module is None:
123
+ return
124
+ elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
125
+ nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
126
+ if Module.bias is not None:
127
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
128
+ if fan_in != 0:
129
+ bound = 1 / math.sqrt(fan_in)
130
+ nn.init.uniform_(Module.bias, -bound, bound)
131
+ elif isinstance(Module, NormLayerTuple):
132
+ if Module.weight is not None:
133
+ nn.init.ones_(Module.weight)
134
+ if Module.bias is not None:
135
+ nn.init.zeros_(Module.bias)
136
+ elif isinstance(Module, nn.Linear):
137
+ nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
138
+ if Module.bias is not None:
139
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
140
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
141
+ nn.init.uniform_(Module.bias, -bound, bound)
142
+ elif isinstance(Module, (nn.Sequential, nn.ModuleList)):
143
+ for m in Module:
144
+ initWeight(m)
145
+ elif list(Module.children()):
146
+ for m in Module.children():
147
+ initWeight(m)
148
+ class Attention(nn.Module):
149
+ def __init__(
150
+ self,
151
+ InChannels: int,
152
+ HidChannels: int = None,
153
+ SqueezeFactor: int = 4,
154
+ PoolRes: list = [1, 2, 3],
155
+ Act: Callable[..., nn.Module] = nn.ReLU,
156
+ ScaleAct: Callable[..., nn.Module] = nn.Sigmoid,
157
+ MoCOrder: bool = True,
158
+ **kwargs: Any,
159
+ ) -> None:
160
+ super().__init__()
161
+ if HidChannels is None:
162
+ HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32)
163
+
164
+ AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes
165
+ for k in AllPoolRes:
166
+ Pooling = AdaptiveAvgPool2d(k)
167
+ setMethod(self, 'Pool%d' % k, Pooling)
168
+
169
+ self.SELayer = nn.Sequential(
170
+ BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act),
171
+ BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct),
172
+ )
173
+
174
+ self.PoolRes = PoolRes
175
+ self.MoCOrder = MoCOrder
176
+
177
+ def RandomSample(self, x: Tensor) -> Tensor:
178
+ if self.training:
179
+ PoolKeep = np.random.choice(self.PoolRes)
180
+ x1 = shuffleTensor(x)[0] if self.MoCOrder else x
181
+ AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1)
182
+ if AttnMap.shape[-1] > 1:
183
+ AttnMap = AttnMap.flatten(2)
184
+ AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]]
185
+ AttnMap = AttnMap[:, :, None, None] # squeeze twice
186
+ else:
187
+ AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x)
188
+
189
+ return AttnMap
190
+
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ AttnMap = self.RandomSample(x)
193
+ return x * self.SELayer(AttnMap)
194
+
195
+ def channel_shuffle(x, groups):
196
+ batchsize, num_channels, height, width = x.data.size()
197
+ channels_per_group = num_channels // groups
198
+ x = x.view(batchsize, groups, channels_per_group, height, width)
199
+ x = torch.transpose(x, 1, 2).contiguous()
200
+ x = x.view(batchsize, -1, height, width)
201
+ return x
202
+ class GLFA(nn.Module):
203
+ def __init__(self, in_channels):
204
+ super(GLFA, self).__init__()
205
+ self.in_channels = in_channels
206
+ self.out_channels = in_channels
207
+ self.conv_1 = nn.Sequential(
208
+ nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1),
209
+ nn.BatchNorm2d(in_channels),
210
+ nn.ReLU(inplace=True)
211
+ )
212
+ self.conv_2 = nn.Sequential(
213
+ nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2),
214
+ nn.BatchNorm2d(in_channels),
215
+ nn.ReLU(inplace=True)
216
+ )
217
+ self.conv_3 = nn.Sequential(
218
+ nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3),
219
+ nn.BatchNorm2d(in_channels),
220
+ nn.ReLU(inplace=True)
221
+ )
222
+ self.conv_4 = nn.Sequential(
223
+ nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4),
224
+ nn.BatchNorm2d(in_channels),
225
+ nn.ReLU(inplace=True)
226
+ )
227
+ self.fuse = nn.Sequential(
228
+ nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0),
229
+ nn.BatchNorm2d(in_channels),
230
+ nn.ReLU(inplace=True)
231
+ )
232
+ self.mca = Attention(InChannels=in_channels, HidChannels=16)
233
+ def forward(self, x):
234
+ d = x
235
+ c1 = self.conv_1(x)
236
+ c2 = self.conv_2(x)
237
+ c3 = self.conv_3(x)
238
+ c4 = self.conv_4(x)
239
+ cat = torch.cat([c1, c2, c3, c4], dim=1)
240
+ cat = channel_shuffle(cat, groups=4)
241
+ M= self.fuse(cat) #
242
+ O = self.mca(M)
243
+ return O + d
model/auxiliary.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ from functools import partial
4
+ from typing import Optional, Callable
5
+ from torch import Tensor
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+ from einops import rearrange, repeat
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+ try:
14
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
15
+ except:
16
+ pass
17
+ try:
18
+ from selective_scan import selective_scan_fn as selective_scan_fn_v1
19
+ from selective_scan import selective_scan_ref as selective_scan_ref_v1
20
+ except:
21
+ pass
22
+
23
+ DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
24
+
25
+ def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
26
+ """
27
+ u: r(B D L)
28
+ delta: r(B D L)
29
+ A: r(D N)
30
+ B: r(B N L)
31
+ C: r(B N L)
32
+ D: r(D)
33
+ z: r(B D L)
34
+ delta_bias: r(D), fp32
35
+
36
+ ignores:
37
+ [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
38
+ """
39
+ import numpy as np
40
+
41
+ # fvcore.nn.jit_handles
42
+ def get_flops_einsum(input_shapes, equation):
43
+ np_arrs = [np.zeros(s) for s in input_shapes]
44
+ optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
45
+ for line in optim.split("\n"):
46
+ if "optimized flop" in line.lower():
47
+ # divided by 2 because we count MAC (multiply-add counted as one flop)
48
+ flop = float(np.floor(float(line.split(":")[-1]) / 2))
49
+ return flop
50
+
51
+ assert not with_complex
52
+
53
+ flops = 0 # below code flops = 0
54
+ if False:
55
+ ...
56
+ """
57
+ dtype_in = u.dtype
58
+ u = u.float()
59
+ delta = delta.float()
60
+ if delta_bias is not None:
61
+ delta = delta + delta_bias[..., None].float()
62
+ if delta_softplus:
63
+ delta = F.softplus(delta)
64
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
65
+ is_variable_B = B.dim() >= 3
66
+ is_variable_C = C.dim() >= 3
67
+ if A.is_complex():
68
+ if is_variable_B:
69
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
70
+ if is_variable_C:
71
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
72
+ else:
73
+ B = B.float()
74
+ C = C.float()
75
+ x = A.new_zeros((batch, dim, dstate))
76
+ ys = []
77
+ """
78
+
79
+ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
80
+ if with_Group:
81
+ flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
82
+ else:
83
+ flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
84
+ if False:
85
+ ...
86
+ """
87
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
88
+ if not is_variable_B:
89
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
90
+ else:
91
+ if B.dim() == 3:
92
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
93
+ else:
94
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
95
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
96
+ if is_variable_C and C.dim() == 4:
97
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
98
+ last_state = None
99
+ """
100
+
101
+ in_for_flops = B * D * N
102
+ if with_Group:
103
+ in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
104
+ else:
105
+ in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
106
+ flops += L * in_for_flops
107
+ if False:
108
+ ...
109
+ """
110
+ for i in range(u.shape[2]):
111
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
112
+ if not is_variable_C:
113
+ y = torch.einsum('bdn,dn->bd', x, C)
114
+ else:
115
+ if C.dim() == 3:
116
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
117
+ else:
118
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
119
+ if i == u.shape[2] - 1:
120
+ last_state = x
121
+ if y.is_complex():
122
+ y = y.real * 2
123
+ ys.append(y)
124
+ y = torch.stack(ys, dim=2) # (batch dim L)
125
+ """
126
+
127
+ if with_D:
128
+ flops += B * D * L
129
+ if with_Z:
130
+ flops += B * D * L
131
+ if False:
132
+ ...
133
+ return flops
134
+ class PatchEmbed2D(nn.Module):
135
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
136
+ super().__init__()
137
+ if isinstance(patch_size, int):
138
+ patch_size = (patch_size, patch_size)
139
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
140
+ if norm_layer is not None:
141
+ self.norm = norm_layer(embed_dim)
142
+ else:
143
+ self.norm = None
144
+
145
+ def forward(self, x):
146
+ x = self.proj(x).permute(0, 2, 3, 1)
147
+ if self.norm is not None:
148
+ x = self.norm(x)
149
+ return x
150
+ class PatchMerging2D(nn.Module):
151
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
152
+ super().__init__()
153
+ self.dim = dim
154
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
155
+ self.norm = norm_layer(4 * dim)
156
+ def forward(self, x): # x: [B, H, W, C]
157
+ B, H, W, C = x.shape
158
+ SHAPE_FIX = [-1, -1]
159
+ if (W % 2 != 0) or (H % 2 != 0):
160
+ print(f"Warning: x.shape {x.shape} is not even.", flush=True)
161
+ SHAPE_FIX[0] = H // 2
162
+ SHAPE_FIX[1] = W // 2
163
+ x0 = x[:, 0::2, 0::2, :]
164
+ x1 = x[:, 1::2, 0::2, :]
165
+ x2 = x[:, 0::2, 1::2, :]
166
+ x3 = x[:, 1::2, 1::2, :]
167
+ if SHAPE_FIX[0] > 0:
168
+ x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
169
+ x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
170
+ x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
171
+ x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
172
+ x = torch.cat([x0, x1, x2, x3], dim=-1)
173
+ x = self.norm(x)
174
+ x = self.reduction(x)
175
+ return x
176
+ class PatchExpand2D(nn.Module):
177
+ def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
178
+ super().__init__()
179
+ self.dim = dim * 2
180
+ self.dim_scale = dim_scale
181
+ self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
182
+ self.norm = norm_layer(self.dim // dim_scale)
183
+ def forward(self, x):
184
+ B, H, W, C = x.shape
185
+ x = self.expand(x)
186
+
187
+ x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
188
+ c=C // self.dim_scale)
189
+ x = self.norm(x)
190
+ return x
191
+ class Final_PatchExpand2D(nn.Module):
192
+ def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
193
+ super().__init__()
194
+ self.dim = dim
195
+ self.dim_scale = dim_scale
196
+ self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
197
+ self.norm = norm_layer(self.dim // dim_scale)
198
+
199
+ def forward(self, x):
200
+ B, H, W, C = x.shape
201
+ x = self.expand(x)
202
+
203
+ x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
204
+ c=C // self.dim_scale)
205
+ x = self.norm(x)
206
+
207
+ return x
208
+ class SS2D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ d_model,
212
+ d_state=16,
213
+ d_conv=3,
214
+ expand=2,
215
+ dt_rank="auto",
216
+ dt_min=0.001,
217
+ dt_max=0.1,
218
+ dt_init="random",
219
+ dt_scale=1.0,
220
+ dt_init_floor=1e-4,
221
+ dropout=0.,
222
+ conv_bias=True,
223
+ bias=False,
224
+ device=None,
225
+ dtype=None,
226
+ **kwargs,
227
+ ):
228
+ factory_kwargs = {"device": device, "dtype": dtype}
229
+ super().__init__()
230
+ self.d_model = d_model
231
+ self.d_state = d_state
232
+ self.d_conv = d_conv
233
+ self.expand = expand
234
+ self.d_inner = int(self.expand * self.d_model)
235
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
236
+
237
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
238
+ self.conv2d = nn.Conv2d(
239
+ in_channels=self.d_inner,
240
+ out_channels=self.d_inner,
241
+ groups=self.d_inner,
242
+ bias=conv_bias,
243
+ kernel_size=d_conv,
244
+ padding=(d_conv - 1) // 2,
245
+ **factory_kwargs,
246
+ )
247
+ self.act = nn.SiLU()
248
+
249
+ self.x_proj = (
250
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
251
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
252
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
253
+ nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
254
+ )
255
+ self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
256
+ del self.x_proj
257
+
258
+ self.dt_projs = (
259
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
260
+ **factory_kwargs),
261
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
262
+ **factory_kwargs),
263
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
264
+ **factory_kwargs),
265
+ self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
266
+ **factory_kwargs),
267
+ )
268
+ self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
269
+ self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
270
+ del self.dt_projs
271
+
272
+ self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
273
+ self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
274
+
275
+ # self.selective_scan = selective_scan_fn
276
+ self.forward_core = self.forward_corev0
277
+
278
+ self.out_norm = nn.LayerNorm(self.d_inner)
279
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
280
+ self.dropout = nn.Dropout(dropout) if dropout > 0. else None
281
+ self.ChannelAttentionModule = ChannelAttentionModule(in_channels=self.d_inner)
282
+ @staticmethod
283
+ def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
284
+ **factory_kwargs):
285
+ dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
286
+ dt_init_std = dt_rank ** -0.5 * dt_scale
287
+ if dt_init == "constant":
288
+ nn.init.constant_(dt_proj.weight, dt_init_std)
289
+ elif dt_init == "random":
290
+ nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
291
+ else:
292
+ raise NotImplementedError
293
+ dt = torch.exp(
294
+ torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
295
+ + math.log(dt_min)
296
+ ).clamp(min=dt_init_floor)
297
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
298
+ with torch.no_grad():
299
+ dt_proj.bias.copy_(inv_dt)
300
+ dt_proj.bias._no_reinit = True
301
+
302
+ return dt_proj
303
+
304
+ @staticmethod
305
+ def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
306
+ # S4D real initialization
307
+ A = repeat(
308
+ torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
309
+ "n -> d n",
310
+ d=d_inner,
311
+ ).contiguous()
312
+ A_log = torch.log(A) # Keep A_log in fp32
313
+ if copies > 1:
314
+ A_log = repeat(A_log, "d n -> r d n", r=copies)
315
+ if merge:
316
+ A_log = A_log.flatten(0, 1)
317
+ A_log = nn.Parameter(A_log)
318
+ A_log._no_weight_decay = True
319
+ return A_log
320
+
321
+ @staticmethod
322
+ def D_init(d_inner, copies=1, device=None, merge=True):
323
+ # D "skip" parameter
324
+ D = torch.ones(d_inner, device=device)
325
+ if copies > 1:
326
+ D = repeat(D, "n1 -> r n1", r=copies)
327
+ if merge:
328
+ D = D.flatten(0, 1)
329
+ D = nn.Parameter(D) # Keep in fp32
330
+ D._no_weight_decay = True
331
+ return D
332
+
333
+ def forward_corev0(self, x: torch.Tensor):
334
+ self.selective_scan = selective_scan_fn
335
+
336
+ B, C, H, W = x.shape
337
+ L = H * W
338
+ K = 4
339
+
340
+ x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
341
+ dim=1).view(B, 2, -1, L)
342
+ xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
343
+
344
+ x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
345
+ # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
346
+ dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
347
+ dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
348
+ # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
349
+
350
+ xs = xs.float().view(B, -1, L) # (b, k * d, l)
351
+ dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
352
+ Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
353
+ Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
354
+ Ds = self.Ds.float().view(-1) # (k * d)
355
+ As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
356
+ dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
357
+
358
+ out_y = self.selective_scan(
359
+ xs, dts,
360
+ As, Bs, Cs, Ds, z=None,
361
+ delta_bias=dt_projs_bias,
362
+ delta_softplus=True,
363
+ return_last_state=False,
364
+ ).view(B, K, -1, L)
365
+ assert out_y.dtype == torch.float
366
+
367
+ inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
368
+ wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
369
+ invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
370
+
371
+ return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
372
+ def forward_corev1(self, x: torch.Tensor):
373
+ self.selective_scan = selective_scan_fn_v1
374
+
375
+ B, C, H, W = x.shape
376
+ L = H * W
377
+ K = 4
378
+ x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
379
+ dim=1).view(B, 2, -1, L)
380
+ xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
381
+
382
+ x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
383
+ dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
384
+ dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
385
+ xs = xs.float().view(B, -1, L)
386
+ dts = dts.contiguous().float().view(B, -1, L)
387
+ Bs = Bs.float().view(B, K, -1, L)
388
+ Cs = Cs.float().view(B, K, -1, L)
389
+ Ds = self.Ds.float().view(-1)
390
+ As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)
391
+ dt_projs_bias = self.dt_projs_bias.float().view(-1)
392
+
393
+ out_y = self.selective_scan(
394
+ xs, dts,
395
+ As, Bs, Cs, Ds,
396
+ delta_bias=dt_projs_bias,
397
+ delta_softplus=True,
398
+ ).view(B, K, -1, L)
399
+ assert out_y.dtype == torch.float
400
+
401
+ inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
402
+ wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
403
+ invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
404
+
405
+ return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
406
+ def forward(self, a: torch.Tensor, **kwargs):
407
+ B, H, W, C = a.shape
408
+
409
+ xz = self.in_proj(a)
410
+ x, z = xz.chunk(2, dim=-1)
411
+ z = z.permute(0, 3, 1, 2)
412
+ z = self.ChannelAttentionModule(z) * z
413
+ z = z.permute(0, 2, 3, 1).contiguous()
414
+ x = x.permute(0, 3, 1, 2).contiguous()
415
+ x = self.act(self.conv2d(x))
416
+ y1, y2, y3, y4 = self.forward_core(x)
417
+ assert y1.dtype == torch.float32
418
+ y = y1 + y2 + y3 + y4
419
+ y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
420
+ y = self.out_norm(y)
421
+ y = y * torch.nn.functional.silu(z)
422
+ out = self.out_proj(y)
423
+ if self.dropout is not None:
424
+ out = self.dropout(out)
425
+ return out+a
426
+ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
427
+ batch_size, height, width, num_channels = x.size()
428
+ channels_per_group = num_channels // groups
429
+ x = x.view(batch_size, height, width, groups, channels_per_group)
430
+ x = torch.transpose(x, 3, 4).contiguous()
431
+ x = x.view(batch_size, height, width, -1)
432
+ return x
433
+
434
+ class ChannelAttentionModule(nn.Module):
435
+ def __init__(self, in_channels, reduction=4):
436
+ super(ChannelAttentionModule, self).__init__()
437
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
438
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
439
+ self.fc = nn.Sequential(
440
+ nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
441
+ nn.ReLU(inplace=True),
442
+ nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
443
+ )
444
+ self.sigmoid = nn.Sigmoid()
445
+
446
+ def forward(self, x):
447
+ avg_out = self.fc(self.avg_pool(x))
448
+ max_out = self.fc(self.max_pool(x))
449
+ out = avg_out + max_out
450
+ return self.sigmoid(out)
451
+
452
+ class SS_Conv_SSM(nn.Module):
453
+ def __init__(
454
+ self,
455
+ hidden_dim: int = 0,
456
+ drop_path: float = 0,
457
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
458
+ attn_drop_rate: float = 0,
459
+ d_state: int = 16,
460
+ **kwargs,
461
+ ):
462
+ super().__init__()
463
+ self.ln_1 = norm_layer(hidden_dim // 2)
464
+ self.self_attention = SS2D(d_model=hidden_dim // 2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
465
+ self.drop_path = DropPath(drop_path)
466
+
467
+ self.conv33conv33conv11 = nn.Sequential(
468
+ nn.BatchNorm2d(hidden_dim // 2),
469
+ nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
470
+ nn.BatchNorm2d(hidden_dim // 2),
471
+ nn.ReLU(),
472
+ nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
473
+ nn.BatchNorm2d(hidden_dim // 2),
474
+ nn.ReLU(),
475
+ nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1),
476
+ nn.ReLU()
477
+ )
478
+ self.ChannelAttentionModule = ChannelAttentionModule(in_channels=hidden_dim // 2)
479
+ def forward(self, input: torch.Tensor):
480
+ input_left, input_right = input.chunk(2, dim=-1)
481
+ input_right = self.ln_1(input_right)
482
+ input_left = self.ln_1(input_left)
483
+ x = self.drop_path(self.self_attention(input_right))
484
+ b0 = input_left.permute(0, 3, 1, 2).contiguous()
485
+ b1 = self.conv33conv33conv11(b0)
486
+ b2 = self.ChannelAttentionModule(b0)
487
+ b1= b1.permute(0, 2, 3, 1).contiguous()
488
+ b2 = b2.permute(0, 2, 3, 1).contiguous()
489
+ input_left = b1 * b2
490
+ output1 = torch.cat((input_left, x), dim=-1)
491
+ output = channel_shuffle(output1, groups=2)
492
+ return output + input
493
+ class VSSLayer(nn.Module):
494
+ """ A basic Swin Transformer layer for one stage.
495
+ Args:
496
+ dim (int): Number of input channels.
497
+ depth (int): Number of blocks.
498
+ drop (float, optional): Dropout rate. Default: 0.0
499
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
500
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
501
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
502
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
503
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
504
+ """
505
+
506
+ def __init__(
507
+ self,
508
+ dim,
509
+ depth,
510
+ attn_drop=0.,
511
+ drop_path=0.,
512
+ norm_layer=nn.LayerNorm,
513
+ downsample=None,
514
+ use_checkpoint=False,
515
+ d_state=16,
516
+ **kwargs,
517
+ ):
518
+ super().__init__()
519
+ self.dim = dim
520
+ self.use_checkpoint = use_checkpoint
521
+
522
+ self.blocks = nn.ModuleList([
523
+ SS_Conv_SSM(
524
+ hidden_dim=dim,
525
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
526
+ norm_layer=norm_layer,
527
+ attn_drop_rate=attn_drop,
528
+ d_state=d_state,
529
+ )
530
+ for i in range(depth)])
531
+
532
+ if True: # is this really applied? Yes, but been overriden later in VSSM!
533
+ def _init_weights(module: nn.Module):
534
+ for name, p in module.named_parameters():
535
+ if name in ["out_proj.weight-881-1KESHIHUA QUANZHONG"]:
536
+ p = p.clone().detach_() # fake init, just to keep the seed ....
537
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
538
+
539
+ self.apply(_init_weights)
540
+
541
+ if downsample is not None:
542
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
543
+ else:
544
+ self.downsample = None
545
+
546
+ def forward(self, x):
547
+ for blk in self.blocks:
548
+ if self.use_checkpoint:
549
+ x = checkpoint.checkpoint(blk, x)
550
+ else:
551
+ x = blk(x)
552
+
553
+ if self.downsample is not None:
554
+ x = self.downsample(x)
555
+
556
+ return x
557
+ class VSSLayer_up(nn.Module):
558
+ """ A basic Swin Transformer layer for one stage.
559
+ Args:
560
+ dim (int): Number of input channels.
561
+ depth (int): Number of blocks.
562
+ drop (float, optional): Dropout rate. Default: 0.0
563
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
564
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
565
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
566
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
567
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
568
+ """
569
+
570
+ def __init__(
571
+ self,
572
+ dim,
573
+ depth,
574
+ attn_drop=0.,
575
+ drop_path=0.,
576
+ norm_layer=nn.LayerNorm,
577
+ upsample=None,
578
+ use_checkpoint=False,
579
+ d_state=16,
580
+ **kwargs,
581
+ ):
582
+ super().__init__()
583
+ self.dim = dim
584
+ self.use_checkpoint = use_checkpoint
585
+
586
+ self.blocks = nn.ModuleList([
587
+ SS_Conv_SSM(
588
+ hidden_dim=dim,
589
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
590
+ norm_layer=norm_layer,
591
+ attn_drop_rate=attn_drop,
592
+ d_state=d_state,
593
+ )
594
+ for i in range(depth)])
595
+
596
+ if True:
597
+ def _init_weights(module: nn.Module):
598
+ for name, p in module.named_parameters():
599
+ if name in ["out_proj.weight-881-1KESHIHUA QUANZHONG"]:
600
+ p = p.clone().detach_()
601
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
602
+
603
+ self.apply(_init_weights)
604
+
605
+ if upsample is not None:
606
+ self.upsample = upsample(dim=dim, norm_layer=norm_layer)
607
+ else:
608
+ self.upsample = None
609
+
610
+ def forward(self, x):
611
+ if self.upsample is not None:
612
+ x = self.upsample(x)
613
+ for blk in self.blocks:
614
+ if self.use_checkpoint:
615
+ x = checkpoint.checkpoint(blk, x)
616
+ else:
617
+ x = blk(x)
618
+ return x
619
+ class VSSM(nn.Module):
620
+ def __init__(self, patch_size=1, in_chans=3, num_classes=1, depths=[2, 2, 2, 2],
621
+ dims=[16, 32, 64, 128], d_state=16, drop_rate=0.,
622
+ attn_drop_rate=0., drop_path_rate=0.1,
623
+ norm_layer=nn.LayerNorm, patch_norm=True,
624
+ use_checkpoint=False, **kwargs):
625
+ super().__init__()
626
+ self.num_classes = num_classes
627
+ self.num_layers = len(depths)
628
+ self.embed_dim = dims[0]
629
+ self.num_features = dims[-1]
630
+ self.dims = dims
631
+ self.layer_outputs = []
632
+ self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
633
+ norm_layer=norm_layer if patch_norm else None)
634
+ self.ape = False
635
+ if self.ape:
636
+ self.patches_resolution = self.patch_embed.patches_resolution
637
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
638
+ trunc_normal_(self.absolute_pos_embed, std=.02)
639
+ self.pos_drop = nn.Dropout(p=drop_rate)
640
+
641
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
642
+ self.layers = nn.ModuleList()
643
+ for i_layer in range(self.num_layers):
644
+ layer = VSSLayer(
645
+ dim=dims[i_layer],
646
+ depth=depths[i_layer],
647
+ d_state=math.ceil(dims[0] / 6) if d_state is None else d_state,
648
+ drop=drop_rate,
649
+ attn_drop=attn_drop_rate,
650
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
651
+ norm_layer=norm_layer,
652
+ downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
653
+ use_checkpoint=use_checkpoint,
654
+ )
655
+ self.layers.append(layer)
656
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
657
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
658
+
659
+ self.apply(self._init_weights)
660
+ for m in self.modules():
661
+ if isinstance(m, nn.Conv2d):
662
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
663
+
664
+ def _init_weights(self, m: nn.Module):
665
+ if isinstance(m, nn.Linear):
666
+ trunc_normal_(m.weight, std=.02)
667
+ if isinstance(m, nn.Linear) and m.bias is not None:
668
+ nn.init.constant_(m.bias, 0)
669
+ elif isinstance(m, nn.LayerNorm):
670
+ nn.init.constant_(m.bias, 0)
671
+ nn.init.constant_(m.weight, 1.0)
672
+
673
+ @torch.jit.ignore
674
+ def no_weight_decay(self):
675
+ return {'absolute_pos_embed'}
676
+
677
+ @torch.jit.ignore
678
+ def no_weight_decay_keywords(self):
679
+ return {'relative_position_bias_table'}
680
+
681
+ def forward_backbone(self, x):
682
+ self.layer_outputs = []
683
+ x = self.patch_embed(x)
684
+ self.layer_outputs.append(x)
685
+
686
+ if self.ape:
687
+ x = x + self.absolute_pos_embed
688
+ x = self.pos_drop(x)
689
+
690
+ for layer in self.layers:
691
+ x = layer(x)
692
+ self.layer_outputs.append(x)
693
+ return self.layer_outputs
694
+
695
+ def forward(self, x, i=None):
696
+ outputs = self.forward_backbone(x)
697
+ if i is not None:
698
+ x = outputs[i]
699
+ x = x.permute(0, 3, 1, 2).contiguous()
700
+ return x
701
+ return outputs
model/loss.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from skimage import measure
6
+
7
+
8
+ def SoftIoULoss(pred, target):
9
+ pred = torch.sigmoid(pred)
10
+
11
+ smooth = 1
12
+
13
+ intersection = pred * target
14
+ intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
15
+ pred_sum = torch.sum(pred, dim=(1, 2, 3))
16
+ target_sum = torch.sum(target, dim=(1, 2, 3))
17
+
18
+ loss = (intersection_sum + smooth) / \
19
+ (pred_sum + target_sum - intersection_sum + smooth)
20
+
21
+ loss = 1 - loss.mean()
22
+
23
+ return loss
24
+
25
+
26
+ def Dice(pred, target, warm_epoch=1, epoch=1, layer=0):
27
+ pred = torch.sigmoid(pred)
28
+
29
+ smooth = 1
30
+
31
+ intersection = pred * target
32
+ intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
33
+ pred_sum = torch.sum(pred, dim=(1, 2, 3))
34
+ target_sum = torch.sum(target, dim=(1, 2, 3))
35
+
36
+ loss = (2 * intersection_sum + smooth) / \
37
+ (pred_sum + target_sum + intersection_sum + smooth)
38
+
39
+ loss = 1 - loss.mean()
40
+
41
+ return loss
42
+
43
+
44
+ class SLSIoULoss(nn.Module):
45
+ def __init__(self):
46
+ super(SLSIoULoss, self).__init__()
47
+
48
+ def forward(self, pred_log, target, warm_epoch, epoch, with_shape=True):
49
+ pred = torch.sigmoid(pred_log)
50
+ smooth = 0.0
51
+
52
+ intersection = pred * target
53
+
54
+ intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
55
+ pred_sum = torch.sum(pred, dim=(1, 2, 3))
56
+ target_sum = torch.sum(target, dim=(1, 2, 3))
57
+
58
+ dis = torch.pow((pred_sum - target_sum) / 2, 2)
59
+
60
+ alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth)
61
+
62
+ loss = (intersection_sum + smooth) / \
63
+ (pred_sum + target_sum - intersection_sum + smooth)
64
+ lloss = LLoss(pred, target)
65
+
66
+ if epoch > warm_epoch:
67
+ siou_loss = alpha * loss
68
+ if with_shape:
69
+ loss = 1 - siou_loss.mean() + lloss
70
+ else:
71
+ loss = 1 - siou_loss.mean()
72
+ else:
73
+ loss = 1 - loss.mean()
74
+ return loss
75
+
76
+
77
+ def LLoss(pred, target):
78
+ loss = torch.tensor(0.0, requires_grad=True).to(pred)
79
+
80
+ patch_size = pred.shape[0]
81
+ h = pred.shape[2]
82
+ w = pred.shape[3]
83
+ x_index = torch.arange(0, w, 1).view(1, 1, w).repeat((1, h, 1)).to(pred) / w
84
+ y_index = torch.arange(0, h, 1).view(1, h, 1).repeat((1, 1, w)).to(pred) / h
85
+ smooth = 1e-8
86
+ for i in range(patch_size):
87
+ pred_centerx = (x_index * pred[i]).mean()
88
+ pred_centery = (y_index * pred[i]).mean()
89
+
90
+ target_centerx = (x_index * target[i]).mean()
91
+ target_centery = (y_index * target[i]).mean()
92
+
93
+ angle_loss = (4 / (torch.pi ** 2)) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth))
94
+ - torch.arctan(
95
+ (target_centery) / (target_centerx + smooth))))
96
+
97
+ pred_length = torch.sqrt(pred_centerx * pred_centerx + pred_centery * pred_centery + smooth)
98
+ target_length = torch.sqrt(target_centerx * target_centerx + target_centery * target_centery + smooth)
99
+
100
+ length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth)
101
+
102
+ loss = loss + (1 - length_loss + angle_loss) / patch_size
103
+
104
+ return loss
105
+
106
+
107
+ class AverageMeter(object):
108
+ """Computes and stores the average and current value"""
109
+
110
+ def __init__(self):
111
+ self.reset()
112
+
113
+ def reset(self):
114
+ self.val = 0
115
+ self.avg = 0
116
+ self.sum = 0
117
+ self.count = 0
118
+
119
+ def update(self, val, n=1):
120
+ self.val = val
121
+ self.sum += val * n
122
+ self.count += n
123
+ self.avg = self.sum / self.count