Upload 4 files
Browse files- model/DCCS.py +158 -0
- model/LaSEA.py +243 -0
- model/auxiliary.py +701 -0
- model/loss.py +123 -0
model/DCCS.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from thop import profile
|
| 5 |
+
from model.auxiliary import VSSM
|
| 6 |
+
import torch
|
| 7 |
+
from model.LaSEA import *
|
| 8 |
+
import torch
|
| 9 |
+
import time
|
| 10 |
+
from thop import profile
|
| 11 |
+
class ChannelAttention(nn.Module):
|
| 12 |
+
def __init__(self, in_planes, ratio=16):
|
| 13 |
+
super(ChannelAttention, self).__init__()
|
| 14 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 15 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
| 16 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
|
| 17 |
+
self.relu1 = nn.ReLU()
|
| 18 |
+
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
|
| 19 |
+
self.sigmoid = nn.Sigmoid()
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
| 23 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
| 24 |
+
out = avg_out + max_out
|
| 25 |
+
return self.sigmoid(out)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SpatialAttention(nn.Module):
|
| 29 |
+
def __init__(self, kernel_size=7):
|
| 30 |
+
super(SpatialAttention, self).__init__()
|
| 31 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
| 32 |
+
padding = 3 if kernel_size == 7 else 1
|
| 33 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
| 34 |
+
self.sigmoid = nn.Sigmoid()
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
| 38 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
| 39 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
| 40 |
+
x = self.conv1(x)
|
| 41 |
+
return self.sigmoid(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ResNet(nn.Module):
|
| 45 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 46 |
+
super(ResNet, self).__init__()
|
| 47 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
|
| 48 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 49 |
+
self.relu = nn.ReLU(inplace=True)
|
| 50 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 51 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 52 |
+
if stride != 1 or out_channels != in_channels:
|
| 53 |
+
self.shortcut = nn.Sequential(
|
| 54 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
|
| 55 |
+
nn.BatchNorm2d(out_channels))
|
| 56 |
+
else:
|
| 57 |
+
self.shortcut = None
|
| 58 |
+
|
| 59 |
+
self.ca = ChannelAttention(out_channels)
|
| 60 |
+
self.sa = SpatialAttention()
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
residual = x
|
| 64 |
+
if self.shortcut is not None:
|
| 65 |
+
residual = self.shortcut(x)
|
| 66 |
+
out = self.conv1(x)
|
| 67 |
+
out = self.bn1(out)
|
| 68 |
+
out = self.relu(out)
|
| 69 |
+
|
| 70 |
+
out = self.conv2(out)
|
| 71 |
+
out = self.bn2(out)
|
| 72 |
+
out = self.ca(out) * out
|
| 73 |
+
out = self.sa(out) * out
|
| 74 |
+
out += residual
|
| 75 |
+
out = self.relu(out)
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DCCS(nn.Module):
|
| 80 |
+
def __init__(self, input_channels, block=ResNet):
|
| 81 |
+
super().__init__()
|
| 82 |
+
param_channels = [16, 32, 64, 128, 256]
|
| 83 |
+
param_blocks = [2, 2, 2, 2]
|
| 84 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 85 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 86 |
+
self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 87 |
+
self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
| 88 |
+
self.up_16 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 89 |
+
self.conv_init = nn.Conv2d(input_channels, param_channels[0], 1, 1)
|
| 90 |
+
self.encoder_0 = self._make_layer(param_channels[0], param_channels[0], block)
|
| 91 |
+
self.encoder_1 = self._make_layer(param_channels[0], param_channels[1], block, param_blocks[0])
|
| 92 |
+
self.encoder_2 = self._make_layer(param_channels[1], param_channels[2], block, param_blocks[1])
|
| 93 |
+
self.encoder_3 = self._make_layer(param_channels[2], param_channels[3], block, param_blocks[2])
|
| 94 |
+
|
| 95 |
+
self.middle_layer = self._make_layer(param_channels[3], param_channels[4], block, param_blocks[3])
|
| 96 |
+
|
| 97 |
+
self.decoder_3 = self._make_layer(param_channels[3] + param_channels[4], param_channels[3], block,
|
| 98 |
+
param_blocks[2])
|
| 99 |
+
self.decoder_2 = self._make_layer(param_channels[2] + param_channels[3], param_channels[2], block,
|
| 100 |
+
param_blocks[1])
|
| 101 |
+
self.decoder_1 = self._make_layer(param_channels[1] + param_channels[2], param_channels[1], block,
|
| 102 |
+
param_blocks[0])
|
| 103 |
+
self.decoder_0 = self._make_layer(param_channels[0] + param_channels[1], param_channels[0], block)
|
| 104 |
+
|
| 105 |
+
self.output_0 = nn.Conv2d(param_channels[0], 1, 1)
|
| 106 |
+
self.output_1 = nn.Conv2d(param_channels[1], 1, 1)
|
| 107 |
+
self.output_2 = nn.Conv2d(param_channels[2], 1, 1)
|
| 108 |
+
self.output_3 = nn.Conv2d(param_channels[3], 1, 1)
|
| 109 |
+
self.final = nn.Conv2d(4, 1, 3, 1, 1)
|
| 110 |
+
self.VSSM = VSSM()
|
| 111 |
+
self.post_fuse3 = nn.Conv2d(param_channels[3] * 2, param_channels[3], kernel_size=1)
|
| 112 |
+
self.post_fuse2 = nn.Conv2d(param_channels[2] * 2, param_channels[2], kernel_size=1)
|
| 113 |
+
self.post_fuse1 = nn.Conv2d(param_channels[1] * 2, param_channels[1], kernel_size=1)
|
| 114 |
+
self.post_fuse0 = nn.Conv2d(param_channels[0] * 2, param_channels[0], kernel_size=1)
|
| 115 |
+
self.GLFA = GLFA(in_channels=256)
|
| 116 |
+
def _make_layer(self, in_channels, out_channels, block, block_num=1):
|
| 117 |
+
layer = []
|
| 118 |
+
layer.append(block(in_channels, out_channels))
|
| 119 |
+
for _ in range(block_num - 1):
|
| 120 |
+
layer.append(block(out_channels, out_channels))
|
| 121 |
+
return nn.Sequential(*layer)
|
| 122 |
+
def forward(self, x, warm_flag):
|
| 123 |
+
outputs = self.VSSM(x)
|
| 124 |
+
x_e0f = outputs[0].permute(0, 3, 1, 2).contiguous()
|
| 125 |
+
x_e1f = outputs[1].permute(0, 3, 1, 2).contiguous()
|
| 126 |
+
x_e2f = outputs[2].permute(0, 3, 1, 2).contiguous()
|
| 127 |
+
x_e3f = outputs[3].permute(0, 3, 1, 2).contiguous()
|
| 128 |
+
x_e0z = self.encoder_0(self.conv_init(x))
|
| 129 |
+
x_e0 = torch.cat([x_e0z, x_e0f], dim=1)
|
| 130 |
+
x_e0z = self.post_fuse0(x_e0)
|
| 131 |
+
x_e1z = self.encoder_1(self.pool(x_e0z))
|
| 132 |
+
x_e1_fused = torch.cat([x_e1z, x_e1f], dim=1)
|
| 133 |
+
x_e1z = self.post_fuse1(x_e1_fused)
|
| 134 |
+
x_e2z = self.encoder_2(self.pool(x_e1z))
|
| 135 |
+
x_e2_fused = torch.cat([x_e2z, x_e2f], dim=1)
|
| 136 |
+
x_e2z = self.post_fuse2(x_e2_fused)
|
| 137 |
+
x_e3z = self.encoder_3(self.pool(x_e2z))
|
| 138 |
+
x_e3_fused = torch.cat([x_e3z, x_e3f], dim=1)
|
| 139 |
+
x_e3z = self.post_fuse3(x_e3_fused)
|
| 140 |
+
x_m = self.middle_layer(self.pool(x_e3z))
|
| 141 |
+
x_m = self.GLFA(x_m)
|
| 142 |
+
x_d3 = self.decoder_3(torch.cat([x_e3z, self.up(x_m)], 1))
|
| 143 |
+
x_d2 = self.decoder_2(torch.cat([x_e2z, self.up(x_d3)], 1))
|
| 144 |
+
x_d1 = self.decoder_1(torch.cat([x_e1z, self.up(x_d2)], 1))
|
| 145 |
+
x_d0 = self.decoder_0(torch.cat([x_e0z, self.up(x_d1)], 1))
|
| 146 |
+
|
| 147 |
+
if warm_flag:
|
| 148 |
+
mask0 = self.output_0(x_d0)
|
| 149 |
+
mask1 = self.output_1(x_d1)
|
| 150 |
+
mask2 = self.output_2(x_d2)
|
| 151 |
+
mask3 = self.output_3(x_d3)
|
| 152 |
+
output = self.final(torch.cat([mask0, self.up(mask1), self.up_4(mask2), self.up_8(mask3)], dim=1))
|
| 153 |
+
return [mask0, mask1, mask2, mask3], output
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
output = self.output_0(x_d0)
|
| 157 |
+
return [], output
|
| 158 |
+
|
model/LaSEA.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Callable, Union, Tuple, Any
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import math
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
|
| 12 |
+
if min_value is None:
|
| 13 |
+
min_value = divisor
|
| 14 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 15 |
+
if new_v < 0.9 * v:
|
| 16 |
+
new_v += divisor
|
| 17 |
+
return new_v
|
| 18 |
+
def callMethod(self, ElementName):
|
| 19 |
+
return getattr(self, ElementName)
|
| 20 |
+
def setMethod(self, ElementName, ElementValue):
|
| 21 |
+
return setattr(self, ElementName, ElementValue)
|
| 22 |
+
def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor:
|
| 23 |
+
if isinstance(Feature, Tensor):
|
| 24 |
+
Feature = [Feature]
|
| 25 |
+
Indexs = None
|
| 26 |
+
Output = []
|
| 27 |
+
for f in Feature:
|
| 28 |
+
B, C, H, W = f.shape
|
| 29 |
+
if Mode == 1:
|
| 30 |
+
f = f.flatten(2)
|
| 31 |
+
if Indexs is None:
|
| 32 |
+
Indexs = torch.randperm(f.shape[-1], device=f.device)
|
| 33 |
+
f = f[:, :, Indexs.to(f.device)]
|
| 34 |
+
f = f.reshape(B, C, H, W)
|
| 35 |
+
else:
|
| 36 |
+
if Indexs is None:
|
| 37 |
+
Indexs = [torch.randperm(H, device=f.device),
|
| 38 |
+
torch.randperm(W, device=f.device)]
|
| 39 |
+
f = f[:, :, Indexs[0].to(f.device)]
|
| 40 |
+
f = f[:, :, :, Indexs[1].to(f.device)]
|
| 41 |
+
Output.append(f)
|
| 42 |
+
return Output
|
| 43 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
|
| 44 |
+
def __init__(self, output_size: int or tuple=1):
|
| 45 |
+
super(AdaptiveAvgPool2d, self).__init__(output_size=output_size)
|
| 46 |
+
|
| 47 |
+
def profileModule(self, Input: Tensor):
|
| 48 |
+
Output = self.forward(Input)
|
| 49 |
+
return Output, 0.0, 0.0
|
| 50 |
+
|
| 51 |
+
class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d):
|
| 52 |
+
def __init__(self, output_size: int or tuple=1):
|
| 53 |
+
super(AdaptiveMaxPool2d, self).__init__(output_size=output_size)
|
| 54 |
+
|
| 55 |
+
def profileModule(self, Input: Tensor):
|
| 56 |
+
Output = self.forward(Input)
|
| 57 |
+
return Output, 0.0, 0.0
|
| 58 |
+
class BaseConv2d(nn.Module):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
in_channels: int,
|
| 62 |
+
out_channels: int,
|
| 63 |
+
kernel_size: int,
|
| 64 |
+
stride: Optional[int] = 1,
|
| 65 |
+
padding: Optional[int] = None,
|
| 66 |
+
groups: Optional[int] = 1,
|
| 67 |
+
bias: Optional[bool] = None,
|
| 68 |
+
BNorm: bool = False,
|
| 69 |
+
ActLayer: Optional[Callable[..., nn.Module]] = None,
|
| 70 |
+
dilation: int = 1,
|
| 71 |
+
Momentum: Optional[float] = 0.1,
|
| 72 |
+
**kwargs: Any
|
| 73 |
+
) -> None:
|
| 74 |
+
super(BaseConv2d, self).__init__()
|
| 75 |
+
if padding is None:
|
| 76 |
+
padding = int((kernel_size - 1) // 2 * dilation)
|
| 77 |
+
|
| 78 |
+
if bias is None:
|
| 79 |
+
bias = not BNorm
|
| 80 |
+
|
| 81 |
+
self.in_channels = in_channels
|
| 82 |
+
self.out_channels = out_channels
|
| 83 |
+
self.kernel_size = kernel_size
|
| 84 |
+
self.stride = stride
|
| 85 |
+
self.padding = padding
|
| 86 |
+
self.groups = groups
|
| 87 |
+
self.bias = bias
|
| 88 |
+
|
| 89 |
+
self.Conv = nn.Conv2d(in_channels, out_channels,
|
| 90 |
+
kernel_size, stride, padding, dilation, groups, bias, **kwargs)
|
| 91 |
+
|
| 92 |
+
self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity()
|
| 93 |
+
|
| 94 |
+
if ActLayer is not None:
|
| 95 |
+
if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid):
|
| 96 |
+
self.Act = ActLayer()
|
| 97 |
+
else:
|
| 98 |
+
self.Act = ActLayer(inplace=True)
|
| 99 |
+
else:
|
| 100 |
+
self.Act = ActLayer
|
| 101 |
+
|
| 102 |
+
self.apply(initWeight)
|
| 103 |
+
|
| 104 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 105 |
+
x = self.Conv(x)
|
| 106 |
+
x = self.Bn(x)
|
| 107 |
+
if self.Act is not None:
|
| 108 |
+
x = self.Act(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
NormLayerTuple = (
|
| 112 |
+
nn.BatchNorm1d,
|
| 113 |
+
nn.BatchNorm2d,
|
| 114 |
+
nn.SyncBatchNorm,
|
| 115 |
+
nn.LayerNorm,
|
| 116 |
+
nn.InstanceNorm1d,
|
| 117 |
+
nn.InstanceNorm2d,
|
| 118 |
+
nn.GroupNorm,
|
| 119 |
+
nn.BatchNorm3d,
|
| 120 |
+
)
|
| 121 |
+
def initWeight(Module):
|
| 122 |
+
if Module is None:
|
| 123 |
+
return
|
| 124 |
+
elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
|
| 125 |
+
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
|
| 126 |
+
if Module.bias is not None:
|
| 127 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
|
| 128 |
+
if fan_in != 0:
|
| 129 |
+
bound = 1 / math.sqrt(fan_in)
|
| 130 |
+
nn.init.uniform_(Module.bias, -bound, bound)
|
| 131 |
+
elif isinstance(Module, NormLayerTuple):
|
| 132 |
+
if Module.weight is not None:
|
| 133 |
+
nn.init.ones_(Module.weight)
|
| 134 |
+
if Module.bias is not None:
|
| 135 |
+
nn.init.zeros_(Module.bias)
|
| 136 |
+
elif isinstance(Module, nn.Linear):
|
| 137 |
+
nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5))
|
| 138 |
+
if Module.bias is not None:
|
| 139 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight)
|
| 140 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 141 |
+
nn.init.uniform_(Module.bias, -bound, bound)
|
| 142 |
+
elif isinstance(Module, (nn.Sequential, nn.ModuleList)):
|
| 143 |
+
for m in Module:
|
| 144 |
+
initWeight(m)
|
| 145 |
+
elif list(Module.children()):
|
| 146 |
+
for m in Module.children():
|
| 147 |
+
initWeight(m)
|
| 148 |
+
class Attention(nn.Module):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
InChannels: int,
|
| 152 |
+
HidChannels: int = None,
|
| 153 |
+
SqueezeFactor: int = 4,
|
| 154 |
+
PoolRes: list = [1, 2, 3],
|
| 155 |
+
Act: Callable[..., nn.Module] = nn.ReLU,
|
| 156 |
+
ScaleAct: Callable[..., nn.Module] = nn.Sigmoid,
|
| 157 |
+
MoCOrder: bool = True,
|
| 158 |
+
**kwargs: Any,
|
| 159 |
+
) -> None:
|
| 160 |
+
super().__init__()
|
| 161 |
+
if HidChannels is None:
|
| 162 |
+
HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32)
|
| 163 |
+
|
| 164 |
+
AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes
|
| 165 |
+
for k in AllPoolRes:
|
| 166 |
+
Pooling = AdaptiveAvgPool2d(k)
|
| 167 |
+
setMethod(self, 'Pool%d' % k, Pooling)
|
| 168 |
+
|
| 169 |
+
self.SELayer = nn.Sequential(
|
| 170 |
+
BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act),
|
| 171 |
+
BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.PoolRes = PoolRes
|
| 175 |
+
self.MoCOrder = MoCOrder
|
| 176 |
+
|
| 177 |
+
def RandomSample(self, x: Tensor) -> Tensor:
|
| 178 |
+
if self.training:
|
| 179 |
+
PoolKeep = np.random.choice(self.PoolRes)
|
| 180 |
+
x1 = shuffleTensor(x)[0] if self.MoCOrder else x
|
| 181 |
+
AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1)
|
| 182 |
+
if AttnMap.shape[-1] > 1:
|
| 183 |
+
AttnMap = AttnMap.flatten(2)
|
| 184 |
+
AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]]
|
| 185 |
+
AttnMap = AttnMap[:, :, None, None] # squeeze twice
|
| 186 |
+
else:
|
| 187 |
+
AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x)
|
| 188 |
+
|
| 189 |
+
return AttnMap
|
| 190 |
+
|
| 191 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 192 |
+
AttnMap = self.RandomSample(x)
|
| 193 |
+
return x * self.SELayer(AttnMap)
|
| 194 |
+
|
| 195 |
+
def channel_shuffle(x, groups):
|
| 196 |
+
batchsize, num_channels, height, width = x.data.size()
|
| 197 |
+
channels_per_group = num_channels // groups
|
| 198 |
+
x = x.view(batchsize, groups, channels_per_group, height, width)
|
| 199 |
+
x = torch.transpose(x, 1, 2).contiguous()
|
| 200 |
+
x = x.view(batchsize, -1, height, width)
|
| 201 |
+
return x
|
| 202 |
+
class GLFA(nn.Module):
|
| 203 |
+
def __init__(self, in_channels):
|
| 204 |
+
super(GLFA, self).__init__()
|
| 205 |
+
self.in_channels = in_channels
|
| 206 |
+
self.out_channels = in_channels
|
| 207 |
+
self.conv_1 = nn.Sequential(
|
| 208 |
+
nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1),
|
| 209 |
+
nn.BatchNorm2d(in_channels),
|
| 210 |
+
nn.ReLU(inplace=True)
|
| 211 |
+
)
|
| 212 |
+
self.conv_2 = nn.Sequential(
|
| 213 |
+
nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2),
|
| 214 |
+
nn.BatchNorm2d(in_channels),
|
| 215 |
+
nn.ReLU(inplace=True)
|
| 216 |
+
)
|
| 217 |
+
self.conv_3 = nn.Sequential(
|
| 218 |
+
nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3),
|
| 219 |
+
nn.BatchNorm2d(in_channels),
|
| 220 |
+
nn.ReLU(inplace=True)
|
| 221 |
+
)
|
| 222 |
+
self.conv_4 = nn.Sequential(
|
| 223 |
+
nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4),
|
| 224 |
+
nn.BatchNorm2d(in_channels),
|
| 225 |
+
nn.ReLU(inplace=True)
|
| 226 |
+
)
|
| 227 |
+
self.fuse = nn.Sequential(
|
| 228 |
+
nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0),
|
| 229 |
+
nn.BatchNorm2d(in_channels),
|
| 230 |
+
nn.ReLU(inplace=True)
|
| 231 |
+
)
|
| 232 |
+
self.mca = Attention(InChannels=in_channels, HidChannels=16)
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
d = x
|
| 235 |
+
c1 = self.conv_1(x)
|
| 236 |
+
c2 = self.conv_2(x)
|
| 237 |
+
c3 = self.conv_3(x)
|
| 238 |
+
c4 = self.conv_4(x)
|
| 239 |
+
cat = torch.cat([c1, c2, c3, c4], dim=1)
|
| 240 |
+
cat = channel_shuffle(cat, groups=4)
|
| 241 |
+
M= self.fuse(cat) #
|
| 242 |
+
O = self.mca(M)
|
| 243 |
+
return O + d
|
model/auxiliary.py
ADDED
|
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import math
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Optional, Callable
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.utils.checkpoint as checkpoint
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
| 15 |
+
except:
|
| 16 |
+
pass
|
| 17 |
+
try:
|
| 18 |
+
from selective_scan import selective_scan_fn as selective_scan_fn_v1
|
| 19 |
+
from selective_scan import selective_scan_ref as selective_scan_ref_v1
|
| 20 |
+
except:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
|
| 24 |
+
|
| 25 |
+
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
|
| 26 |
+
"""
|
| 27 |
+
u: r(B D L)
|
| 28 |
+
delta: r(B D L)
|
| 29 |
+
A: r(D N)
|
| 30 |
+
B: r(B N L)
|
| 31 |
+
C: r(B N L)
|
| 32 |
+
D: r(D)
|
| 33 |
+
z: r(B D L)
|
| 34 |
+
delta_bias: r(D), fp32
|
| 35 |
+
|
| 36 |
+
ignores:
|
| 37 |
+
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
| 38 |
+
"""
|
| 39 |
+
import numpy as np
|
| 40 |
+
|
| 41 |
+
# fvcore.nn.jit_handles
|
| 42 |
+
def get_flops_einsum(input_shapes, equation):
|
| 43 |
+
np_arrs = [np.zeros(s) for s in input_shapes]
|
| 44 |
+
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
|
| 45 |
+
for line in optim.split("\n"):
|
| 46 |
+
if "optimized flop" in line.lower():
|
| 47 |
+
# divided by 2 because we count MAC (multiply-add counted as one flop)
|
| 48 |
+
flop = float(np.floor(float(line.split(":")[-1]) / 2))
|
| 49 |
+
return flop
|
| 50 |
+
|
| 51 |
+
assert not with_complex
|
| 52 |
+
|
| 53 |
+
flops = 0 # below code flops = 0
|
| 54 |
+
if False:
|
| 55 |
+
...
|
| 56 |
+
"""
|
| 57 |
+
dtype_in = u.dtype
|
| 58 |
+
u = u.float()
|
| 59 |
+
delta = delta.float()
|
| 60 |
+
if delta_bias is not None:
|
| 61 |
+
delta = delta + delta_bias[..., None].float()
|
| 62 |
+
if delta_softplus:
|
| 63 |
+
delta = F.softplus(delta)
|
| 64 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
| 65 |
+
is_variable_B = B.dim() >= 3
|
| 66 |
+
is_variable_C = C.dim() >= 3
|
| 67 |
+
if A.is_complex():
|
| 68 |
+
if is_variable_B:
|
| 69 |
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
| 70 |
+
if is_variable_C:
|
| 71 |
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
| 72 |
+
else:
|
| 73 |
+
B = B.float()
|
| 74 |
+
C = C.float()
|
| 75 |
+
x = A.new_zeros((batch, dim, dstate))
|
| 76 |
+
ys = []
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
|
| 80 |
+
if with_Group:
|
| 81 |
+
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
|
| 82 |
+
else:
|
| 83 |
+
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
|
| 84 |
+
if False:
|
| 85 |
+
...
|
| 86 |
+
"""
|
| 87 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
| 88 |
+
if not is_variable_B:
|
| 89 |
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
| 90 |
+
else:
|
| 91 |
+
if B.dim() == 3:
|
| 92 |
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
| 93 |
+
else:
|
| 94 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
| 95 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
| 96 |
+
if is_variable_C and C.dim() == 4:
|
| 97 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
| 98 |
+
last_state = None
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
in_for_flops = B * D * N
|
| 102 |
+
if with_Group:
|
| 103 |
+
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
|
| 104 |
+
else:
|
| 105 |
+
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
|
| 106 |
+
flops += L * in_for_flops
|
| 107 |
+
if False:
|
| 108 |
+
...
|
| 109 |
+
"""
|
| 110 |
+
for i in range(u.shape[2]):
|
| 111 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
| 112 |
+
if not is_variable_C:
|
| 113 |
+
y = torch.einsum('bdn,dn->bd', x, C)
|
| 114 |
+
else:
|
| 115 |
+
if C.dim() == 3:
|
| 116 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
| 117 |
+
else:
|
| 118 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
| 119 |
+
if i == u.shape[2] - 1:
|
| 120 |
+
last_state = x
|
| 121 |
+
if y.is_complex():
|
| 122 |
+
y = y.real * 2
|
| 123 |
+
ys.append(y)
|
| 124 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
if with_D:
|
| 128 |
+
flops += B * D * L
|
| 129 |
+
if with_Z:
|
| 130 |
+
flops += B * D * L
|
| 131 |
+
if False:
|
| 132 |
+
...
|
| 133 |
+
return flops
|
| 134 |
+
class PatchEmbed2D(nn.Module):
|
| 135 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
|
| 136 |
+
super().__init__()
|
| 137 |
+
if isinstance(patch_size, int):
|
| 138 |
+
patch_size = (patch_size, patch_size)
|
| 139 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 140 |
+
if norm_layer is not None:
|
| 141 |
+
self.norm = norm_layer(embed_dim)
|
| 142 |
+
else:
|
| 143 |
+
self.norm = None
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
x = self.proj(x).permute(0, 2, 3, 1)
|
| 147 |
+
if self.norm is not None:
|
| 148 |
+
x = self.norm(x)
|
| 149 |
+
return x
|
| 150 |
+
class PatchMerging2D(nn.Module):
|
| 151 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.dim = dim
|
| 154 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 155 |
+
self.norm = norm_layer(4 * dim)
|
| 156 |
+
def forward(self, x): # x: [B, H, W, C]
|
| 157 |
+
B, H, W, C = x.shape
|
| 158 |
+
SHAPE_FIX = [-1, -1]
|
| 159 |
+
if (W % 2 != 0) or (H % 2 != 0):
|
| 160 |
+
print(f"Warning: x.shape {x.shape} is not even.", flush=True)
|
| 161 |
+
SHAPE_FIX[0] = H // 2
|
| 162 |
+
SHAPE_FIX[1] = W // 2
|
| 163 |
+
x0 = x[:, 0::2, 0::2, :]
|
| 164 |
+
x1 = x[:, 1::2, 0::2, :]
|
| 165 |
+
x2 = x[:, 0::2, 1::2, :]
|
| 166 |
+
x3 = x[:, 1::2, 1::2, :]
|
| 167 |
+
if SHAPE_FIX[0] > 0:
|
| 168 |
+
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
| 169 |
+
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
| 170 |
+
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
| 171 |
+
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
| 172 |
+
x = torch.cat([x0, x1, x2, x3], dim=-1)
|
| 173 |
+
x = self.norm(x)
|
| 174 |
+
x = self.reduction(x)
|
| 175 |
+
return x
|
| 176 |
+
class PatchExpand2D(nn.Module):
|
| 177 |
+
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.dim = dim * 2
|
| 180 |
+
self.dim_scale = dim_scale
|
| 181 |
+
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
|
| 182 |
+
self.norm = norm_layer(self.dim // dim_scale)
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
B, H, W, C = x.shape
|
| 185 |
+
x = self.expand(x)
|
| 186 |
+
|
| 187 |
+
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
|
| 188 |
+
c=C // self.dim_scale)
|
| 189 |
+
x = self.norm(x)
|
| 190 |
+
return x
|
| 191 |
+
class Final_PatchExpand2D(nn.Module):
|
| 192 |
+
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.dim = dim
|
| 195 |
+
self.dim_scale = dim_scale
|
| 196 |
+
self.expand = nn.Linear(self.dim, dim_scale * self.dim, bias=False)
|
| 197 |
+
self.norm = norm_layer(self.dim // dim_scale)
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
B, H, W, C = x.shape
|
| 201 |
+
x = self.expand(x)
|
| 202 |
+
|
| 203 |
+
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
|
| 204 |
+
c=C // self.dim_scale)
|
| 205 |
+
x = self.norm(x)
|
| 206 |
+
|
| 207 |
+
return x
|
| 208 |
+
class SS2D(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
d_model,
|
| 212 |
+
d_state=16,
|
| 213 |
+
d_conv=3,
|
| 214 |
+
expand=2,
|
| 215 |
+
dt_rank="auto",
|
| 216 |
+
dt_min=0.001,
|
| 217 |
+
dt_max=0.1,
|
| 218 |
+
dt_init="random",
|
| 219 |
+
dt_scale=1.0,
|
| 220 |
+
dt_init_floor=1e-4,
|
| 221 |
+
dropout=0.,
|
| 222 |
+
conv_bias=True,
|
| 223 |
+
bias=False,
|
| 224 |
+
device=None,
|
| 225 |
+
dtype=None,
|
| 226 |
+
**kwargs,
|
| 227 |
+
):
|
| 228 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.d_model = d_model
|
| 231 |
+
self.d_state = d_state
|
| 232 |
+
self.d_conv = d_conv
|
| 233 |
+
self.expand = expand
|
| 234 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 235 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 236 |
+
|
| 237 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 238 |
+
self.conv2d = nn.Conv2d(
|
| 239 |
+
in_channels=self.d_inner,
|
| 240 |
+
out_channels=self.d_inner,
|
| 241 |
+
groups=self.d_inner,
|
| 242 |
+
bias=conv_bias,
|
| 243 |
+
kernel_size=d_conv,
|
| 244 |
+
padding=(d_conv - 1) // 2,
|
| 245 |
+
**factory_kwargs,
|
| 246 |
+
)
|
| 247 |
+
self.act = nn.SiLU()
|
| 248 |
+
|
| 249 |
+
self.x_proj = (
|
| 250 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 251 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 252 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 253 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 254 |
+
)
|
| 255 |
+
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
|
| 256 |
+
del self.x_proj
|
| 257 |
+
|
| 258 |
+
self.dt_projs = (
|
| 259 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 260 |
+
**factory_kwargs),
|
| 261 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 262 |
+
**factory_kwargs),
|
| 263 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 264 |
+
**factory_kwargs),
|
| 265 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 266 |
+
**factory_kwargs),
|
| 267 |
+
)
|
| 268 |
+
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
|
| 269 |
+
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
|
| 270 |
+
del self.dt_projs
|
| 271 |
+
|
| 272 |
+
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
| 273 |
+
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
| 274 |
+
|
| 275 |
+
# self.selective_scan = selective_scan_fn
|
| 276 |
+
self.forward_core = self.forward_corev0
|
| 277 |
+
|
| 278 |
+
self.out_norm = nn.LayerNorm(self.d_inner)
|
| 279 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 280 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
|
| 281 |
+
self.ChannelAttentionModule = ChannelAttentionModule(in_channels=self.d_inner)
|
| 282 |
+
@staticmethod
|
| 283 |
+
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
|
| 284 |
+
**factory_kwargs):
|
| 285 |
+
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
|
| 286 |
+
dt_init_std = dt_rank ** -0.5 * dt_scale
|
| 287 |
+
if dt_init == "constant":
|
| 288 |
+
nn.init.constant_(dt_proj.weight, dt_init_std)
|
| 289 |
+
elif dt_init == "random":
|
| 290 |
+
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
|
| 291 |
+
else:
|
| 292 |
+
raise NotImplementedError
|
| 293 |
+
dt = torch.exp(
|
| 294 |
+
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 295 |
+
+ math.log(dt_min)
|
| 296 |
+
).clamp(min=dt_init_floor)
|
| 297 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
dt_proj.bias.copy_(inv_dt)
|
| 300 |
+
dt_proj.bias._no_reinit = True
|
| 301 |
+
|
| 302 |
+
return dt_proj
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
|
| 306 |
+
# S4D real initialization
|
| 307 |
+
A = repeat(
|
| 308 |
+
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
|
| 309 |
+
"n -> d n",
|
| 310 |
+
d=d_inner,
|
| 311 |
+
).contiguous()
|
| 312 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 313 |
+
if copies > 1:
|
| 314 |
+
A_log = repeat(A_log, "d n -> r d n", r=copies)
|
| 315 |
+
if merge:
|
| 316 |
+
A_log = A_log.flatten(0, 1)
|
| 317 |
+
A_log = nn.Parameter(A_log)
|
| 318 |
+
A_log._no_weight_decay = True
|
| 319 |
+
return A_log
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def D_init(d_inner, copies=1, device=None, merge=True):
|
| 323 |
+
# D "skip" parameter
|
| 324 |
+
D = torch.ones(d_inner, device=device)
|
| 325 |
+
if copies > 1:
|
| 326 |
+
D = repeat(D, "n1 -> r n1", r=copies)
|
| 327 |
+
if merge:
|
| 328 |
+
D = D.flatten(0, 1)
|
| 329 |
+
D = nn.Parameter(D) # Keep in fp32
|
| 330 |
+
D._no_weight_decay = True
|
| 331 |
+
return D
|
| 332 |
+
|
| 333 |
+
def forward_corev0(self, x: torch.Tensor):
|
| 334 |
+
self.selective_scan = selective_scan_fn
|
| 335 |
+
|
| 336 |
+
B, C, H, W = x.shape
|
| 337 |
+
L = H * W
|
| 338 |
+
K = 4
|
| 339 |
+
|
| 340 |
+
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
| 341 |
+
dim=1).view(B, 2, -1, L)
|
| 342 |
+
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
| 343 |
+
|
| 344 |
+
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
| 345 |
+
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
| 346 |
+
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
| 347 |
+
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
| 348 |
+
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
|
| 349 |
+
|
| 350 |
+
xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
| 351 |
+
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
| 352 |
+
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
| 353 |
+
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
| 354 |
+
Ds = self.Ds.float().view(-1) # (k * d)
|
| 355 |
+
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
| 356 |
+
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
| 357 |
+
|
| 358 |
+
out_y = self.selective_scan(
|
| 359 |
+
xs, dts,
|
| 360 |
+
As, Bs, Cs, Ds, z=None,
|
| 361 |
+
delta_bias=dt_projs_bias,
|
| 362 |
+
delta_softplus=True,
|
| 363 |
+
return_last_state=False,
|
| 364 |
+
).view(B, K, -1, L)
|
| 365 |
+
assert out_y.dtype == torch.float
|
| 366 |
+
|
| 367 |
+
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
| 368 |
+
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 369 |
+
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 370 |
+
|
| 371 |
+
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
|
| 372 |
+
def forward_corev1(self, x: torch.Tensor):
|
| 373 |
+
self.selective_scan = selective_scan_fn_v1
|
| 374 |
+
|
| 375 |
+
B, C, H, W = x.shape
|
| 376 |
+
L = H * W
|
| 377 |
+
K = 4
|
| 378 |
+
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
| 379 |
+
dim=1).view(B, 2, -1, L)
|
| 380 |
+
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
|
| 381 |
+
|
| 382 |
+
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
| 383 |
+
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
| 384 |
+
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
| 385 |
+
xs = xs.float().view(B, -1, L)
|
| 386 |
+
dts = dts.contiguous().float().view(B, -1, L)
|
| 387 |
+
Bs = Bs.float().view(B, K, -1, L)
|
| 388 |
+
Cs = Cs.float().view(B, K, -1, L)
|
| 389 |
+
Ds = self.Ds.float().view(-1)
|
| 390 |
+
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)
|
| 391 |
+
dt_projs_bias = self.dt_projs_bias.float().view(-1)
|
| 392 |
+
|
| 393 |
+
out_y = self.selective_scan(
|
| 394 |
+
xs, dts,
|
| 395 |
+
As, Bs, Cs, Ds,
|
| 396 |
+
delta_bias=dt_projs_bias,
|
| 397 |
+
delta_softplus=True,
|
| 398 |
+
).view(B, K, -1, L)
|
| 399 |
+
assert out_y.dtype == torch.float
|
| 400 |
+
|
| 401 |
+
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
| 402 |
+
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 403 |
+
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 404 |
+
|
| 405 |
+
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
|
| 406 |
+
def forward(self, a: torch.Tensor, **kwargs):
|
| 407 |
+
B, H, W, C = a.shape
|
| 408 |
+
|
| 409 |
+
xz = self.in_proj(a)
|
| 410 |
+
x, z = xz.chunk(2, dim=-1)
|
| 411 |
+
z = z.permute(0, 3, 1, 2)
|
| 412 |
+
z = self.ChannelAttentionModule(z) * z
|
| 413 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 414 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 415 |
+
x = self.act(self.conv2d(x))
|
| 416 |
+
y1, y2, y3, y4 = self.forward_core(x)
|
| 417 |
+
assert y1.dtype == torch.float32
|
| 418 |
+
y = y1 + y2 + y3 + y4
|
| 419 |
+
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
| 420 |
+
y = self.out_norm(y)
|
| 421 |
+
y = y * torch.nn.functional.silu(z)
|
| 422 |
+
out = self.out_proj(y)
|
| 423 |
+
if self.dropout is not None:
|
| 424 |
+
out = self.dropout(out)
|
| 425 |
+
return out+a
|
| 426 |
+
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
|
| 427 |
+
batch_size, height, width, num_channels = x.size()
|
| 428 |
+
channels_per_group = num_channels // groups
|
| 429 |
+
x = x.view(batch_size, height, width, groups, channels_per_group)
|
| 430 |
+
x = torch.transpose(x, 3, 4).contiguous()
|
| 431 |
+
x = x.view(batch_size, height, width, -1)
|
| 432 |
+
return x
|
| 433 |
+
|
| 434 |
+
class ChannelAttentionModule(nn.Module):
|
| 435 |
+
def __init__(self, in_channels, reduction=4):
|
| 436 |
+
super(ChannelAttentionModule, self).__init__()
|
| 437 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 438 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
| 439 |
+
self.fc = nn.Sequential(
|
| 440 |
+
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
|
| 441 |
+
nn.ReLU(inplace=True),
|
| 442 |
+
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
|
| 443 |
+
)
|
| 444 |
+
self.sigmoid = nn.Sigmoid()
|
| 445 |
+
|
| 446 |
+
def forward(self, x):
|
| 447 |
+
avg_out = self.fc(self.avg_pool(x))
|
| 448 |
+
max_out = self.fc(self.max_pool(x))
|
| 449 |
+
out = avg_out + max_out
|
| 450 |
+
return self.sigmoid(out)
|
| 451 |
+
|
| 452 |
+
class SS_Conv_SSM(nn.Module):
|
| 453 |
+
def __init__(
|
| 454 |
+
self,
|
| 455 |
+
hidden_dim: int = 0,
|
| 456 |
+
drop_path: float = 0,
|
| 457 |
+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
| 458 |
+
attn_drop_rate: float = 0,
|
| 459 |
+
d_state: int = 16,
|
| 460 |
+
**kwargs,
|
| 461 |
+
):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.ln_1 = norm_layer(hidden_dim // 2)
|
| 464 |
+
self.self_attention = SS2D(d_model=hidden_dim // 2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
|
| 465 |
+
self.drop_path = DropPath(drop_path)
|
| 466 |
+
|
| 467 |
+
self.conv33conv33conv11 = nn.Sequential(
|
| 468 |
+
nn.BatchNorm2d(hidden_dim // 2),
|
| 469 |
+
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
|
| 470 |
+
nn.BatchNorm2d(hidden_dim // 2),
|
| 471 |
+
nn.ReLU(),
|
| 472 |
+
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
|
| 473 |
+
nn.BatchNorm2d(hidden_dim // 2),
|
| 474 |
+
nn.ReLU(),
|
| 475 |
+
nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1),
|
| 476 |
+
nn.ReLU()
|
| 477 |
+
)
|
| 478 |
+
self.ChannelAttentionModule = ChannelAttentionModule(in_channels=hidden_dim // 2)
|
| 479 |
+
def forward(self, input: torch.Tensor):
|
| 480 |
+
input_left, input_right = input.chunk(2, dim=-1)
|
| 481 |
+
input_right = self.ln_1(input_right)
|
| 482 |
+
input_left = self.ln_1(input_left)
|
| 483 |
+
x = self.drop_path(self.self_attention(input_right))
|
| 484 |
+
b0 = input_left.permute(0, 3, 1, 2).contiguous()
|
| 485 |
+
b1 = self.conv33conv33conv11(b0)
|
| 486 |
+
b2 = self.ChannelAttentionModule(b0)
|
| 487 |
+
b1= b1.permute(0, 2, 3, 1).contiguous()
|
| 488 |
+
b2 = b2.permute(0, 2, 3, 1).contiguous()
|
| 489 |
+
input_left = b1 * b2
|
| 490 |
+
output1 = torch.cat((input_left, x), dim=-1)
|
| 491 |
+
output = channel_shuffle(output1, groups=2)
|
| 492 |
+
return output + input
|
| 493 |
+
class VSSLayer(nn.Module):
|
| 494 |
+
""" A basic Swin Transformer layer for one stage.
|
| 495 |
+
Args:
|
| 496 |
+
dim (int): Number of input channels.
|
| 497 |
+
depth (int): Number of blocks.
|
| 498 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 499 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 500 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 501 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 502 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 503 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
def __init__(
|
| 507 |
+
self,
|
| 508 |
+
dim,
|
| 509 |
+
depth,
|
| 510 |
+
attn_drop=0.,
|
| 511 |
+
drop_path=0.,
|
| 512 |
+
norm_layer=nn.LayerNorm,
|
| 513 |
+
downsample=None,
|
| 514 |
+
use_checkpoint=False,
|
| 515 |
+
d_state=16,
|
| 516 |
+
**kwargs,
|
| 517 |
+
):
|
| 518 |
+
super().__init__()
|
| 519 |
+
self.dim = dim
|
| 520 |
+
self.use_checkpoint = use_checkpoint
|
| 521 |
+
|
| 522 |
+
self.blocks = nn.ModuleList([
|
| 523 |
+
SS_Conv_SSM(
|
| 524 |
+
hidden_dim=dim,
|
| 525 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 526 |
+
norm_layer=norm_layer,
|
| 527 |
+
attn_drop_rate=attn_drop,
|
| 528 |
+
d_state=d_state,
|
| 529 |
+
)
|
| 530 |
+
for i in range(depth)])
|
| 531 |
+
|
| 532 |
+
if True: # is this really applied? Yes, but been overriden later in VSSM!
|
| 533 |
+
def _init_weights(module: nn.Module):
|
| 534 |
+
for name, p in module.named_parameters():
|
| 535 |
+
if name in ["out_proj.weight-881-1KESHIHUA QUANZHONG"]:
|
| 536 |
+
p = p.clone().detach_() # fake init, just to keep the seed ....
|
| 537 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 538 |
+
|
| 539 |
+
self.apply(_init_weights)
|
| 540 |
+
|
| 541 |
+
if downsample is not None:
|
| 542 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 543 |
+
else:
|
| 544 |
+
self.downsample = None
|
| 545 |
+
|
| 546 |
+
def forward(self, x):
|
| 547 |
+
for blk in self.blocks:
|
| 548 |
+
if self.use_checkpoint:
|
| 549 |
+
x = checkpoint.checkpoint(blk, x)
|
| 550 |
+
else:
|
| 551 |
+
x = blk(x)
|
| 552 |
+
|
| 553 |
+
if self.downsample is not None:
|
| 554 |
+
x = self.downsample(x)
|
| 555 |
+
|
| 556 |
+
return x
|
| 557 |
+
class VSSLayer_up(nn.Module):
|
| 558 |
+
""" A basic Swin Transformer layer for one stage.
|
| 559 |
+
Args:
|
| 560 |
+
dim (int): Number of input channels.
|
| 561 |
+
depth (int): Number of blocks.
|
| 562 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 563 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 564 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 565 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 566 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 567 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
def __init__(
|
| 571 |
+
self,
|
| 572 |
+
dim,
|
| 573 |
+
depth,
|
| 574 |
+
attn_drop=0.,
|
| 575 |
+
drop_path=0.,
|
| 576 |
+
norm_layer=nn.LayerNorm,
|
| 577 |
+
upsample=None,
|
| 578 |
+
use_checkpoint=False,
|
| 579 |
+
d_state=16,
|
| 580 |
+
**kwargs,
|
| 581 |
+
):
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.dim = dim
|
| 584 |
+
self.use_checkpoint = use_checkpoint
|
| 585 |
+
|
| 586 |
+
self.blocks = nn.ModuleList([
|
| 587 |
+
SS_Conv_SSM(
|
| 588 |
+
hidden_dim=dim,
|
| 589 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 590 |
+
norm_layer=norm_layer,
|
| 591 |
+
attn_drop_rate=attn_drop,
|
| 592 |
+
d_state=d_state,
|
| 593 |
+
)
|
| 594 |
+
for i in range(depth)])
|
| 595 |
+
|
| 596 |
+
if True:
|
| 597 |
+
def _init_weights(module: nn.Module):
|
| 598 |
+
for name, p in module.named_parameters():
|
| 599 |
+
if name in ["out_proj.weight-881-1KESHIHUA QUANZHONG"]:
|
| 600 |
+
p = p.clone().detach_()
|
| 601 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
| 602 |
+
|
| 603 |
+
self.apply(_init_weights)
|
| 604 |
+
|
| 605 |
+
if upsample is not None:
|
| 606 |
+
self.upsample = upsample(dim=dim, norm_layer=norm_layer)
|
| 607 |
+
else:
|
| 608 |
+
self.upsample = None
|
| 609 |
+
|
| 610 |
+
def forward(self, x):
|
| 611 |
+
if self.upsample is not None:
|
| 612 |
+
x = self.upsample(x)
|
| 613 |
+
for blk in self.blocks:
|
| 614 |
+
if self.use_checkpoint:
|
| 615 |
+
x = checkpoint.checkpoint(blk, x)
|
| 616 |
+
else:
|
| 617 |
+
x = blk(x)
|
| 618 |
+
return x
|
| 619 |
+
class VSSM(nn.Module):
|
| 620 |
+
def __init__(self, patch_size=1, in_chans=3, num_classes=1, depths=[2, 2, 2, 2],
|
| 621 |
+
dims=[16, 32, 64, 128], d_state=16, drop_rate=0.,
|
| 622 |
+
attn_drop_rate=0., drop_path_rate=0.1,
|
| 623 |
+
norm_layer=nn.LayerNorm, patch_norm=True,
|
| 624 |
+
use_checkpoint=False, **kwargs):
|
| 625 |
+
super().__init__()
|
| 626 |
+
self.num_classes = num_classes
|
| 627 |
+
self.num_layers = len(depths)
|
| 628 |
+
self.embed_dim = dims[0]
|
| 629 |
+
self.num_features = dims[-1]
|
| 630 |
+
self.dims = dims
|
| 631 |
+
self.layer_outputs = []
|
| 632 |
+
self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
|
| 633 |
+
norm_layer=norm_layer if patch_norm else None)
|
| 634 |
+
self.ape = False
|
| 635 |
+
if self.ape:
|
| 636 |
+
self.patches_resolution = self.patch_embed.patches_resolution
|
| 637 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
|
| 638 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 639 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 640 |
+
|
| 641 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 642 |
+
self.layers = nn.ModuleList()
|
| 643 |
+
for i_layer in range(self.num_layers):
|
| 644 |
+
layer = VSSLayer(
|
| 645 |
+
dim=dims[i_layer],
|
| 646 |
+
depth=depths[i_layer],
|
| 647 |
+
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state,
|
| 648 |
+
drop=drop_rate,
|
| 649 |
+
attn_drop=attn_drop_rate,
|
| 650 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 651 |
+
norm_layer=norm_layer,
|
| 652 |
+
downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
|
| 653 |
+
use_checkpoint=use_checkpoint,
|
| 654 |
+
)
|
| 655 |
+
self.layers.append(layer)
|
| 656 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 657 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 658 |
+
|
| 659 |
+
self.apply(self._init_weights)
|
| 660 |
+
for m in self.modules():
|
| 661 |
+
if isinstance(m, nn.Conv2d):
|
| 662 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 663 |
+
|
| 664 |
+
def _init_weights(self, m: nn.Module):
|
| 665 |
+
if isinstance(m, nn.Linear):
|
| 666 |
+
trunc_normal_(m.weight, std=.02)
|
| 667 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 668 |
+
nn.init.constant_(m.bias, 0)
|
| 669 |
+
elif isinstance(m, nn.LayerNorm):
|
| 670 |
+
nn.init.constant_(m.bias, 0)
|
| 671 |
+
nn.init.constant_(m.weight, 1.0)
|
| 672 |
+
|
| 673 |
+
@torch.jit.ignore
|
| 674 |
+
def no_weight_decay(self):
|
| 675 |
+
return {'absolute_pos_embed'}
|
| 676 |
+
|
| 677 |
+
@torch.jit.ignore
|
| 678 |
+
def no_weight_decay_keywords(self):
|
| 679 |
+
return {'relative_position_bias_table'}
|
| 680 |
+
|
| 681 |
+
def forward_backbone(self, x):
|
| 682 |
+
self.layer_outputs = []
|
| 683 |
+
x = self.patch_embed(x)
|
| 684 |
+
self.layer_outputs.append(x)
|
| 685 |
+
|
| 686 |
+
if self.ape:
|
| 687 |
+
x = x + self.absolute_pos_embed
|
| 688 |
+
x = self.pos_drop(x)
|
| 689 |
+
|
| 690 |
+
for layer in self.layers:
|
| 691 |
+
x = layer(x)
|
| 692 |
+
self.layer_outputs.append(x)
|
| 693 |
+
return self.layer_outputs
|
| 694 |
+
|
| 695 |
+
def forward(self, x, i=None):
|
| 696 |
+
outputs = self.forward_backbone(x)
|
| 697 |
+
if i is not None:
|
| 698 |
+
x = outputs[i]
|
| 699 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 700 |
+
return x
|
| 701 |
+
return outputs
|
model/loss.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from skimage import measure
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def SoftIoULoss(pred, target):
|
| 9 |
+
pred = torch.sigmoid(pred)
|
| 10 |
+
|
| 11 |
+
smooth = 1
|
| 12 |
+
|
| 13 |
+
intersection = pred * target
|
| 14 |
+
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
|
| 15 |
+
pred_sum = torch.sum(pred, dim=(1, 2, 3))
|
| 16 |
+
target_sum = torch.sum(target, dim=(1, 2, 3))
|
| 17 |
+
|
| 18 |
+
loss = (intersection_sum + smooth) / \
|
| 19 |
+
(pred_sum + target_sum - intersection_sum + smooth)
|
| 20 |
+
|
| 21 |
+
loss = 1 - loss.mean()
|
| 22 |
+
|
| 23 |
+
return loss
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def Dice(pred, target, warm_epoch=1, epoch=1, layer=0):
|
| 27 |
+
pred = torch.sigmoid(pred)
|
| 28 |
+
|
| 29 |
+
smooth = 1
|
| 30 |
+
|
| 31 |
+
intersection = pred * target
|
| 32 |
+
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
|
| 33 |
+
pred_sum = torch.sum(pred, dim=(1, 2, 3))
|
| 34 |
+
target_sum = torch.sum(target, dim=(1, 2, 3))
|
| 35 |
+
|
| 36 |
+
loss = (2 * intersection_sum + smooth) / \
|
| 37 |
+
(pred_sum + target_sum + intersection_sum + smooth)
|
| 38 |
+
|
| 39 |
+
loss = 1 - loss.mean()
|
| 40 |
+
|
| 41 |
+
return loss
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SLSIoULoss(nn.Module):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super(SLSIoULoss, self).__init__()
|
| 47 |
+
|
| 48 |
+
def forward(self, pred_log, target, warm_epoch, epoch, with_shape=True):
|
| 49 |
+
pred = torch.sigmoid(pred_log)
|
| 50 |
+
smooth = 0.0
|
| 51 |
+
|
| 52 |
+
intersection = pred * target
|
| 53 |
+
|
| 54 |
+
intersection_sum = torch.sum(intersection, dim=(1, 2, 3))
|
| 55 |
+
pred_sum = torch.sum(pred, dim=(1, 2, 3))
|
| 56 |
+
target_sum = torch.sum(target, dim=(1, 2, 3))
|
| 57 |
+
|
| 58 |
+
dis = torch.pow((pred_sum - target_sum) / 2, 2)
|
| 59 |
+
|
| 60 |
+
alpha = (torch.min(pred_sum, target_sum) + dis + smooth) / (torch.max(pred_sum, target_sum) + dis + smooth)
|
| 61 |
+
|
| 62 |
+
loss = (intersection_sum + smooth) / \
|
| 63 |
+
(pred_sum + target_sum - intersection_sum + smooth)
|
| 64 |
+
lloss = LLoss(pred, target)
|
| 65 |
+
|
| 66 |
+
if epoch > warm_epoch:
|
| 67 |
+
siou_loss = alpha * loss
|
| 68 |
+
if with_shape:
|
| 69 |
+
loss = 1 - siou_loss.mean() + lloss
|
| 70 |
+
else:
|
| 71 |
+
loss = 1 - siou_loss.mean()
|
| 72 |
+
else:
|
| 73 |
+
loss = 1 - loss.mean()
|
| 74 |
+
return loss
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def LLoss(pred, target):
|
| 78 |
+
loss = torch.tensor(0.0, requires_grad=True).to(pred)
|
| 79 |
+
|
| 80 |
+
patch_size = pred.shape[0]
|
| 81 |
+
h = pred.shape[2]
|
| 82 |
+
w = pred.shape[3]
|
| 83 |
+
x_index = torch.arange(0, w, 1).view(1, 1, w).repeat((1, h, 1)).to(pred) / w
|
| 84 |
+
y_index = torch.arange(0, h, 1).view(1, h, 1).repeat((1, 1, w)).to(pred) / h
|
| 85 |
+
smooth = 1e-8
|
| 86 |
+
for i in range(patch_size):
|
| 87 |
+
pred_centerx = (x_index * pred[i]).mean()
|
| 88 |
+
pred_centery = (y_index * pred[i]).mean()
|
| 89 |
+
|
| 90 |
+
target_centerx = (x_index * target[i]).mean()
|
| 91 |
+
target_centery = (y_index * target[i]).mean()
|
| 92 |
+
|
| 93 |
+
angle_loss = (4 / (torch.pi ** 2)) * (torch.square(torch.arctan((pred_centery) / (pred_centerx + smooth))
|
| 94 |
+
- torch.arctan(
|
| 95 |
+
(target_centery) / (target_centerx + smooth))))
|
| 96 |
+
|
| 97 |
+
pred_length = torch.sqrt(pred_centerx * pred_centerx + pred_centery * pred_centery + smooth)
|
| 98 |
+
target_length = torch.sqrt(target_centerx * target_centerx + target_centery * target_centery + smooth)
|
| 99 |
+
|
| 100 |
+
length_loss = (torch.min(pred_length, target_length)) / (torch.max(pred_length, target_length) + smooth)
|
| 101 |
+
|
| 102 |
+
loss = loss + (1 - length_loss + angle_loss) / patch_size
|
| 103 |
+
|
| 104 |
+
return loss
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class AverageMeter(object):
|
| 108 |
+
"""Computes and stores the average and current value"""
|
| 109 |
+
|
| 110 |
+
def __init__(self):
|
| 111 |
+
self.reset()
|
| 112 |
+
|
| 113 |
+
def reset(self):
|
| 114 |
+
self.val = 0
|
| 115 |
+
self.avg = 0
|
| 116 |
+
self.sum = 0
|
| 117 |
+
self.count = 0
|
| 118 |
+
|
| 119 |
+
def update(self, val, n=1):
|
| 120 |
+
self.val = val
|
| 121 |
+
self.sum += val * n
|
| 122 |
+
self.count += n
|
| 123 |
+
self.avg = self.sum / self.count
|