Upload 6 files
Browse files- models/Blocks.py +476 -0
- models/STNR.py +327 -0
- models/__init__.py +13 -0
- models/loss.py +155 -0
- models/mamba_customer.py +569 -0
- models/resnet.py +358 -0
models/Blocks.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from monai.networks.layers.utils import get_act_layer
|
| 6 |
+
import warnings
|
| 7 |
+
warnings.filterwarnings("ignore")
|
| 8 |
+
import math
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Callable
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple
|
| 12 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CAB(nn.Module):
|
| 17 |
+
def __init__(self, in_channels, out_channels=None, ratio=16, activation='relu'):
|
| 18 |
+
super(CAB, self).__init__()
|
| 19 |
+
|
| 20 |
+
self.in_channels = in_channels
|
| 21 |
+
self.out_channels = out_channels
|
| 22 |
+
if self.in_channels < ratio:
|
| 23 |
+
ratio = self.in_channels
|
| 24 |
+
self.reduced_channels = self.in_channels // ratio
|
| 25 |
+
if self.out_channels == None:
|
| 26 |
+
self.out_channels = in_channels
|
| 27 |
+
|
| 28 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 29 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
| 30 |
+
self.activation = get_act_layer(activation)
|
| 31 |
+
self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
|
| 32 |
+
self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)
|
| 33 |
+
|
| 34 |
+
self.sigmoid = nn.Sigmoid()
|
| 35 |
+
|
| 36 |
+
nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
|
| 37 |
+
nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='relu')
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
|
| 41 |
+
avg = self.fc2(self.activation(self.fc1(self.avg_pool(x))))
|
| 42 |
+
max = self.fc2(self.activation(self.fc1(self.max_pool(x))))
|
| 43 |
+
attention = self.sigmoid(avg + max)
|
| 44 |
+
|
| 45 |
+
return attention
|
| 46 |
+
|
| 47 |
+
class SAB(nn.Module):
|
| 48 |
+
def __init__(self, kernel_size=7):
|
| 49 |
+
super(SAB, self).__init__()
|
| 50 |
+
assert kernel_size in (3, 7, 11), "kernel_size must be 3, 7 or 11"
|
| 51 |
+
padding = kernel_size // 2
|
| 52 |
+
|
| 53 |
+
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
| 54 |
+
self.sigmoid = nn.Sigmoid()
|
| 55 |
+
nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
| 59 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
| 60 |
+
|
| 61 |
+
x_cat = torch.cat([avg_out, max_out], dim=1) # shape: [B, 2, H, W]
|
| 62 |
+
attention = self.sigmoid(self.conv(x_cat))
|
| 63 |
+
|
| 64 |
+
return attention
|
| 65 |
+
|
| 66 |
+
#--------------------------------
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ChannelAttention(nn.Module):
|
| 70 |
+
"""Channel attention used in RCAN.
|
| 71 |
+
Args:
|
| 72 |
+
num_feat (int): Channel number of intermediate features.
|
| 73 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
| 77 |
+
super(ChannelAttention, self).__init__()
|
| 78 |
+
squeeze_channels = max(num_feat // squeeze_factor, 4) # 防止为0
|
| 79 |
+
self.attention = nn.Sequential(
|
| 80 |
+
nn.AdaptiveAvgPool2d(1),
|
| 81 |
+
nn.Conv2d(num_feat, squeeze_channels, 1, padding=0),
|
| 82 |
+
nn.ReLU(inplace=True),
|
| 83 |
+
nn.Conv2d(squeeze_channels, num_feat, 1, padding=0),
|
| 84 |
+
nn.Sigmoid())
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
y = self.attention(x)
|
| 88 |
+
return x * y
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CAB(nn.Module):
|
| 92 |
+
def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30):
|
| 93 |
+
super(CAB, self).__init__()
|
| 94 |
+
mid_channels = max(num_feat // compress_ratio, 4) # 防止为0
|
| 95 |
+
if is_light_sr: # we use depth-wise conv for light-SR to achieve more efficient
|
| 96 |
+
self.cab = nn.Sequential(
|
| 97 |
+
nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=num_feat),
|
| 98 |
+
ChannelAttention(num_feat, squeeze_factor)
|
| 99 |
+
)
|
| 100 |
+
else: # for classic SR
|
| 101 |
+
self.cab = nn.Sequential(
|
| 102 |
+
nn.Conv2d(num_feat, mid_channels, 3, 1, 1),
|
| 103 |
+
nn.GELU(),
|
| 104 |
+
nn.Conv2d(mid_channels, num_feat, 3, 1, 1),
|
| 105 |
+
ChannelAttention(num_feat, squeeze_factor)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
return self.cab(x)
|
| 110 |
+
|
| 111 |
+
class SS2D(nn.Module):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
d_model,
|
| 115 |
+
d_state=16,
|
| 116 |
+
d_conv=3,
|
| 117 |
+
expand=2.,
|
| 118 |
+
dt_rank="auto",
|
| 119 |
+
dt_min=0.001,
|
| 120 |
+
dt_max=0.1,
|
| 121 |
+
dt_init="random",
|
| 122 |
+
dt_scale=1.0,
|
| 123 |
+
dt_init_floor=1e-4,
|
| 124 |
+
dropout=0.,
|
| 125 |
+
conv_bias=True,
|
| 126 |
+
bias=False,
|
| 127 |
+
device=None,
|
| 128 |
+
dtype=None,
|
| 129 |
+
**kwargs,
|
| 130 |
+
):
|
| 131 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.d_model = d_model
|
| 134 |
+
self.d_state = d_state
|
| 135 |
+
self.d_conv = d_conv
|
| 136 |
+
self.expand = expand
|
| 137 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 138 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 139 |
+
|
| 140 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 141 |
+
self.conv2d = nn.Conv2d(
|
| 142 |
+
in_channels=self.d_inner,
|
| 143 |
+
out_channels=self.d_inner,
|
| 144 |
+
groups=self.d_inner,
|
| 145 |
+
bias=conv_bias,
|
| 146 |
+
kernel_size=d_conv,
|
| 147 |
+
padding=(d_conv - 1) // 2,
|
| 148 |
+
**factory_kwargs,
|
| 149 |
+
)
|
| 150 |
+
self.act = nn.SiLU()
|
| 151 |
+
|
| 152 |
+
self.x_proj = (
|
| 153 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 154 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 155 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 156 |
+
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
| 157 |
+
)
|
| 158 |
+
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
|
| 159 |
+
del self.x_proj
|
| 160 |
+
|
| 161 |
+
self.dt_projs = (
|
| 162 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 163 |
+
**factory_kwargs),
|
| 164 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 165 |
+
**factory_kwargs),
|
| 166 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 167 |
+
**factory_kwargs),
|
| 168 |
+
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
| 169 |
+
**factory_kwargs),
|
| 170 |
+
)
|
| 171 |
+
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
|
| 172 |
+
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
|
| 173 |
+
del self.dt_projs
|
| 174 |
+
|
| 175 |
+
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
| 176 |
+
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
| 177 |
+
|
| 178 |
+
self.selective_scan = selective_scan_fn
|
| 179 |
+
|
| 180 |
+
self.out_norm = nn.LayerNorm(self.d_inner)
|
| 181 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 182 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
|
| 186 |
+
**factory_kwargs):
|
| 187 |
+
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
|
| 188 |
+
|
| 189 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 190 |
+
dt_init_std = dt_rank ** -0.5 * dt_scale
|
| 191 |
+
if dt_init == "constant":
|
| 192 |
+
nn.init.constant_(dt_proj.weight, dt_init_std)
|
| 193 |
+
elif dt_init == "random":
|
| 194 |
+
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError
|
| 197 |
+
|
| 198 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 199 |
+
dt = torch.exp(
|
| 200 |
+
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 201 |
+
+ math.log(dt_min)
|
| 202 |
+
).clamp(min=dt_init_floor)
|
| 203 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 204 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
dt_proj.bias.copy_(inv_dt)
|
| 207 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 208 |
+
dt_proj.bias._no_reinit = True
|
| 209 |
+
|
| 210 |
+
return dt_proj
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
|
| 214 |
+
# S4D real initialization
|
| 215 |
+
A = repeat(
|
| 216 |
+
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
|
| 217 |
+
"n -> d n",
|
| 218 |
+
d=d_inner,
|
| 219 |
+
).contiguous()
|
| 220 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 221 |
+
if copies > 1:
|
| 222 |
+
A_log = repeat(A_log, "d n -> r d n", r=copies)
|
| 223 |
+
if merge:
|
| 224 |
+
A_log = A_log.flatten(0, 1)
|
| 225 |
+
A_log = nn.Parameter(A_log)
|
| 226 |
+
A_log._no_weight_decay = True
|
| 227 |
+
return A_log
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def D_init(d_inner, copies=1, device=None, merge=True):
|
| 231 |
+
# D "skip" parameter
|
| 232 |
+
D = torch.ones(d_inner, device=device)
|
| 233 |
+
if copies > 1:
|
| 234 |
+
D = repeat(D, "n1 -> r n1", r=copies)
|
| 235 |
+
if merge:
|
| 236 |
+
D = D.flatten(0, 1)
|
| 237 |
+
D = nn.Parameter(D) # Keep in fp32
|
| 238 |
+
D._no_weight_decay = True
|
| 239 |
+
return D
|
| 240 |
+
|
| 241 |
+
def forward_core(self, x: torch.Tensor):
|
| 242 |
+
B, C, H, W = x.shape
|
| 243 |
+
L = H * W
|
| 244 |
+
K = 4
|
| 245 |
+
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
|
| 246 |
+
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (1, 4, 192, 3136)
|
| 247 |
+
|
| 248 |
+
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
| 249 |
+
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
| 250 |
+
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
| 251 |
+
xs = xs.float().view(B, -1, L)
|
| 252 |
+
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
| 253 |
+
Bs = Bs.float().view(B, K, -1, L)
|
| 254 |
+
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
| 255 |
+
Ds = self.Ds.float().view(-1)
|
| 256 |
+
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)
|
| 257 |
+
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
| 258 |
+
out_y = self.selective_scan(
|
| 259 |
+
xs, dts,
|
| 260 |
+
As, Bs, Cs, Ds, z=None,
|
| 261 |
+
delta_bias=dt_projs_bias,
|
| 262 |
+
delta_softplus=True,
|
| 263 |
+
return_last_state=False,
|
| 264 |
+
).view(B, K, -1, L)
|
| 265 |
+
assert out_y.dtype == torch.float
|
| 266 |
+
|
| 267 |
+
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
| 268 |
+
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 269 |
+
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
| 270 |
+
|
| 271 |
+
return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
|
| 272 |
+
|
| 273 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
| 274 |
+
B, H, W, C = x.shape
|
| 275 |
+
|
| 276 |
+
xz = self.in_proj(x)
|
| 277 |
+
x, z = xz.chunk(2, dim=-1)
|
| 278 |
+
|
| 279 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 280 |
+
x = self.act(self.conv2d(x))
|
| 281 |
+
y1, y2, y3, y4 = self.forward_core(x)
|
| 282 |
+
assert y1.dtype == torch.float32
|
| 283 |
+
y = y1 + y2 + y3 + y4
|
| 284 |
+
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
| 285 |
+
y = self.out_norm(y)
|
| 286 |
+
y = y * F.silu(z)
|
| 287 |
+
out = self.out_proj(y)
|
| 288 |
+
if self.dropout is not None:
|
| 289 |
+
out = self.dropout(out)
|
| 290 |
+
return out
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class VSSBlock(nn.Module):
|
| 294 |
+
def __init__(
|
| 295 |
+
self,
|
| 296 |
+
hidden_dim: int = 0,
|
| 297 |
+
drop_path: float = 0,
|
| 298 |
+
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
| 299 |
+
attn_drop_rate: float = 0,
|
| 300 |
+
d_state: int = 16,
|
| 301 |
+
expand: float = 2.,
|
| 302 |
+
is_light_sr: bool = False,
|
| 303 |
+
**kwargs,
|
| 304 |
+
):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.ln_1 = norm_layer(hidden_dim)
|
| 307 |
+
self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs)
|
| 308 |
+
self.drop_path = DropPath(drop_path)
|
| 309 |
+
self.skip_scale= nn.Parameter(torch.ones(hidden_dim))
|
| 310 |
+
self.conv_blk = CAB(hidden_dim,is_light_sr)
|
| 311 |
+
self.ln_2 = nn.LayerNorm(hidden_dim)
|
| 312 |
+
self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim))
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def forward(self, input, x_size):
|
| 317 |
+
# x [B,HW,C]
|
| 318 |
+
B, L, C = input.shape
|
| 319 |
+
input = input.view(B, *x_size, C).contiguous() # [B,H,W,C]
|
| 320 |
+
x = self.ln_1(input)
|
| 321 |
+
x = input*self.skip_scale + self.drop_path(self.self_attention(x))
|
| 322 |
+
x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()
|
| 323 |
+
x = x.view(B, -1, C).contiguous()
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
class RoPE(nn.Module):
|
| 327 |
+
def __init__(self, embed_dim, num_heads):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.head_dim = embed_dim // num_heads
|
| 330 |
+
self.num_heads = num_heads
|
| 331 |
+
|
| 332 |
+
def forward(self, x_size):
|
| 333 |
+
H, W = x_size
|
| 334 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 335 |
+
pos_h = torch.arange(H, dtype=torch.float32, device=device)
|
| 336 |
+
pos_w = torch.arange(W, dtype=torch.float32, device=device)
|
| 337 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
|
| 338 |
+
sin_h = torch.sin(torch.einsum("i,j->ij", pos_h, inv_freq))
|
| 339 |
+
cos_h = torch.cos(torch.einsum("i,j->ij", pos_h, inv_freq))
|
| 340 |
+
sin_w = torch.sin(torch.einsum("i,j->ij", pos_w, inv_freq))
|
| 341 |
+
cos_w = torch.cos(torch.einsum("i,j->ij", pos_w, inv_freq))
|
| 342 |
+
sin = torch.einsum("i,j->ij", sin_h[:, 0], sin_w[:, 0]).unsqueeze(0).unsqueeze(0)
|
| 343 |
+
cos = torch.einsum("i,j->ij", cos_h[:, 0], cos_w[:, 0]).unsqueeze(0).unsqueeze(0)
|
| 344 |
+
sin = sin.expand(self.num_heads, -1, -1, -1).contiguous()
|
| 345 |
+
cos = cos.expand(self.num_heads, -1, -1, -1).contiguous()
|
| 346 |
+
return sin, cos
|
| 347 |
+
|
| 348 |
+
def rotate_every_two(x):
|
| 349 |
+
if x.shape[-1] % 2 != 0:
|
| 350 |
+
x = F.pad(x, (0, 1), mode='constant', value=0)
|
| 351 |
+
pad = True
|
| 352 |
+
else:
|
| 353 |
+
pad = False
|
| 354 |
+
|
| 355 |
+
x1 = x[..., ::2]
|
| 356 |
+
x2 = x[..., 1::2]
|
| 357 |
+
out = torch.stack((-x2, x1), -1).reshape(*x.shape[:-1], -1)
|
| 358 |
+
|
| 359 |
+
return out[..., :x.shape[-1]-1] if pad else out
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def theta_shift(x, sin, cos):
|
| 363 |
+
if sin.shape[-1] < x.shape[-1]:
|
| 364 |
+
pad = x.shape[-1] - sin.shape[-1]
|
| 365 |
+
sin = F.pad(sin, (0, pad), mode='constant', value=0)
|
| 366 |
+
cos = F.pad(cos, (0, pad), mode='constant', value=1)
|
| 367 |
+
elif sin.shape[-1] > x.shape[-1]:
|
| 368 |
+
sin = sin[..., :x.shape[-1]]
|
| 369 |
+
cos = cos[..., :x.shape[-1]]
|
| 370 |
+
|
| 371 |
+
return (x * cos) + (rotate_every_two(x) * sin)
|
| 372 |
+
|
| 373 |
+
class OverlapWindowAttention(nn.Module):
|
| 374 |
+
def __init__(self, dim, num_heads=4, window_size=7, shift_size=3):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.dim = dim
|
| 377 |
+
self.num_heads = num_heads
|
| 378 |
+
self.head_dim = dim // num_heads
|
| 379 |
+
self.scale = self.head_dim ** -0.5
|
| 380 |
+
self.window_size = window_size
|
| 381 |
+
self.shift_size = shift_size
|
| 382 |
+
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
|
| 383 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 384 |
+
|
| 385 |
+
def forward(self, x, sin, cos):
|
| 386 |
+
|
| 387 |
+
B, C, H, W = x.shape
|
| 388 |
+
ws = self.window_size
|
| 389 |
+
pad_h = (ws - H % ws) % ws
|
| 390 |
+
pad_w = (ws - W % ws) % ws
|
| 391 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
| 392 |
+
H_pad, W_pad = x.shape[2], x.shape[3]
|
| 393 |
+
|
| 394 |
+
if self.shift_size > 0:
|
| 395 |
+
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
|
| 396 |
+
|
| 397 |
+
qkv = self.qkv(x)
|
| 398 |
+
qkv = rearrange(qkv, 'b (m c) h w -> m b c h w', m=3)
|
| 399 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 400 |
+
q = q.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
|
| 401 |
+
k = k.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
|
| 402 |
+
v = v.view(B, self.num_heads, self.head_dim, H_pad, W_pad)
|
| 403 |
+
q = theta_shift(q, sin, cos) * self.scale
|
| 404 |
+
k = theta_shift(k, sin, cos)
|
| 405 |
+
|
| 406 |
+
q = q.view(B, C, H_pad, W_pad)
|
| 407 |
+
k = k.view(B, C, H_pad, W_pad)
|
| 408 |
+
v = v.view(B, C, H_pad, W_pad)
|
| 409 |
+
|
| 410 |
+
q = rearrange(q, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
|
| 411 |
+
k = rearrange(k, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
|
| 412 |
+
v = rearrange(v, 'b c (h ws1) (w ws2) -> b (h w) (ws1 ws2) c', ws1=ws, ws2=ws)
|
| 413 |
+
|
| 414 |
+
B, num_windows, window_len, C_new = q.shape
|
| 415 |
+
assert C_new % self.num_heads == 0, f"C_new={C_new} 不能整除 num_heads={self.num_heads}"
|
| 416 |
+
head_dim_new = C_new // self.num_heads
|
| 417 |
+
|
| 418 |
+
q = q.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
|
| 419 |
+
k = k.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
|
| 420 |
+
v = v.view(B, num_windows, window_len, self.num_heads, head_dim_new).transpose(2, 3)
|
| 421 |
+
|
| 422 |
+
attn = torch.softmax(q @ k.transpose(-2, -1), dim=-1)
|
| 423 |
+
out = (attn @ v).transpose(2, 3).reshape(B, num_windows, window_len, self.num_heads * head_dim_new)
|
| 424 |
+
out = rearrange(out, 'b (h w) (ws1 ws2) c -> b c (h ws1) (w ws2)', h=H_pad // ws, ws1=ws, ws2=ws, w=W_pad // ws)
|
| 425 |
+
|
| 426 |
+
if self.shift_size > 0:
|
| 427 |
+
out = torch.roll(out, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
|
| 428 |
+
|
| 429 |
+
out = out[:, :, :H, :W]
|
| 430 |
+
out = self.proj(out)
|
| 431 |
+
return out
|
| 432 |
+
|
| 433 |
+
class ShallowFusionAttnBlock(nn.Module):
|
| 434 |
+
def __init__(self, dim, num_heads=4, window_size=7, shift_size=3):
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.dim = dim
|
| 437 |
+
self.attn = OverlapWindowAttention(dim, num_heads=num_heads, window_size=window_size, shift_size=shift_size)
|
| 438 |
+
self.rope = RoPE(embed_dim=dim, num_heads=num_heads)
|
| 439 |
+
self.conv1 = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1)
|
| 440 |
+
self.conv2 = nn.Conv2d(dim * 2, dim, kernel_size=3, padding=1)
|
| 441 |
+
self.vss = VSSBlock(dim)
|
| 442 |
+
|
| 443 |
+
def patch_unembed(self, x, h, w):
|
| 444 |
+
return x.transpose(1, 2).reshape(x.size(0), -1, h, w)
|
| 445 |
+
|
| 446 |
+
def patch_embed(self, x):
|
| 447 |
+
return x.flatten(2).transpose(1, 2)
|
| 448 |
+
|
| 449 |
+
def forward(self, I1, I2, h, w):
|
| 450 |
+
B, C, H, W = I1.shape
|
| 451 |
+
|
| 452 |
+
diff = torch.abs(I1 - I2)
|
| 453 |
+
H_pad = (self.attn.window_size - h % self.attn.window_size) % self.attn.window_size + h
|
| 454 |
+
W_pad = (self.attn.window_size - w % self.attn.window_size) % self.attn.window_size + w
|
| 455 |
+
sin, cos = self.rope((H_pad, W_pad))
|
| 456 |
+
diff_attn = self.attn(diff, sin, cos)
|
| 457 |
+
token_attn = self.patch_embed(diff_attn) # [B, N, C]
|
| 458 |
+
I1_token = self.patch_embed(I1)
|
| 459 |
+
I2_token = self.patch_embed(I2)
|
| 460 |
+
I1 = I1_token + token_attn
|
| 461 |
+
I2 = I2_token + token_attn
|
| 462 |
+
|
| 463 |
+
I1_un = self.patch_unembed(I1, h, w)
|
| 464 |
+
I2_un = self.patch_unembed(I2, h, w)
|
| 465 |
+
|
| 466 |
+
I1_local = self.conv1(torch.cat([I1_un, I2_un], dim=1)) + I1_un
|
| 467 |
+
I2_local = self.conv2(torch.cat([I2_un, I1_un], dim=1)) + I2_un
|
| 468 |
+
|
| 469 |
+
I1_token = self.patch_embed(I1_local)
|
| 470 |
+
I2_token = self.patch_embed(I2_local)
|
| 471 |
+
vss_feat_1 = self.vss(I1_token, (h, w)).transpose(1, 2).view(B, C, h, w)
|
| 472 |
+
vss_feat_2 = self.vss(I2_token, (h, w)).transpose(1, 2).view(B, C, h, w)
|
| 473 |
+
|
| 474 |
+
I1_fuse = I1_local + vss_feat_1
|
| 475 |
+
I2_fuse = I2_local + vss_feat_2
|
| 476 |
+
return I1_fuse, I2_fuse
|
models/STNR.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from monai.networks.blocks.convolutions import Convolution
|
| 6 |
+
from monai.networks.blocks.segresnet_block import get_conv_layer, get_upsample_layer
|
| 7 |
+
from monai.networks.layers.factories import Dropout
|
| 8 |
+
from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
| 9 |
+
from monai.utils import UpsampleMode
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from models.mamba_customer import ConvMamba, M3, PatchEmbed, PatchUnEmbed
|
| 12 |
+
from models.Blocks import CAB, SAB, VSSBlock, ShallowFusionAttnBlock
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings("ignore")
|
| 15 |
+
|
| 16 |
+
def get_dwconv_layer(
|
| 17 |
+
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
| 18 |
+
bias: bool = False
|
| 19 |
+
):
|
| 20 |
+
depth_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels,
|
| 21 |
+
strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True, groups=in_channels)
|
| 22 |
+
point_conv = Convolution(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels,
|
| 23 |
+
strides=stride, kernel_size=1, bias=bias, conv_only=True, groups=1)
|
| 24 |
+
return torch.nn.Sequential(depth_conv, point_conv)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SRCMLayer(nn.Module):
|
| 28 |
+
def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2, conv_mode='deepwise'):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.input_dim = input_dim
|
| 31 |
+
self.output_dim = output_dim
|
| 32 |
+
self.norm = nn.LayerNorm(input_dim)
|
| 33 |
+
self.convmamba = ConvMamba(
|
| 34 |
+
d_model=input_dim,
|
| 35 |
+
d_state=d_state,
|
| 36 |
+
d_conv=d_conv,
|
| 37 |
+
expand=expand,
|
| 38 |
+
bimamba_type="v2",
|
| 39 |
+
conv_mode=conv_mode
|
| 40 |
+
)
|
| 41 |
+
self.proj = nn.Linear(input_dim, output_dim)
|
| 42 |
+
self.skip_scale = nn.Parameter(torch.ones(1))
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
if x.dtype == torch.float16:
|
| 46 |
+
x = x.type(torch.float32)
|
| 47 |
+
B, C = x.shape[:2]
|
| 48 |
+
assert C == self.input_dim
|
| 49 |
+
n_tokens = x.shape[2:].numel()
|
| 50 |
+
img_dims = x.shape[2:]
|
| 51 |
+
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
|
| 52 |
+
x_norm = self.norm(x_flat)
|
| 53 |
+
x_mamba = self.convmamba(x_norm) + self.skip_scale * x_flat
|
| 54 |
+
x_mamba = self.norm(x_mamba)
|
| 55 |
+
x_mamba = self.proj(x_mamba)
|
| 56 |
+
out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_srcm_layer(
|
| 61 |
+
spatial_dims: int, in_channels: int, out_channels: int, stride: int = 1, conv_mode: str = "deepwise"
|
| 62 |
+
):
|
| 63 |
+
srcm_layer = SRCMLayer(input_dim=in_channels, output_dim=out_channels, conv_mode=conv_mode)
|
| 64 |
+
if stride != 1:
|
| 65 |
+
if spatial_dims == 2:
|
| 66 |
+
return nn.Sequential(srcm_layer, nn.MaxPool2d(kernel_size=stride, stride=stride))
|
| 67 |
+
return srcm_layer
|
| 68 |
+
|
| 69 |
+
class SRCMBlock(nn.Module):
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
spatial_dims: int,
|
| 74 |
+
in_channels: int,
|
| 75 |
+
norm: tuple | str,
|
| 76 |
+
kernel_size: int = 3,
|
| 77 |
+
conv_mode: str = "deepwise",
|
| 78 |
+
act: tuple | str = ("RELU", {"inplace": True}),
|
| 79 |
+
) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Args:
|
| 82 |
+
spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
|
| 83 |
+
in_channels: number of input channels.
|
| 84 |
+
norm: feature normalization type and arguments.
|
| 85 |
+
kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
|
| 86 |
+
act: activation type and arguments. Defaults to ``RELU``.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
if kernel_size % 2 != 1:
|
| 92 |
+
raise AssertionError("kernel_size should be an odd number.")
|
| 93 |
+
# print(conv_mode)
|
| 94 |
+
self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
| 95 |
+
self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
|
| 96 |
+
self.act = get_act_layer(act)
|
| 97 |
+
self.conv1 = get_srcm_layer(
|
| 98 |
+
spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
| 99 |
+
)
|
| 100 |
+
self.conv2 = get_srcm_layer(
|
| 101 |
+
spatial_dims, in_channels=in_channels, out_channels=in_channels, conv_mode=conv_mode
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
identity = x
|
| 106 |
+
|
| 107 |
+
x = self.norm1(x)
|
| 108 |
+
x = self.act(x)
|
| 109 |
+
x = self.conv1(x)
|
| 110 |
+
|
| 111 |
+
x = self.norm2(x)
|
| 112 |
+
x = self.act(x)
|
| 113 |
+
x = self.conv2(x)
|
| 114 |
+
|
| 115 |
+
x += identity
|
| 116 |
+
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class CSI(nn.Module):
|
| 121 |
+
def __init__(self, dim):
|
| 122 |
+
super(CSI, self).__init__()
|
| 123 |
+
self.shallow_fusion_attn = ShallowFusionAttnBlock(dim)
|
| 124 |
+
self.m3 = M3(dim)
|
| 125 |
+
self.vss = VSSBlock(hidden_dim=dim)
|
| 126 |
+
self.patch_embed = PatchEmbed(in_chans=dim, embed_dim=dim)
|
| 127 |
+
self.patch_unembed = PatchUnEmbed(in_chans=dim, embed_dim=dim)
|
| 128 |
+
def forward(self, I1, I2, h, w):
|
| 129 |
+
I1_fuse, I2_fuse = self.shallow_fusion_attn(I1, I2, h, w)
|
| 130 |
+
fusion = torch.abs(I1_fuse - I2_fuse)
|
| 131 |
+
I1_token = self.patch_embed(I1_fuse)
|
| 132 |
+
I2_token = self.patch_embed(I2_fuse)
|
| 133 |
+
fusion_token = self.patch_embed(fusion)
|
| 134 |
+
test_h, test_w = fusion.shape[2], fusion.shape[3]
|
| 135 |
+
fusion_token, _ = self.m3(I1_token, I2_token, fusion_token, test_h, test_w)
|
| 136 |
+
fusion_out = self.patch_unembed(fusion_token, (h, w))
|
| 137 |
+
return fusion_out
|
| 138 |
+
|
| 139 |
+
class STNR(nn.Module):
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
spatial_dims: int = 2,
|
| 143 |
+
init_filters: int = 16,
|
| 144 |
+
in_channels: int = 1,
|
| 145 |
+
out_channels: int = 2,
|
| 146 |
+
conv_mode: str = "deepwise",
|
| 147 |
+
local_query_model = "orignal_dinner",
|
| 148 |
+
dropout_prob: float | None = None,
|
| 149 |
+
act: tuple | str = ("RELU", {"inplace": True}),
|
| 150 |
+
norm: tuple | str = ("GROUP", {"num_groups": 8}),
|
| 151 |
+
norm_name: str = "",
|
| 152 |
+
num_groups: int = 8,
|
| 153 |
+
use_conv_final: bool = True,
|
| 154 |
+
blocks_down: tuple = (1, 2, 2, 4),
|
| 155 |
+
blocks_up: tuple = (1, 1, 1),
|
| 156 |
+
mode: str = "",
|
| 157 |
+
up_mode="ResMamba",
|
| 158 |
+
up_conv_mode="deepwise",
|
| 159 |
+
resdiual=False,
|
| 160 |
+
stage = 4,
|
| 161 |
+
diff_abs="later",
|
| 162 |
+
mamba_act = "silu",
|
| 163 |
+
upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
if spatial_dims not in (2, 3):
|
| 168 |
+
raise ValueError("`spatial_dims` can only be 2 or 3.")
|
| 169 |
+
self.mode = mode
|
| 170 |
+
self.stage = stage
|
| 171 |
+
self.up_conv_mode = up_conv_mode
|
| 172 |
+
self.mamba_act = mamba_act
|
| 173 |
+
self.resdiual = resdiual
|
| 174 |
+
self.up_mode = up_mode
|
| 175 |
+
self.diff_abs = diff_abs
|
| 176 |
+
self.conv_mode = conv_mode
|
| 177 |
+
self.local_query_model = local_query_model
|
| 178 |
+
self.spatial_dims = spatial_dims
|
| 179 |
+
self.init_filters = init_filters
|
| 180 |
+
self.channels_list = [self.init_filters, self.init_filters*2, self.init_filters*4, self.init_filters*8]
|
| 181 |
+
self.in_channels = in_channels
|
| 182 |
+
self.blocks_down = blocks_down
|
| 183 |
+
self.blocks_up = blocks_up
|
| 184 |
+
print(self.blocks_up)
|
| 185 |
+
self.dropout_prob = dropout_prob
|
| 186 |
+
self.act = act # input options
|
| 187 |
+
self.act_mod = get_act_layer(act)
|
| 188 |
+
if norm_name:
|
| 189 |
+
if norm_name.lower() != "group":
|
| 190 |
+
raise ValueError(f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.")
|
| 191 |
+
norm = ("group", {"num_groups": num_groups})
|
| 192 |
+
self.norm = norm
|
| 193 |
+
print(self.norm)
|
| 194 |
+
self.upsample_mode = UpsampleMode(upsample_mode)
|
| 195 |
+
self.use_conv_final = use_conv_final
|
| 196 |
+
self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
|
| 197 |
+
self.srcm_encoder_layers = self._make_srcm_encoder_layers()
|
| 198 |
+
self.srcm_decoder_layers, self.up_samples = self._make_srcm_decoder_layers(up_mode=self.up_mode)
|
| 199 |
+
self.conv_final = self._make_final_conv(out_channels)
|
| 200 |
+
self.fusion_blocks = nn.ModuleList(
|
| 201 |
+
[CSI(self.channels_list[i]) for i in range(self.stage)]
|
| 202 |
+
)
|
| 203 |
+
self.cab_layers = nn.ModuleList([
|
| 204 |
+
CAB(ch) for ch in self.channels_list[::-1][1:]
|
| 205 |
+
])
|
| 206 |
+
self.sab_layers = nn.ModuleList([
|
| 207 |
+
SAB(kernel_size=7) for _ in range(len(self.blocks_up))
|
| 208 |
+
])
|
| 209 |
+
self.conv_down_layers = nn.ModuleList([
|
| 210 |
+
nn.Conv2d(ch * 2, ch, kernel_size=1, stride=1, padding=0) for ch in self.channels_list[::-1][1:]
|
| 211 |
+
])
|
| 212 |
+
if dropout_prob is not None:
|
| 213 |
+
self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
|
| 214 |
+
|
| 215 |
+
def _make_srcm_encoder_layers(self):
|
| 216 |
+
srcm_encoder_layers = nn.ModuleList()
|
| 217 |
+
blocks_down, spatial_dims, filters, norm, conv_mode = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm, self.conv_mode)
|
| 218 |
+
for i, item in enumerate(blocks_down):
|
| 219 |
+
layer_in_channels = filters * 2 ** i
|
| 220 |
+
downsample_mamba = (
|
| 221 |
+
get_srcm_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2, conv_mode=conv_mode)
|
| 222 |
+
if i > 0
|
| 223 |
+
else nn.Identity()
|
| 224 |
+
)
|
| 225 |
+
down_layer = nn.Sequential(
|
| 226 |
+
downsample_mamba,
|
| 227 |
+
*[SRCMBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act, conv_mode=conv_mode) for _ in range(item)]
|
| 228 |
+
)
|
| 229 |
+
srcm_encoder_layers.append(down_layer)
|
| 230 |
+
return srcm_encoder_layers
|
| 231 |
+
|
| 232 |
+
def _make_srcm_decoder_layers(self, up_mode):
|
| 233 |
+
srcm_decoder_layers, up_samples = nn.ModuleList(), nn.ModuleList()
|
| 234 |
+
upsample_mode, blocks_up, spatial_dims, filters, norm = (
|
| 235 |
+
self.upsample_mode,
|
| 236 |
+
self.blocks_up,
|
| 237 |
+
self.spatial_dims,
|
| 238 |
+
self.init_filters,
|
| 239 |
+
self.norm,
|
| 240 |
+
)
|
| 241 |
+
if up_mode == 'SRCM':
|
| 242 |
+
Block_up = SRCMBlock
|
| 243 |
+
n_up = len(blocks_up)
|
| 244 |
+
for i in range(n_up):
|
| 245 |
+
sample_in_channels = filters * 2 ** (n_up - i)
|
| 246 |
+
srcm_decoder_layers.append(
|
| 247 |
+
nn.Sequential(
|
| 248 |
+
*[
|
| 249 |
+
Block_up(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act, conv_mode=self.up_conv_mode)
|
| 250 |
+
for _ in range(blocks_up[i])
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
up_samples.append(
|
| 255 |
+
nn.Sequential(
|
| 256 |
+
*[
|
| 257 |
+
get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1),
|
| 258 |
+
get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode),
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
return srcm_decoder_layers, up_samples
|
| 263 |
+
|
| 264 |
+
def _make_final_conv(self, out_channels: int):
|
| 265 |
+
return nn.Sequential(
|
| 266 |
+
get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),
|
| 267 |
+
self.act_mod,
|
| 268 |
+
get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 272 |
+
x = self.convInit(x)
|
| 273 |
+
if self.dropout_prob is not None:
|
| 274 |
+
x = self.dropout(x)
|
| 275 |
+
down_x = []
|
| 276 |
+
|
| 277 |
+
for down in self.srcm_encoder_layers:
|
| 278 |
+
x = down(x)
|
| 279 |
+
down_x.append(x)
|
| 280 |
+
|
| 281 |
+
return x, down_x
|
| 282 |
+
|
| 283 |
+
def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor:
|
| 284 |
+
for i, (up, upl) in enumerate(zip(self.up_samples, self.srcm_decoder_layers)):
|
| 285 |
+
skip = down_x[i + 1]
|
| 286 |
+
x_up = up(x) + skip
|
| 287 |
+
x_cab = self.cab_layers[i](x_up) * x_up
|
| 288 |
+
x_sab = self.sab_layers[i](x_cab) * x_cab
|
| 289 |
+
x_srcm = upl(x_up)
|
| 290 |
+
combined_out = torch.cat([x_sab, x_srcm], dim=1)
|
| 291 |
+
final_out = self.conv_down_layers[i](combined_out)
|
| 292 |
+
x = final_out
|
| 293 |
+
if self.use_conv_final:
|
| 294 |
+
x = self.conv_final(x)
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
def forward(self, x1: torch.Tensor, x2:torch.Tensor) -> torch.Tensor:
|
| 298 |
+
b, c, h, w = x1.shape
|
| 299 |
+
x1, down_x1 = self.encode(x1)
|
| 300 |
+
x2, down_x2 = self.encode(x2)
|
| 301 |
+
down_x = []
|
| 302 |
+
for i in range(len(down_x1)):
|
| 303 |
+
x1_level, x2_level = down_x1[i], down_x2[i]
|
| 304 |
+
H_i, W_i = x1_level.shape[2], x1_level.shape[3]
|
| 305 |
+
if self.diff_abs == "later":
|
| 306 |
+
if self.mode == "FUSION":
|
| 307 |
+
if i < self.stage:
|
| 308 |
+
zero_res = torch.zeros_like(x1_level)
|
| 309 |
+
fusion = self.fusion_blocks[i](x1_level, x2_level, H_i, W_i)
|
| 310 |
+
else:
|
| 311 |
+
fusion = torch.abs(x1_level - x2_level)
|
| 312 |
+
else:
|
| 313 |
+
fusion = torch.abs(x1_level - x2_level)
|
| 314 |
+
down_x.append(fusion)
|
| 315 |
+
down_x.reverse()
|
| 316 |
+
x = self.decode(down_x[0], down_x)
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
device = "cuda:0"
|
| 321 |
+
CDMamba = STNR(spatial_dims=2, in_channels=3, out_channels=2, init_filters=16, norm=("GROUP", {"num_groups": 8}),
|
| 322 |
+
mode="FUSION", conv_mode='orignal', local_query_model="orignal_dinner",
|
| 323 |
+
stage=4, mamba_act="silu", up_mode="SRCM", up_conv_mode='deepwise', blocks_down=(1, 2, 2, 4), blocks_up=(1, 1, 1),
|
| 324 |
+
resdiual=False, diff_abs="later").to(device)
|
| 325 |
+
x = torch.randn(1, 3, 256, 256).to(device)
|
| 326 |
+
y = CDMamba(x, x)
|
| 327 |
+
print(y.shape)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .resnet import *
|
| 2 |
+
import logging
|
| 3 |
+
logger = logging.getLogger('base')
|
| 4 |
+
|
| 5 |
+
def create_CD_model(opt):
|
| 6 |
+
# Our CDMamba model
|
| 7 |
+
from models.STNR import STNR as stnr
|
| 8 |
+
|
| 9 |
+
if opt['model']['name'] == 'STNR':
|
| 10 |
+
cd_model = stnr(spatial_dims=opt['model']['spatial_dims'], in_channels=opt['model']['in_channels'], init_filters=opt['model']['init_filters'], out_channels=opt['model']['n_classes'],
|
| 11 |
+
mode=opt['model']['mode'], conv_mode=opt['model']['conv_mode'], up_mode=opt['model']['up_mode'], up_conv_mode=opt['model']['up_conv_mode'], norm=opt['model']['norm'],
|
| 12 |
+
blocks_down=opt['model']['blocks_down'], blocks_up=opt['model']['blocks_up'], resdiual=opt['model']['resdiual'], diff_abs=opt['model']['diff_abs'], stage=opt['model']['stage'],
|
| 13 |
+
mamba_act=opt['model']['mamba_act'], local_query_model=opt['model']['local_query_model'])
|
models/loss.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor, einsum
|
| 4 |
+
import torch.nn .functional as F
|
| 5 |
+
from misc.torchutils import class2one_hot,simplex
|
| 6 |
+
from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss
|
| 7 |
+
|
| 8 |
+
def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
|
| 9 |
+
"""
|
| 10 |
+
logSoftmax_with_loss
|
| 11 |
+
:param input: torch.Tensor, N*C*H*W
|
| 12 |
+
:param target: torch.Tensor, N*1*H*W,/ N*H*W
|
| 13 |
+
:param weight: torch.Tensor, C
|
| 14 |
+
:return: torch.Tensor [0]
|
| 15 |
+
"""
|
| 16 |
+
target = target.long()
|
| 17 |
+
if target.dim() == 4:
|
| 18 |
+
target = torch.squeeze(target, dim=1)
|
| 19 |
+
if input.shape[-1] != target.shape[-1]:
|
| 20 |
+
input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
|
| 21 |
+
|
| 22 |
+
return F.cross_entropy(input=input, target=target, weight=weight,
|
| 23 |
+
ignore_index=ignore_index, reduction=reduction)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def dice_loss(predicts,target,weight=None):
|
| 27 |
+
idc= [0, 1]
|
| 28 |
+
probs = torch.softmax(predicts, dim=1)
|
| 29 |
+
# target = target.unsqueeze(1)
|
| 30 |
+
target = class2one_hot(target, 7)
|
| 31 |
+
assert simplex(probs) and simplex(target)
|
| 32 |
+
|
| 33 |
+
pc = probs[:, idc, ...].type(torch.float32)
|
| 34 |
+
tc = target[:, idc, ...].type(torch.float32)
|
| 35 |
+
intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
|
| 36 |
+
union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))
|
| 37 |
+
|
| 38 |
+
divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10)
|
| 39 |
+
|
| 40 |
+
loss = divided.mean()
|
| 41 |
+
return loss
|
| 42 |
+
|
| 43 |
+
def ce_dice(input, target, weight=None):
|
| 44 |
+
ce_loss = cross_entropy(input, target)
|
| 45 |
+
dice_loss_ = dice_loss(input, target)
|
| 46 |
+
loss = 0.5 * ce_loss + 0.5 * dice_loss_
|
| 47 |
+
return loss
|
| 48 |
+
|
| 49 |
+
def dice(input, target, weight=None):
|
| 50 |
+
dice_loss_ = dice_loss(input, target)
|
| 51 |
+
return dice_loss_
|
| 52 |
+
|
| 53 |
+
def ce2_dice1(input, target, weight=None):
|
| 54 |
+
ce_loss = cross_entropy(input, target)
|
| 55 |
+
dice_loss_ = dice_loss(input, target)
|
| 56 |
+
loss = ce_loss + 0.5 * dice_loss_
|
| 57 |
+
return loss
|
| 58 |
+
|
| 59 |
+
def ce1_dice2(input, target, weight=None):
|
| 60 |
+
ce_loss = cross_entropy(input, target)
|
| 61 |
+
dice_loss_ = dice_loss(input, target)
|
| 62 |
+
loss = 0.5 * ce_loss + dice_loss_
|
| 63 |
+
return loss
|
| 64 |
+
|
| 65 |
+
def ce_scl(input, target, weight=None):
|
| 66 |
+
ce_loss = cross_entropy(input, target)
|
| 67 |
+
dice_loss_ = dice_loss(input, target)
|
| 68 |
+
loss = 0.5 * ce_loss + 0.5 * dice_loss_
|
| 69 |
+
return loss
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75):
|
| 73 |
+
logit = logit_pixel.view(-1)
|
| 74 |
+
truth = truth_pixel.view(-1)
|
| 75 |
+
assert (logit.shape == truth.shape)
|
| 76 |
+
|
| 77 |
+
loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none')
|
| 78 |
+
|
| 79 |
+
pos = (truth > 0.5).float()
|
| 80 |
+
neg = (truth < 0.5).float()
|
| 81 |
+
pos_num = pos.sum().item() + 1e-12
|
| 82 |
+
neg_num = neg.sum().item() + 1e-12
|
| 83 |
+
loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum()
|
| 84 |
+
|
| 85 |
+
return loss
|
| 86 |
+
|
| 87 |
+
class ChangeSimilarity(nn.Module):
|
| 88 |
+
"""input: x1, x2 multi-class predictions, c = class_num
|
| 89 |
+
label_change: changed part
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, reduction='mean'):
|
| 93 |
+
super(ChangeSimilarity, self).__init__()
|
| 94 |
+
self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction)
|
| 95 |
+
|
| 96 |
+
def forward(self, x1, x2, label_change):
|
| 97 |
+
b, c, h, w = x1.size()
|
| 98 |
+
x1 = F.softmax(x1, dim=1)
|
| 99 |
+
x2 = F.softmax(x2, dim=1)
|
| 100 |
+
x1 = x1.permute(0, 2, 3, 1)
|
| 101 |
+
x2 = x2.permute(0, 2, 3, 1)
|
| 102 |
+
x1 = torch.reshape(x1, [b * h * w, c])
|
| 103 |
+
x2 = torch.reshape(x2, [b * h * w, c])
|
| 104 |
+
|
| 105 |
+
label_unchange = ~label_change.bool()
|
| 106 |
+
target = label_unchange.float()
|
| 107 |
+
target = target - label_change.float()
|
| 108 |
+
target = torch.reshape(target, [b * h * w])
|
| 109 |
+
|
| 110 |
+
loss = self.loss_f(x1, x2, target)
|
| 111 |
+
return loss
|
| 112 |
+
|
| 113 |
+
def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]):
|
| 114 |
+
"""Calculating the loss"""
|
| 115 |
+
loss = 0
|
| 116 |
+
|
| 117 |
+
# gamma=0, alpha=None --> CE
|
| 118 |
+
# focal = FocalLoss(gamma=0, alpha=None)
|
| 119 |
+
# ssim = SSIM()
|
| 120 |
+
|
| 121 |
+
for i,prediction in enumerate(predictions):
|
| 122 |
+
|
| 123 |
+
bce = cross_entropy(prediction, target)
|
| 124 |
+
dice = dice_loss(prediction, target)
|
| 125 |
+
# ssimloss = ssim(prediction, target)
|
| 126 |
+
loss += weight[i]*(bce + dice) #- ssimloss
|
| 127 |
+
|
| 128 |
+
return loss
|
| 129 |
+
|
| 130 |
+
class BCL(nn.Module):
|
| 131 |
+
"""
|
| 132 |
+
batch-balanced contrastive loss
|
| 133 |
+
no-change,1
|
| 134 |
+
change,-1
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self, margin=2.0):
|
| 137 |
+
super(BCL, self).__init__()
|
| 138 |
+
self.margin = margin
|
| 139 |
+
|
| 140 |
+
def forward(self, distance, label):
|
| 141 |
+
label[label == 1] = -1
|
| 142 |
+
label[label == 0] = 1
|
| 143 |
+
|
| 144 |
+
mask = (label != 255).float()
|
| 145 |
+
distance = distance * mask
|
| 146 |
+
|
| 147 |
+
pos_num = torch.sum((label==1).float())+0.0001
|
| 148 |
+
neg_num = torch.sum((label==-1).float())+0.0001
|
| 149 |
+
|
| 150 |
+
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
|
| 151 |
+
loss_2 = torch.sum((1-label) / 2 *
|
| 152 |
+
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
|
| 153 |
+
) / neg_num
|
| 154 |
+
loss = loss_1 + loss_2
|
| 155 |
+
return loss
|
models/mamba_customer.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
| 2 |
+
import numbers
|
| 3 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings("ignore")
|
| 6 |
+
|
| 7 |
+
from timm.models.layers import DropPath, to_2tuple
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
|
| 17 |
+
from einops import rearrange, repeat
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 21 |
+
except ImportError:
|
| 22 |
+
causal_conv1d_fn, causal_conv1d_update = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
|
| 26 |
+
except ImportError:
|
| 27 |
+
selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 31 |
+
except ImportError:
|
| 32 |
+
selective_state_update = None
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 36 |
+
except ImportError:
|
| 37 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 38 |
+
|
| 39 |
+
class LightweightModel(nn.Module):
|
| 40 |
+
def __init__(self, in_channels, out_channels):
|
| 41 |
+
super(LightweightModel, self).__init__()
|
| 42 |
+
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels)
|
| 43 |
+
self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.depthwise_conv(x)
|
| 47 |
+
x = self.pointwise_conv(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvMamba(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
d_model,
|
| 55 |
+
d_state=16,
|
| 56 |
+
d_conv=4,
|
| 57 |
+
expand=2,
|
| 58 |
+
dt_rank="auto",
|
| 59 |
+
dt_min=0.001,
|
| 60 |
+
dt_max=0.1,
|
| 61 |
+
dt_init="random",
|
| 62 |
+
dt_scale=1.0,
|
| 63 |
+
dt_init_floor=1e-4,
|
| 64 |
+
conv_bias=True,
|
| 65 |
+
bias=False,
|
| 66 |
+
use_fast_path=True,
|
| 67 |
+
layer_idx=None,
|
| 68 |
+
device=None,
|
| 69 |
+
dtype=None,
|
| 70 |
+
bimamba_type="none",
|
| 71 |
+
conv_mode = "deepwise"
|
| 72 |
+
):
|
| 73 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.conv_mode = conv_mode
|
| 76 |
+
self.d_model = d_model
|
| 77 |
+
self.d_state = d_state
|
| 78 |
+
self.d_conv = d_conv
|
| 79 |
+
self.expand = expand
|
| 80 |
+
self.d_inner = int(self.expand * self.d_model)
|
| 81 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 82 |
+
self.use_fast_path = use_fast_path
|
| 83 |
+
self.layer_idx = layer_idx
|
| 84 |
+
self.bimamba_type = bimamba_type
|
| 85 |
+
|
| 86 |
+
if self.conv_mode == "orignal":
|
| 87 |
+
self.local_relation = nn.Sequential(
|
| 88 |
+
nn.Conv2d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, stride=1, padding=1),
|
| 89 |
+
nn.SiLU(),
|
| 90 |
+
nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
|
| 91 |
+
)
|
| 92 |
+
elif self.conv_mode == "orignal_1_5_dmodel":
|
| 93 |
+
self.local_relation = nn.Sequential(
|
| 94 |
+
nn.Conv2d(in_channels=self.d_model, out_channels=int(1.5*self.d_model), kernel_size=3, stride=1, padding=1),
|
| 95 |
+
nn.SiLU(),
|
| 96 |
+
nn.Conv2d(in_channels=int(1.5*self.d_model), out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
|
| 97 |
+
)
|
| 98 |
+
elif self.conv_mode == "orignal_dinner":
|
| 99 |
+
self.local_relation = nn.Sequential(
|
| 100 |
+
nn.Conv2d(in_channels=self.d_model, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
|
| 101 |
+
nn.SiLU(),
|
| 102 |
+
nn.Conv2d(in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=3, stride=1, padding=1),
|
| 103 |
+
)
|
| 104 |
+
elif self.conv_mode == "deepwise":
|
| 105 |
+
self.local_relation = nn.Sequential(
|
| 106 |
+
LightweightModel(in_channels=self.d_model, out_channels=self.d_model),
|
| 107 |
+
nn.SiLU(),
|
| 108 |
+
LightweightModel(in_channels=self.d_model, out_channels=self.d_inner),
|
| 109 |
+
)
|
| 110 |
+
elif self.conv_mode == "deepwise_dinner":
|
| 111 |
+
self.local_relation = nn.Sequential(
|
| 112 |
+
LightweightModel(in_channels=self.d_model, out_channels=self.d_inner),
|
| 113 |
+
nn.SiLU(),
|
| 114 |
+
LightweightModel(in_channels=self.d_inner, out_channels=self.d_inner),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
| 118 |
+
|
| 119 |
+
self.conv1d = nn.Conv1d(
|
| 120 |
+
in_channels=self.d_inner,
|
| 121 |
+
out_channels=self.d_inner,
|
| 122 |
+
bias=conv_bias,
|
| 123 |
+
kernel_size=d_conv,
|
| 124 |
+
groups=self.d_inner,
|
| 125 |
+
padding=d_conv - 1,
|
| 126 |
+
**factory_kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.activation = "silu"
|
| 130 |
+
self.act = nn.SiLU()
|
| 131 |
+
|
| 132 |
+
self.x_proj = nn.Linear(
|
| 133 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 134 |
+
)
|
| 135 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 136 |
+
|
| 137 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 138 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 139 |
+
if dt_init == "constant":
|
| 140 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 141 |
+
elif dt_init == "random":
|
| 142 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 143 |
+
else:
|
| 144 |
+
raise NotImplementedError
|
| 145 |
+
|
| 146 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 147 |
+
dt = torch.exp(
|
| 148 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
| 149 |
+
+ math.log(dt_min)
|
| 150 |
+
).clamp(min=dt_init_floor)
|
| 151 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 152 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 155 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 156 |
+
self.dt_proj.bias._no_reinit = True
|
| 157 |
+
|
| 158 |
+
# S4D real initialization
|
| 159 |
+
A = repeat(
|
| 160 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 161 |
+
"n -> d n",
|
| 162 |
+
d=self.d_inner,
|
| 163 |
+
).contiguous()
|
| 164 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 165 |
+
self.A_log = nn.Parameter(A_log)
|
| 166 |
+
self.A_log._no_weight_decay = True
|
| 167 |
+
|
| 168 |
+
# D "skip" parameter
|
| 169 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 170 |
+
self.D._no_weight_decay = True
|
| 171 |
+
|
| 172 |
+
# bidirectional
|
| 173 |
+
assert bimamba_type == "v2"
|
| 174 |
+
|
| 175 |
+
A_b = repeat(
|
| 176 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 177 |
+
"n -> d n",
|
| 178 |
+
d=self.d_inner,
|
| 179 |
+
).contiguous()
|
| 180 |
+
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
| 181 |
+
self.A_b_log = nn.Parameter(A_b_log)
|
| 182 |
+
self.A_b_log._no_weight_decay = True
|
| 183 |
+
|
| 184 |
+
self.conv1d_b = nn.Conv1d(
|
| 185 |
+
in_channels=self.d_inner,
|
| 186 |
+
out_channels=self.d_inner,
|
| 187 |
+
bias=conv_bias,
|
| 188 |
+
kernel_size=d_conv,
|
| 189 |
+
groups=self.d_inner,
|
| 190 |
+
padding=d_conv - 1,
|
| 191 |
+
**factory_kwargs,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.x_proj_b = nn.Linear(
|
| 195 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
| 196 |
+
)
|
| 197 |
+
self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
| 198 |
+
|
| 199 |
+
self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 200 |
+
self.D_b._no_weight_decay = True
|
| 201 |
+
|
| 202 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 203 |
+
|
| 204 |
+
def forward(self, hidden_states, inference_params=None):
|
| 205 |
+
"""
|
| 206 |
+
hidden_states: (B, L, D)
|
| 207 |
+
Returns: same shape as hidden_states
|
| 208 |
+
"""
|
| 209 |
+
batch, seqlen, dim = hidden_states.shape
|
| 210 |
+
h = int(math.sqrt(seqlen))
|
| 211 |
+
|
| 212 |
+
local_relation = self.local_relation(rearrange(hidden_states, "b (h w) d -> b d h w", h=h))
|
| 213 |
+
local_relation = rearrange(local_relation, "b d h w -> b d (h w)")
|
| 214 |
+
|
| 215 |
+
conv_state, ssm_state = None, None
|
| 216 |
+
if inference_params is not None:
|
| 217 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 218 |
+
if inference_params.seqlen_offset > 0:
|
| 219 |
+
# The states are updated inplace
|
| 220 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 221 |
+
return out
|
| 222 |
+
|
| 223 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
| 224 |
+
xz = rearrange(
|
| 225 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
| 226 |
+
"d (b l) -> b d l",
|
| 227 |
+
l=seqlen,
|
| 228 |
+
)
|
| 229 |
+
if self.in_proj.bias is not None:
|
| 230 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
| 231 |
+
|
| 232 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 233 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
| 234 |
+
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
|
| 235 |
+
if self.bimamba_type == "v2":
|
| 236 |
+
A_b = -torch.exp(self.A_b_log.float())
|
| 237 |
+
out = mamba_inner_fn_no_out_proj(
|
| 238 |
+
xz,
|
| 239 |
+
self.conv1d.weight,
|
| 240 |
+
self.conv1d.bias,
|
| 241 |
+
self.x_proj.weight,
|
| 242 |
+
self.dt_proj.weight,
|
| 243 |
+
A,
|
| 244 |
+
None, # input-dependent B
|
| 245 |
+
None, # input-dependent C
|
| 246 |
+
self.D.float(),
|
| 247 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 248 |
+
delta_softplus=True,
|
| 249 |
+
)
|
| 250 |
+
out_b = mamba_inner_fn_no_out_proj(
|
| 251 |
+
xz.flip([-1]),
|
| 252 |
+
self.conv1d_b.weight,
|
| 253 |
+
self.conv1d_b.bias,
|
| 254 |
+
self.x_proj_b.weight,
|
| 255 |
+
self.dt_proj_b.weight,
|
| 256 |
+
A_b,
|
| 257 |
+
None,
|
| 258 |
+
None,
|
| 259 |
+
self.D_b.float(),
|
| 260 |
+
delta_bias=self.dt_proj_b.bias.float(),
|
| 261 |
+
delta_softplus=True,
|
| 262 |
+
)
|
| 263 |
+
# F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
| 264 |
+
out = F.linear(rearrange(out + out_b.flip([-1]) + local_relation, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
|
| 265 |
+
else:
|
| 266 |
+
out = mamba_inner_fn(
|
| 267 |
+
xz,
|
| 268 |
+
self.conv1d.weight,
|
| 269 |
+
self.conv1d.bias,
|
| 270 |
+
self.x_proj.weight,
|
| 271 |
+
self.dt_proj.weight,
|
| 272 |
+
self.out_proj.weight,
|
| 273 |
+
self.out_proj.bias,
|
| 274 |
+
A,
|
| 275 |
+
None, # input-dependent B
|
| 276 |
+
None, # input-dependent C
|
| 277 |
+
self.D.float(),
|
| 278 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 279 |
+
delta_softplus=True,
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
x, z = xz.chunk(2, dim=1)
|
| 283 |
+
# Compute short convolution
|
| 284 |
+
if conv_state is not None:
|
| 285 |
+
conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
|
| 286 |
+
if causal_conv1d_fn is None:
|
| 287 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
| 288 |
+
else:
|
| 289 |
+
assert self.activation in ["silu", "swish"]
|
| 290 |
+
x = causal_conv1d_fn(
|
| 291 |
+
x,
|
| 292 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 293 |
+
self.conv1d.bias,
|
| 294 |
+
self.activation,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# We're careful here about the layout, to avoid extra transposes.
|
| 298 |
+
# We want dt to have d as the slowest moving dimension
|
| 299 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
| 300 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
| 301 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 302 |
+
dt = self.dt_proj.weight @ dt.t()
|
| 303 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
| 304 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 305 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
| 306 |
+
assert self.activation in ["silu", "swish"]
|
| 307 |
+
y = selective_scan_fn(
|
| 308 |
+
x,
|
| 309 |
+
dt,
|
| 310 |
+
A,
|
| 311 |
+
B,
|
| 312 |
+
C,
|
| 313 |
+
self.D.float(),
|
| 314 |
+
z=z,
|
| 315 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 316 |
+
delta_softplus=True,
|
| 317 |
+
return_last_state=ssm_state is not None,
|
| 318 |
+
)
|
| 319 |
+
if ssm_state is not None:
|
| 320 |
+
y, last_state = y
|
| 321 |
+
ssm_state.copy_(last_state)
|
| 322 |
+
y = rearrange(y, "b d l -> b l d")
|
| 323 |
+
out = self.out_proj(y)
|
| 324 |
+
return out
|
| 325 |
+
|
| 326 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 327 |
+
dtype = hidden_states.dtype
|
| 328 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 329 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
| 330 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
| 331 |
+
|
| 332 |
+
# Conv step
|
| 333 |
+
if causal_conv1d_update is None:
|
| 334 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
| 335 |
+
conv_state[:, :, -1] = x
|
| 336 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 337 |
+
if self.conv1d.bias is not None:
|
| 338 |
+
x = x + self.conv1d.bias
|
| 339 |
+
x = self.act(x).to(dtype=dtype)
|
| 340 |
+
else:
|
| 341 |
+
x = causal_conv1d_update(
|
| 342 |
+
x,
|
| 343 |
+
conv_state,
|
| 344 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 345 |
+
self.conv1d.bias,
|
| 346 |
+
self.activation,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
| 350 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 351 |
+
# Don't add dt_bias here
|
| 352 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
| 353 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 354 |
+
|
| 355 |
+
# SSM step
|
| 356 |
+
if selective_state_update is None:
|
| 357 |
+
# Discretize A and B
|
| 358 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
| 359 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
| 360 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
| 361 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
| 362 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
| 363 |
+
y = y + self.D.to(dtype) * x
|
| 364 |
+
y = y * self.act(z) # (B D)
|
| 365 |
+
else:
|
| 366 |
+
y = selective_state_update(
|
| 367 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
out = self.out_proj(y)
|
| 371 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 372 |
+
|
| 373 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 374 |
+
device = self.out_proj.weight.device
|
| 375 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 376 |
+
conv_state = torch.zeros(
|
| 377 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
| 378 |
+
)
|
| 379 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 380 |
+
# ssm_dtype = torch.float32
|
| 381 |
+
ssm_state = torch.zeros(
|
| 382 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
| 383 |
+
)
|
| 384 |
+
return conv_state, ssm_state
|
| 385 |
+
|
| 386 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 387 |
+
assert self.layer_idx is not None
|
| 388 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
| 389 |
+
batch_shape = (batch_size,)
|
| 390 |
+
conv_state = torch.zeros(
|
| 391 |
+
batch_size,
|
| 392 |
+
self.d_model * self.expand,
|
| 393 |
+
self.d_conv,
|
| 394 |
+
device=self.conv1d.weight.device,
|
| 395 |
+
dtype=self.conv1d.weight.dtype,
|
| 396 |
+
)
|
| 397 |
+
ssm_state = torch.zeros(
|
| 398 |
+
batch_size,
|
| 399 |
+
self.d_model * self.expand,
|
| 400 |
+
self.d_state,
|
| 401 |
+
device=self.dt_proj.weight.device,
|
| 402 |
+
dtype=self.dt_proj.weight.dtype,
|
| 403 |
+
# dtype=torch.float32,
|
| 404 |
+
)
|
| 405 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
| 406 |
+
else:
|
| 407 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
| 408 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
| 409 |
+
if initialize_states:
|
| 410 |
+
conv_state.zero_()
|
| 411 |
+
ssm_state.zero_()
|
| 412 |
+
return conv_state, ssm_state
|
| 413 |
+
|
| 414 |
+
def to_3d(x):
|
| 415 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 416 |
+
|
| 417 |
+
def to_4d(x, h, w):
|
| 418 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 419 |
+
|
| 420 |
+
class WithBias_LayerNorm(nn.Module):
|
| 421 |
+
def __init__(self, normalized_shape):
|
| 422 |
+
super(WithBias_LayerNorm, self).__init__()
|
| 423 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 424 |
+
normalized_shape = (normalized_shape,)
|
| 425 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 426 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 427 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 428 |
+
def forward(self, x):
|
| 429 |
+
mu = x.mean(-1, keepdim=True)
|
| 430 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 431 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 432 |
+
|
| 433 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 434 |
+
def __init__(self, normalized_shape):
|
| 435 |
+
super(BiasFree_LayerNorm, self).__init__()
|
| 436 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 437 |
+
normalized_shape = (normalized_shape,)
|
| 438 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 439 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 440 |
+
def forward(self, x):
|
| 441 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 442 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 443 |
+
|
| 444 |
+
class LayerNorm(nn.Module):
|
| 445 |
+
def __init__(self, dim, norm_type='with_bias'):
|
| 446 |
+
super(LayerNorm, self).__init__()
|
| 447 |
+
if norm_type == 'BiasFree':
|
| 448 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 449 |
+
else:
|
| 450 |
+
self.body = WithBias_LayerNorm(dim)
|
| 451 |
+
def forward(self, x):
|
| 452 |
+
if len(x.shape) == 4:
|
| 453 |
+
h, w = x.shape[-2:]
|
| 454 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 455 |
+
else:
|
| 456 |
+
return self.body(x)
|
| 457 |
+
|
| 458 |
+
class M3(nn.Module):
|
| 459 |
+
def __init__(self, dim):
|
| 460 |
+
super(M3, self).__init__()
|
| 461 |
+
self.multi_modal_mamba_block = Mamba(dim, bimamba_type="m3")
|
| 462 |
+
self.norm1 = LayerNorm(dim, 'with_bias')# fusion
|
| 463 |
+
self.norm2 = LayerNorm(dim, 'with_bias')# I2
|
| 464 |
+
self.norm3 = LayerNorm(dim, 'with_bias')# I1
|
| 465 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
|
| 466 |
+
|
| 467 |
+
def forward(self, I1, I2, fusion, test_h, test_w):
|
| 468 |
+
fusion = self.norm1(fusion)
|
| 469 |
+
I2 = self.norm2(I2)
|
| 470 |
+
I1 = self.norm3(I1)
|
| 471 |
+
global_f = self.multi_modal_mamba_block(fusion, extra_emb1=I2, extra_emb2=I1)# [B, HW, C]
|
| 472 |
+
B, HW, C = global_f.shape
|
| 473 |
+
fusion = global_f.transpose(1, 2).view(B, C, test_h, test_w)
|
| 474 |
+
fusion = (self.dwconv(fusion) + fusion).flatten(2).transpose(1, 2)
|
| 475 |
+
return fusion, None
|
| 476 |
+
|
| 477 |
+
class PatchEmbed(nn.Module):
|
| 478 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 479 |
+
super(PatchEmbed, self).__init__()
|
| 480 |
+
img_size = to_2tuple(img_size)
|
| 481 |
+
patch_size = to_2tuple(patch_size)
|
| 482 |
+
self.img_size = img_size
|
| 483 |
+
self.patch_size = patch_size
|
| 484 |
+
self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 485 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
| 486 |
+
self.in_chans = in_chans
|
| 487 |
+
self.embed_dim = embed_dim
|
| 488 |
+
self.norm = norm_layer(embed_dim) if norm_layer is not None else None
|
| 489 |
+
def forward(self, x):
|
| 490 |
+
# x: [B, C, H, W]
|
| 491 |
+
x = x.flatten(2).transpose(1, 2) # [B, N, C]
|
| 492 |
+
if self.norm is not None:
|
| 493 |
+
x = self.norm(x)
|
| 494 |
+
return x
|
| 495 |
+
|
| 496 |
+
class PatchUnEmbed(nn.Module):
|
| 497 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 498 |
+
super(PatchUnEmbed, self).__init__()
|
| 499 |
+
img_size = to_2tuple(img_size)
|
| 500 |
+
patch_size = to_2tuple(patch_size)
|
| 501 |
+
self.img_size = img_size
|
| 502 |
+
self.patch_size = patch_size
|
| 503 |
+
self.patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 504 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
| 505 |
+
self.in_chans = in_chans
|
| 506 |
+
self.embed_dim = embed_dim
|
| 507 |
+
def forward(self, x, x_size):
|
| 508 |
+
B, HW, C = x.shape
|
| 509 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])
|
| 510 |
+
return x
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
class Block(nn.Module):
|
| 514 |
+
def __init__(
|
| 515 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
| 516 |
+
):
|
| 517 |
+
"""
|
| 518 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 519 |
+
|
| 520 |
+
This Block has a slightly different structure compared to a regular
|
| 521 |
+
prenorm Transformer block.
|
| 522 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 523 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 524 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 525 |
+
the hidden_states (output of the mixer) and the residual.
|
| 526 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 527 |
+
The residual needs to be provided (except for the very first block).
|
| 528 |
+
"""
|
| 529 |
+
super().__init__()
|
| 530 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 531 |
+
self.fused_add_norm = fused_add_norm
|
| 532 |
+
self.mixer = mixer_cls(dim)
|
| 533 |
+
self.norm = norm_cls(dim)
|
| 534 |
+
if self.fused_add_norm:
|
| 535 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 536 |
+
assert isinstance(
|
| 537 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 538 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 539 |
+
|
| 540 |
+
def forward(
|
| 541 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
| 542 |
+
):
|
| 543 |
+
r"""Pass the input through the encoder layer.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 547 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 548 |
+
"""
|
| 549 |
+
if not self.fused_add_norm:
|
| 550 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 551 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 552 |
+
if self.residual_in_fp32:
|
| 553 |
+
residual = residual.to(torch.float32)
|
| 554 |
+
else:
|
| 555 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
| 556 |
+
hidden_states, residual = fused_add_norm_fn(
|
| 557 |
+
hidden_states,
|
| 558 |
+
self.norm.weight,
|
| 559 |
+
self.norm.bias,
|
| 560 |
+
residual=residual,
|
| 561 |
+
prenorm=True,
|
| 562 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 563 |
+
eps=self.norm.eps,
|
| 564 |
+
)
|
| 565 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
| 566 |
+
return hidden_states, residual
|
| 567 |
+
|
| 568 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 569 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
models/resnet.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.hub import load_state_dict_from_url
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
| 7 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
| 8 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
model_urls = {
|
| 12 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 13 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 14 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 15 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 16 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 17 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
| 18 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
| 19 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
| 20 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 25 |
+
"""3x3 convolution with padding"""
|
| 26 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 27 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 31 |
+
"""1x1 convolution"""
|
| 32 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BasicBlock(nn.Module):
|
| 36 |
+
expansion = 1
|
| 37 |
+
|
| 38 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 39 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 40 |
+
super(BasicBlock, self).__init__()
|
| 41 |
+
if norm_layer is None:
|
| 42 |
+
norm_layer = nn.BatchNorm2d
|
| 43 |
+
if groups != 1 or base_width != 64:
|
| 44 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 45 |
+
if dilation > 1:
|
| 46 |
+
dilation = 1
|
| 47 |
+
# raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 48 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 49 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 50 |
+
self.bn1 = norm_layer(planes)
|
| 51 |
+
self.relu = nn.ReLU(inplace=True)
|
| 52 |
+
self.conv2 = conv3x3(planes, planes)
|
| 53 |
+
self.bn2 = norm_layer(planes)
|
| 54 |
+
self.downsample = downsample
|
| 55 |
+
self.stride = stride
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
identity = x
|
| 59 |
+
|
| 60 |
+
out = self.conv1(x)
|
| 61 |
+
out = self.bn1(out)
|
| 62 |
+
out = self.relu(out)
|
| 63 |
+
|
| 64 |
+
out = self.conv2(out)
|
| 65 |
+
out = self.bn2(out)
|
| 66 |
+
|
| 67 |
+
if self.downsample is not None:
|
| 68 |
+
identity = self.downsample(x)
|
| 69 |
+
|
| 70 |
+
out += identity
|
| 71 |
+
out = self.relu(out)
|
| 72 |
+
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Bottleneck(nn.Module):
|
| 77 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 78 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 79 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 80 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 81 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 82 |
+
|
| 83 |
+
expansion = 4
|
| 84 |
+
|
| 85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 86 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 87 |
+
super(Bottleneck, self).__init__()
|
| 88 |
+
if norm_layer is None:
|
| 89 |
+
norm_layer = nn.BatchNorm2d
|
| 90 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 91 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 92 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 93 |
+
self.bn1 = norm_layer(width)
|
| 94 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 95 |
+
self.bn2 = norm_layer(width)
|
| 96 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 97 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 98 |
+
self.relu = nn.ReLU(inplace=True)
|
| 99 |
+
self.downsample = downsample
|
| 100 |
+
self.stride = stride
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
identity = x
|
| 104 |
+
|
| 105 |
+
out = self.conv1(x)
|
| 106 |
+
out = self.bn1(out)
|
| 107 |
+
out = self.relu(out)
|
| 108 |
+
|
| 109 |
+
out = self.conv2(out)
|
| 110 |
+
out = self.bn2(out)
|
| 111 |
+
out = self.relu(out)
|
| 112 |
+
|
| 113 |
+
out = self.conv3(out)
|
| 114 |
+
out = self.bn3(out)
|
| 115 |
+
|
| 116 |
+
if self.downsample is not None:
|
| 117 |
+
identity = self.downsample(x)
|
| 118 |
+
|
| 119 |
+
out += identity
|
| 120 |
+
out = self.relu(out)
|
| 121 |
+
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class ResNet(nn.Module):
|
| 126 |
+
|
| 127 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 128 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 129 |
+
norm_layer=None, strides=None):
|
| 130 |
+
super(ResNet, self).__init__()
|
| 131 |
+
if norm_layer is None:
|
| 132 |
+
norm_layer = nn.BatchNorm2d
|
| 133 |
+
self._norm_layer = norm_layer
|
| 134 |
+
|
| 135 |
+
self.strides = strides
|
| 136 |
+
if self.strides is None:
|
| 137 |
+
self.strides = [2, 2, 2, 2, 2]
|
| 138 |
+
|
| 139 |
+
self.inplanes = 64
|
| 140 |
+
self.dilation = 1
|
| 141 |
+
if replace_stride_with_dilation is None:
|
| 142 |
+
# each element in the tuple indicates if we should replace
|
| 143 |
+
# the 2x2 stride with a dilated convolution instead
|
| 144 |
+
replace_stride_with_dilation = [False, False, False]
|
| 145 |
+
if len(replace_stride_with_dilation) != 3:
|
| 146 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 147 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 148 |
+
self.groups = groups
|
| 149 |
+
self.base_width = width_per_group
|
| 150 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3,
|
| 151 |
+
bias=False)
|
| 152 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 153 |
+
self.relu = nn.ReLU(inplace=True)
|
| 154 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
|
| 155 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 156 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2],
|
| 157 |
+
dilate=replace_stride_with_dilation[0])
|
| 158 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3],
|
| 159 |
+
dilate=replace_stride_with_dilation[1])
|
| 160 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4],
|
| 161 |
+
dilate=replace_stride_with_dilation[2])
|
| 162 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 163 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 164 |
+
|
| 165 |
+
for m in self.modules():
|
| 166 |
+
if isinstance(m, nn.Conv2d):
|
| 167 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 168 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 169 |
+
nn.init.constant_(m.weight, 1)
|
| 170 |
+
nn.init.constant_(m.bias, 0)
|
| 171 |
+
|
| 172 |
+
# Zero-initialize the last BN in each residual branch,
|
| 173 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 174 |
+
# This improves the models by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 175 |
+
if zero_init_residual:
|
| 176 |
+
for m in self.modules():
|
| 177 |
+
if isinstance(m, Bottleneck):
|
| 178 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 179 |
+
elif isinstance(m, BasicBlock):
|
| 180 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 181 |
+
|
| 182 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 183 |
+
norm_layer = self._norm_layer
|
| 184 |
+
downsample = None
|
| 185 |
+
previous_dilation = self.dilation
|
| 186 |
+
if dilate:
|
| 187 |
+
self.dilation *= stride
|
| 188 |
+
stride = 1
|
| 189 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 190 |
+
downsample = nn.Sequential(
|
| 191 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 192 |
+
norm_layer(planes * block.expansion),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
layers = []
|
| 196 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 197 |
+
self.base_width, previous_dilation, norm_layer))
|
| 198 |
+
self.inplanes = planes * block.expansion
|
| 199 |
+
for _ in range(1, blocks):
|
| 200 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 201 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 202 |
+
norm_layer=norm_layer))
|
| 203 |
+
|
| 204 |
+
return nn.Sequential(*layers)
|
| 205 |
+
|
| 206 |
+
def _forward_impl(self, x):
|
| 207 |
+
# See note [TorchScript super()]
|
| 208 |
+
x = self.conv1(x)
|
| 209 |
+
x = self.bn1(x)
|
| 210 |
+
x = self.relu(x)
|
| 211 |
+
x = self.maxpool(x)
|
| 212 |
+
|
| 213 |
+
x = self.layer1(x)
|
| 214 |
+
x = self.layer2(x)
|
| 215 |
+
x = self.layer3(x)
|
| 216 |
+
x = self.layer4(x)
|
| 217 |
+
|
| 218 |
+
x = self.avgpool(x)
|
| 219 |
+
x = torch.flatten(x, 1)
|
| 220 |
+
x = self.fc(x)
|
| 221 |
+
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
return self._forward_impl(x)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
| 229 |
+
model = ResNet(block, layers, **kwargs)
|
| 230 |
+
if pretrained:
|
| 231 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
| 232 |
+
progress=progress)
|
| 233 |
+
model.load_state_dict(state_dict)
|
| 234 |
+
return model
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
| 238 |
+
r"""ResNet-18 models from
|
| 239 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 243 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 244 |
+
"""
|
| 245 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
| 246 |
+
**kwargs)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
| 250 |
+
r"""ResNet-34 models from
|
| 251 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 255 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 256 |
+
"""
|
| 257 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
| 258 |
+
**kwargs)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
| 262 |
+
r"""ResNet-50 models from
|
| 263 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 267 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 268 |
+
"""
|
| 269 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
| 270 |
+
**kwargs)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
| 274 |
+
r"""ResNet-101 models from
|
| 275 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 279 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 280 |
+
"""
|
| 281 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
| 282 |
+
**kwargs)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
| 286 |
+
r"""ResNet-152 models from
|
| 287 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 291 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 292 |
+
"""
|
| 293 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
| 294 |
+
**kwargs)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
| 298 |
+
r"""ResNeXt-50 32x4d models from
|
| 299 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 303 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 304 |
+
"""
|
| 305 |
+
kwargs['groups'] = 32
|
| 306 |
+
kwargs['width_per_group'] = 4
|
| 307 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
| 308 |
+
pretrained, progress, **kwargs)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
| 312 |
+
r"""ResNeXt-101 32x8d models from
|
| 313 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 317 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 318 |
+
"""
|
| 319 |
+
kwargs['groups'] = 32
|
| 320 |
+
kwargs['width_per_group'] = 8
|
| 321 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
| 322 |
+
pretrained, progress, **kwargs)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
| 326 |
+
r"""Wide ResNet-50-2 models from
|
| 327 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 328 |
+
|
| 329 |
+
The models is the same as ResNet except for the bottleneck number of channels
|
| 330 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 331 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 332 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 336 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 337 |
+
"""
|
| 338 |
+
kwargs['width_per_group'] = 64 * 2
|
| 339 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
| 340 |
+
pretrained, progress, **kwargs)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
| 344 |
+
r"""Wide ResNet-101-2 models from
|
| 345 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 346 |
+
|
| 347 |
+
The models is the same as ResNet except for the bottleneck number of channels
|
| 348 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 349 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 350 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
pretrained (bool): If True, returns a models pre-trained on ImageNet
|
| 354 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 355 |
+
"""
|
| 356 |
+
kwargs['width_per_group'] = 64 * 2
|
| 357 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
| 358 |
+
pretrained, progress, **kwargs)
|