HaoFeng2019 commited on
Commit
ae3b630
1 Parent(s): f4ce0ac

Upload 7 files

Browse files
Files changed (7) hide show
  1. GeoTr.py +233 -0
  2. IllTr.py +284 -0
  3. demo.py +178 -0
  4. extractor.py +115 -0
  5. position_encoding.py +111 -0
  6. requirements.txt +7 -0
  7. seg.py +567 -0
GeoTr.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from extractor import BasicEncoder
2
+ from position_encoding import build_position_encoding
3
+
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn, Tensor
8
+ import torch.nn.functional as F
9
+ import copy
10
+ from typing import Optional
11
+
12
+
13
+ class attnLayer(nn.Module):
14
+ def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1,
15
+ activation="relu", normalize_before=False):
16
+ super().__init__()
17
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
18
+ self.multihead_attn_list = nn.ModuleList([copy.deepcopy(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) for i in range(2)])
19
+ # Implementation of Feedforward model
20
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
21
+ self.dropout = nn.Dropout(dropout)
22
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
23
+
24
+ self.norm1 = nn.LayerNorm(d_model)
25
+ self.norm2_list = nn.ModuleList([copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)])
26
+
27
+ self.norm3 = nn.LayerNorm(d_model)
28
+ self.dropout1 = nn.Dropout(dropout)
29
+ self.dropout2_list = nn.ModuleList([copy.deepcopy(nn.Dropout(dropout)) for i in range(2)])
30
+ self.dropout3 = nn.Dropout(dropout)
31
+
32
+ self.activation = _get_activation_fn(activation)
33
+ self.normalize_before = normalize_before
34
+
35
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
36
+ return tensor if pos is None else tensor + pos
37
+
38
+ def forward_post(self, tgt, memory_list, tgt_mask=None, memory_mask=None,
39
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
40
+ pos=None, memory_pos=None):
41
+ q = k = self.with_pos_embed(tgt, pos)
42
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
43
+ key_padding_mask=tgt_key_padding_mask)[0]
44
+ tgt = tgt + self.dropout1(tgt2)
45
+ tgt = self.norm1(tgt)
46
+ for memory, multihead_attn, norm2, dropout2, m_pos in zip(memory_list, self.multihead_attn_list, self.norm2_list, self.dropout2_list, memory_pos):
47
+ tgt2 = multihead_attn(query=self.with_pos_embed(tgt, pos),
48
+ key=self.with_pos_embed(memory, m_pos),
49
+ value=memory, attn_mask=memory_mask,
50
+ key_padding_mask=memory_key_padding_mask)[0]
51
+ tgt = tgt + dropout2(tgt2)
52
+ tgt = norm2(tgt)
53
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
54
+ tgt = tgt + self.dropout3(tgt2)
55
+ tgt = self.norm3(tgt)
56
+ return tgt
57
+
58
+ def forward_pre(self, tgt, memory, tgt_mask=None, memory_mask=None,
59
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
60
+ pos=None, memory_pos=None):
61
+ tgt2 = self.norm1(tgt)
62
+ q = k = self.with_pos_embed(tgt2, pos)
63
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
64
+ key_padding_mask=tgt_key_padding_mask)[0]
65
+ tgt = tgt + self.dropout1(tgt2)
66
+ tgt2 = self.norm2(tgt)
67
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, pos),
68
+ key=self.with_pos_embed(memory, memory_pos),
69
+ value=memory, attn_mask=memory_mask,
70
+ key_padding_mask=memory_key_padding_mask)[0]
71
+ tgt = tgt + self.dropout2(tgt2)
72
+ tgt2 = self.norm3(tgt)
73
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
74
+ tgt = tgt + self.dropout3(tgt2)
75
+ return tgt
76
+
77
+ def forward(self, tgt, memory_list, tgt_mask=None, memory_mask=None,
78
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
79
+ pos=None, memory_pos=None):
80
+ if self.normalize_before:
81
+ return self.forward_pre(tgt, memory_list, tgt_mask, memory_mask,
82
+ tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos)
83
+ return self.forward_post(tgt, memory_list, tgt_mask, memory_mask,
84
+ tgt_key_padding_mask, memory_key_padding_mask, pos, memory_pos)
85
+
86
+
87
+ def _get_clones(module, N):
88
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
89
+
90
+
91
+ def _get_activation_fn(activation):
92
+ """Return an activation function given a string"""
93
+ if activation == "relu":
94
+ return F.relu
95
+ if activation == "gelu":
96
+ return F.gelu
97
+ if activation == "glu":
98
+ return F.glu
99
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
100
+
101
+
102
+ class TransDecoder(nn.Module):
103
+ def __init__(self, num_attn_layers, hidden_dim=128):
104
+ super(TransDecoder, self).__init__()
105
+ attn_layer = attnLayer(hidden_dim)
106
+ self.layers = _get_clones(attn_layer, num_attn_layers)
107
+ self.position_embedding = build_position_encoding(hidden_dim)
108
+
109
+ def forward(self, imgf, query_embed):
110
+ pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()) # torch.Size([1, 128, 36, 36])
111
+
112
+ bs, c, h, w = imgf.shape
113
+ imgf = imgf.flatten(2).permute(2, 0, 1)
114
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
115
+ pos = pos.flatten(2).permute(2, 0, 1)
116
+
117
+ for layer in self.layers:
118
+ query_embed = layer(query_embed, [imgf], pos=pos, memory_pos=[pos, pos])
119
+ query_embed = query_embed.permute(1, 2, 0).reshape(bs, c, h, w)
120
+
121
+ return query_embed
122
+
123
+
124
+ class TransEncoder(nn.Module):
125
+ def __init__(self, num_attn_layers, hidden_dim=128):
126
+ super(TransEncoder, self).__init__()
127
+ attn_layer = attnLayer(hidden_dim)
128
+ self.layers = _get_clones(attn_layer, num_attn_layers)
129
+ self.position_embedding = build_position_encoding(hidden_dim)
130
+
131
+ def forward(self, imgf):
132
+ pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()) # torch.Size([1, 128, 36, 36])
133
+ bs, c, h, w = imgf.shape
134
+ imgf = imgf.flatten(2).permute(2, 0, 1)
135
+ pos = pos.flatten(2).permute(2, 0, 1)
136
+
137
+ for layer in self.layers:
138
+ imgf = layer(imgf, [imgf], pos=pos, memory_pos=[pos, pos])
139
+ imgf = imgf.permute(1, 2, 0).reshape(bs, c, h, w)
140
+
141
+ return imgf
142
+
143
+
144
+ class FlowHead(nn.Module):
145
+ def __init__(self, input_dim=128, hidden_dim=256):
146
+ super(FlowHead, self).__init__()
147
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
148
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
149
+ self.relu = nn.ReLU(inplace=True)
150
+
151
+ def forward(self, x):
152
+ return self.conv2(self.relu(self.conv1(x)))
153
+
154
+
155
+ class UpdateBlock(nn.Module):
156
+ def __init__(self, hidden_dim=128):
157
+ super(UpdateBlock, self).__init__()
158
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
159
+ self.mask = nn.Sequential(
160
+ nn.Conv2d(hidden_dim, 256, 3, padding=1),
161
+ nn.ReLU(inplace=True),
162
+ nn.Conv2d(256, 64*9, 1, padding=0))
163
+
164
+ def forward(self, imgf, coords1):
165
+ mask = .25 * self.mask(imgf) # scale mask to balence gradients
166
+ dflow = self.flow_head(imgf)
167
+ coords1 = coords1 + dflow
168
+
169
+ return mask, coords1
170
+
171
+
172
+ def coords_grid(batch, ht, wd):
173
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
174
+ coords = torch.stack(coords[::-1], dim=0).float()
175
+ return coords[None].repeat(batch, 1, 1, 1)
176
+
177
+
178
+ def upflow8(flow, mode='bilinear'):
179
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
180
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
181
+
182
+
183
+ class GeoTr(nn.Module):
184
+ def __init__(self, num_attn_layers):
185
+ super(GeoTr, self).__init__()
186
+ self.num_attn_layers = num_attn_layers
187
+
188
+ self.hidden_dim = hdim = 256
189
+
190
+ self.fnet = BasicEncoder(output_dim=hdim, norm_fn='instance')
191
+
192
+ self.TransEncoder = TransEncoder(self.num_attn_layers, hidden_dim=hdim)
193
+ self.TransDecoder = TransDecoder(self.num_attn_layers, hidden_dim=hdim)
194
+ self.query_embed = nn.Embedding(1296, self.hidden_dim)
195
+
196
+ self.update_block = UpdateBlock(self.hidden_dim)
197
+
198
+ def initialize_flow(self, img):
199
+ N, C, H, W = img.shape
200
+ coodslar = coords_grid(N, H, W).to(img.device)
201
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
202
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
203
+
204
+ return coodslar, coords0, coords1
205
+
206
+ def upsample_flow(self, flow, mask):
207
+ N, _, H, W = flow.shape
208
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
209
+ mask = torch.softmax(mask, dim=2)
210
+
211
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
212
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
213
+
214
+ up_flow = torch.sum(mask * up_flow, dim=2)
215
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
216
+
217
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
218
+
219
+ def forward(self, image1):
220
+ fmap = self.fnet(image1)
221
+ fmap = torch.relu(fmap)
222
+
223
+ fmap = self.TransEncoder(fmap)
224
+ fmap = self.TransDecoder(fmap, self.query_embed.weight)
225
+
226
+ # convex upsample baesd on fmap
227
+ coodslar, coords0, coords1 = self.initialize_flow(image1)
228
+ coords1 = coords1.detach()
229
+ mask, coords1 = self.update_block(fmap, coords1)
230
+ flow_up = self.upsample_flow(coords1 - coords0, mask)
231
+ bm_up = coodslar + flow_up
232
+
233
+ return bm_up
IllTr.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.functional import Tensor
4
+ from torch.nn.modules.activation import Tanhshrink
5
+ from timm.models.layers import trunc_normal_
6
+ from functools import partial
7
+
8
+
9
+ class Ffn(nn.Module):
10
+ # feed forward network layer after attention
11
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
12
+ super().__init__()
13
+ out_features = out_features or in_features
14
+ hidden_features = hidden_features or in_features
15
+ self.fc1 = nn.Linear(in_features, hidden_features)
16
+ self.act = act_layer()
17
+ self.fc2 = nn.Linear(hidden_features, out_features)
18
+ self.drop = nn.Dropout(drop)
19
+
20
+ def forward(self, x):
21
+ x = self.fc1(x)
22
+ x = self.act(x)
23
+ x = self.drop(x)
24
+ x = self.fc2(x)
25
+ x = self.drop(x)
26
+ return x
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
31
+ super().__init__()
32
+ self.num_heads = num_heads
33
+ head_dim = dim // num_heads
34
+ self.scale = qk_scale or head_dim ** -0.5
35
+
36
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
37
+ self.attn_drop = nn.Dropout(attn_drop)
38
+ self.proj = nn.Linear(dim, dim)
39
+ self.proj_drop = nn.Dropout(proj_drop)
40
+
41
+ def forward(self, x, task_embed=None, level=0):
42
+ N, L, D = x.shape
43
+ qkv = self.qkv(x).reshape(N, L, 3, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)
44
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
45
+
46
+ # for decoder's task_embedding of different levels of attention layers
47
+ if task_embed != None:
48
+ _N, _H, _L, _D = q.shape
49
+ task_embed = task_embed.reshape(1, _H, _L, _D)
50
+ if level == 1:
51
+ q += task_embed
52
+ k += task_embed
53
+ if level == 2:
54
+ q += task_embed
55
+
56
+ attn = (q @ k.transpose(-2, -1)) * self.scale
57
+ attn = attn.softmax(dim=-1)
58
+ attn = self.attn_drop(attn)
59
+
60
+ x = (attn @ v).transpose(1, 2).reshape(N, L, D)
61
+ x = self.proj(x)
62
+ x = self.proj_drop(x)
63
+ return x
64
+
65
+
66
+ class EncoderLayer(nn.Module):
67
+ def __init__(self, dim, num_heads, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
68
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
69
+ super().__init__()
70
+ self.norm1 = norm_layer(dim)
71
+ self.attn = Attention(
72
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
73
+ self.norm2 = norm_layer(dim)
74
+ ffn_hidden_dim = int(dim * ffn_ratio)
75
+ self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop)
76
+
77
+ def forward(self, x):
78
+ x = x + self.attn(self.norm1(x))
79
+ x = x + self.ffn(self.norm2(x))
80
+ return x
81
+
82
+
83
+ class DecoderLayer(nn.Module):
84
+ def __init__(self, dim, num_heads, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
85
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
86
+ super().__init__()
87
+ self.norm1 = norm_layer(dim)
88
+ self.attn1 = Attention(
89
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
90
+ self.norm2 = norm_layer(dim)
91
+ self.attn2 = Attention(
92
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
+ self.norm3 = norm_layer(dim)
94
+ ffn_hidden_dim = int(dim * ffn_ratio)
95
+ self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop)
96
+
97
+ def forward(self, x, task_embed):
98
+ x = x + self.attn1(self.norm1(x), task_embed=task_embed, level=1)
99
+ x = x + self.attn2(self.norm2(x), task_embed=task_embed, level=2)
100
+ x = x + self.ffn(self.norm3(x))
101
+ return x
102
+
103
+
104
+ class ResBlock(nn.Module):
105
+ def __init__(self, channels):
106
+ super(ResBlock, self).__init__()
107
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=5, stride=1,
108
+ padding=2, bias=False)
109
+ self.bn1 = nn.InstanceNorm2d(channels)
110
+ self.relu = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=5, stride=1,
112
+ padding=2, bias=False)
113
+ self.bn2 = nn.InstanceNorm2d(channels)
114
+
115
+ def forward(self, x):
116
+ residual = x
117
+
118
+ out = self.conv1(x)
119
+ out = self.bn1(out)
120
+ out = self.relu(out)
121
+
122
+ out = self.conv2(out)
123
+ out = self.bn2(out)
124
+
125
+ out += residual
126
+ out = self.relu(out)
127
+
128
+ return out
129
+
130
+
131
+ class Head(nn.Module):
132
+ def __init__(self, in_channels, out_channels):
133
+ super(Head, self).__init__()
134
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1,
135
+ padding=1, bias=False)
136
+ self.bn1 = nn.InstanceNorm2d(out_channels)
137
+ self.relu = nn.ReLU(inplace=True)
138
+ self.resblock = ResBlock(out_channels)
139
+
140
+ def forward(self, x):
141
+ out = self.conv1(x)
142
+ out = self.bn1(out)
143
+ out = self.relu(out)
144
+
145
+ out = self.resblock(out)
146
+
147
+ return out
148
+
149
+
150
+ class PatchEmbed(nn.Module):
151
+ """ Feature to Patch Embedding
152
+ input : N C H W
153
+ output: N num_patch P^2*C
154
+ """
155
+ def __init__(self, patch_size=1, in_channels=64):
156
+ super().__init__()
157
+ self.patch_size = patch_size
158
+ self.dim = self.patch_size ** 2 * in_channels
159
+
160
+ def forward(self, x):
161
+ N, C, H, W = ori_shape = x.shape
162
+
163
+ p = self.patch_size
164
+ num_patches = (H // p) * (W // p)
165
+ out = torch.zeros((N, num_patches, self.dim)).to(x.device)
166
+ i, j = 0, 0
167
+ for k in range(num_patches):
168
+ if i + p > W:
169
+ i = 0
170
+ j += p
171
+ out[:, k, :] = x[:, :, i:i + p, j:j + p].flatten(1)
172
+ i += p
173
+ return out, ori_shape
174
+
175
+
176
+ class DePatchEmbed(nn.Module):
177
+ """ Patch Embedding to Feature
178
+ input : N num_patch P^2*C
179
+ output: N C H W
180
+ """
181
+ def __init__(self, patch_size=1, in_channels=64):
182
+ super().__init__()
183
+ self.patch_size = patch_size
184
+ self.num_patches = None
185
+ self.dim = self.patch_size ** 2 * in_channels
186
+
187
+ def forward(self, x, ori_shape):
188
+ N, num_patches, dim = x.shape
189
+ _, C, H, W = ori_shape
190
+ p = self.patch_size
191
+ out = torch.zeros(ori_shape).to(x.device)
192
+ i, j = 0, 0
193
+ for k in range(num_patches):
194
+ if i + p > W:
195
+ i = 0
196
+ j += p
197
+ out[:, :, i:i + p, j:j + p] = x[:, k, :].reshape(N, C, p, p)
198
+ i += p
199
+ return out
200
+
201
+
202
+ class Tail(nn.Module):
203
+ def __init__(self, in_channels, out_channels):
204
+ super(Tail, self).__init__()
205
+ self.output = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
206
+
207
+ def forward(self, x):
208
+ out = self.output(x)
209
+ return out
210
+
211
+
212
+ class IllTr_Net(nn.Module):
213
+ """ Vision Transformer with support for patch or hybrid CNN input stage
214
+ """
215
+
216
+ def __init__(self, patch_size=1, in_channels=3, mid_channels=16, num_classes=1000, depth=12,
217
+ num_heads=8, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
218
+ norm_layer=nn.LayerNorm):
219
+ super(IllTr_Net, self).__init__()
220
+
221
+ self.num_classes = num_classes
222
+ self.embed_dim = patch_size * patch_size * mid_channels
223
+ self.head = Head(in_channels, mid_channels)
224
+ self.patch_embedding = PatchEmbed(patch_size=patch_size, in_channels=mid_channels)
225
+ self.embed_dim = self.patch_embedding.dim
226
+ if self.embed_dim % num_heads != 0:
227
+ raise RuntimeError("Embedding dim must be devided by numbers of heads")
228
+
229
+ self.pos_embed = nn.Parameter(torch.zeros(1, (128 // patch_size) ** 2, self.embed_dim))
230
+ self.task_embed = nn.Parameter(torch.zeros(6, 1, (128 // patch_size) ** 2, self.embed_dim))
231
+
232
+ self.encoder = nn.ModuleList([
233
+ EncoderLayer(
234
+ dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
235
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
236
+ for _ in range(depth)])
237
+ self.decoder = nn.ModuleList([
238
+ DecoderLayer(
239
+ dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
240
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
241
+ for _ in range(depth)])
242
+
243
+ self.de_patch_embedding = DePatchEmbed(patch_size=patch_size, in_channels=mid_channels)
244
+ # tail
245
+ self.tail = Tail(int(mid_channels), in_channels)
246
+
247
+ self.acf = nn.Hardtanh(0,1)
248
+
249
+ trunc_normal_(self.pos_embed, std=.02)
250
+ self.apply(self._init_weights)
251
+
252
+ def _init_weights(self, m):
253
+ if isinstance(m, nn.Linear):
254
+ trunc_normal_(m.weight, std=.02)
255
+ if isinstance(m, nn.Linear) and m.bias is not None:
256
+ nn.init.constant_(m.bias, 0)
257
+ elif isinstance(m, nn.LayerNorm):
258
+ nn.init.constant_(m.bias, 0)
259
+ nn.init.constant_(m.weight, 1.0)
260
+
261
+ def forward(self, x):
262
+ x = self.head(x)
263
+ x, ori_shape = self.patch_embedding(x)
264
+ x = x + self.pos_embed[:, :x.shape[1]]
265
+
266
+ for blk in self.encoder:
267
+ x = blk(x)
268
+
269
+ for blk in self.decoder:
270
+ x = blk(x, self.task_embed[0, :, :x.shape[1]])
271
+
272
+ x = self.de_patch_embedding(x, ori_shape)
273
+ x = self.tail(x)
274
+
275
+ x = self.acf(x)
276
+ return x
277
+
278
+
279
+ def IllTr(**kwargs):
280
+ model = IllTr_Net(
281
+ patch_size=4, depth=6, num_heads=8, ffn_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
282
+ **kwargs)
283
+
284
+ return model
demo.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #origin
2
+
3
+ from seg import U2NETP
4
+ from GeoTr import GeoTr
5
+ from IllTr import IllTr
6
+ from inference_ill import rec_ill
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import skimage.io as io
12
+ import numpy as np
13
+ import cv2
14
+ import glob
15
+ import os
16
+ from PIL import Image
17
+ import argparse
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+
22
+
23
+
24
+
25
+ import gradio as gr
26
+
27
+
28
+ class GeoTr_Seg(nn.Module):
29
+ def __init__(self):
30
+ super(GeoTr_Seg, self).__init__()
31
+ self.msk = U2NETP(3, 1)
32
+ self.GeoTr = GeoTr(num_attn_layers=6)
33
+
34
+ def forward(self, x):
35
+ msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
36
+ msk = (msk > 0.5).float()
37
+ x = msk * x
38
+
39
+ bm = self.GeoTr(x)
40
+ bm = (2 * (bm / 286.8) - 1) * 0.99
41
+
42
+ return bm
43
+
44
+
45
+ def reload_model(model, path=""):
46
+ if not bool(path):
47
+ return model
48
+ else:
49
+ model_dict = model.state_dict()
50
+ pretrained_dict = torch.load(path, map_location='cuda:0')
51
+ print(len(pretrained_dict.keys()))
52
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
+ print(len(pretrained_dict.keys()))
54
+ model_dict.update(pretrained_dict)
55
+ model.load_state_dict(model_dict)
56
+
57
+ return model
58
+
59
+
60
+ def reload_segmodel(model, path=""):
61
+ if not bool(path):
62
+ return model
63
+ else:
64
+ model_dict = model.state_dict()
65
+ pretrained_dict = torch.load(path, map_location='cuda:0')
66
+ print(len(pretrained_dict.keys()))
67
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
+ print(len(pretrained_dict.keys()))
69
+ model_dict.update(pretrained_dict)
70
+ model.load_state_dict(model_dict)
71
+
72
+ return model
73
+
74
+
75
+ def rec(opt):
76
+ # print(torch.__version__) # 1.5.1
77
+ img_list = os.listdir(opt.distorrted_path) # distorted images list
78
+
79
+ if not os.path.exists(opt.gsave_path): # create save path
80
+ os.mkdir(opt.gsave_path)
81
+ if not os.path.exists(opt.isave_path): # create save path
82
+ os.mkdir(opt.isave_path)
83
+
84
+ GeoTr_Seg_model = GeoTr_Seg().cuda()
85
+ # reload segmentation model
86
+ reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
+ # reload geometric unwarping model
88
+ reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
+
90
+ IllTr_model = IllTr().cuda()
91
+ # reload illumination rectification model
92
+ reload_model(IllTr_model, opt.IllTr_path)
93
+
94
+ # To eval mode
95
+ GeoTr_Seg_model.eval()
96
+ IllTr_model.eval()
97
+
98
+ for img_path in img_list:
99
+ name = img_path.split('.')[-2] # image name
100
+
101
+ img_path = opt.distorrted_path + img_path # read image and to tensor
102
+ im_ori = np.array(Image.open(img_path))[:, :, :3] / 255.
103
+ h, w, _ = im_ori.shape
104
+ im = cv2.resize(im_ori, (288, 288))
105
+ im = im.transpose(2, 0, 1)
106
+ im = torch.from_numpy(im).float().unsqueeze(0)
107
+
108
+ with torch.no_grad():
109
+ # geometric unwarping
110
+ bm = GeoTr_Seg_model(im.cuda())
111
+ bm = bm.cpu()
112
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
114
+ bm0 = cv2.blur(bm0, (3, 3))
115
+ bm1 = cv2.blur(bm1, (3, 3))
116
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
117
+
118
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2,0,1).unsqueeze(0).float(), lbl, align_corners=True)
119
+ img_geo = ((out[0]*255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
120
+ cv2.imwrite(opt.gsave_path + name + '_geo' + '.png', img_geo) # save
121
+
122
+ # illumination rectification
123
+ if opt.ill_rec:
124
+ ill_savep = opt.isave_path + name + '_ill' + '.png'
125
+ rec_ill(IllTr_model, img_geo, saveRecPath=ill_savep)
126
+
127
+ print('Done: ', img_path)
128
+
129
+
130
+
131
+
132
+
133
+
134
+ def process_image(input_image):
135
+ GeoTr_Seg_model = GeoTr_Seg().cuda()
136
+ reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
+ reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
+
139
+ IllTr_model = IllTr().cuda()
140
+ reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
+
142
+ GeoTr_Seg_model.eval()
143
+ IllTr_model.eval()
144
+
145
+ im_ori = np.array(input_image)[:, :, :3] / 255.
146
+ h, w, _ = im_ori.shape
147
+ im = cv2.resize(im_ori, (288, 288))
148
+ im = im.transpose(2, 0, 1)
149
+ im = torch.from_numpy(im).float().unsqueeze(0)
150
+
151
+ with torch.no_grad():
152
+ bm = GeoTr_Seg_model(im.cuda())
153
+ bm = bm.cpu()
154
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
156
+ bm0 = cv2.blur(bm0, (3, 3))
157
+ bm1 = cv2.blur(bm1, (3, 3))
158
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
159
+
160
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
161
+ img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
162
+
163
+ ill_rec=False
164
+
165
+ if ill_rec:
166
+ img_ill = rec_ill(IllTr_model, img_geo)
167
+ return Image.fromarray(img_ill)
168
+ else:
169
+ return Image.fromarray(img_geo)
170
+
171
+ # Define Gradio interface
172
+ input_image = gr.inputs.Image()
173
+ output_image = gr.outputs.Image(type='pil')
174
+
175
+
176
+ iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="Image Correction")
177
+ iface.launch(server_port=1234, server_name="0.0.0.0")
178
+
extractor.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+ class BasicEncoder(nn.Module):
60
+ def __init__(self, output_dim=128, norm_fn='batch'):
61
+ super(BasicEncoder, self).__init__()
62
+ self.norm_fn = norm_fn
63
+
64
+ if self.norm_fn == 'group':
65
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
66
+
67
+ elif self.norm_fn == 'batch':
68
+ self.norm1 = nn.BatchNorm2d(64)
69
+
70
+ elif self.norm_fn == 'instance':
71
+ self.norm1 = nn.InstanceNorm2d(64)
72
+
73
+ elif self.norm_fn == 'none':
74
+ self.norm1 = nn.Sequential()
75
+
76
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
77
+ self.relu1 = nn.ReLU(inplace=True)
78
+
79
+ self.in_planes = 64
80
+ self.layer1 = self._make_layer(64, stride=1)
81
+ self.layer2 = self._make_layer(128, stride=2)
82
+ self.layer3 = self._make_layer(192, stride=2)
83
+
84
+ # output convolution
85
+ self.conv2 = nn.Conv2d(192, output_dim, kernel_size=1)
86
+
87
+ for m in self.modules():
88
+ if isinstance(m, nn.Conv2d):
89
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
90
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
91
+ if m.weight is not None:
92
+ nn.init.constant_(m.weight, 1)
93
+ if m.bias is not None:
94
+ nn.init.constant_(m.bias, 0)
95
+
96
+ def _make_layer(self, dim, stride=1):
97
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
98
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
99
+ layers = (layer1, layer2)
100
+
101
+ self.in_planes = dim
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x):
105
+ x = self.conv1(x)
106
+ x = self.norm1(x)
107
+ x = self.relu1(x)
108
+
109
+ x = self.layer1(x)
110
+ x = self.layer2(x)
111
+ x = self.layer3(x)
112
+
113
+ x = self.conv2(x)
114
+
115
+ return x
position_encoding.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from typing import List
9
+ from typing import Optional
10
+ from torch import Tensor
11
+
12
+
13
+ class NestedTensor(object):
14
+ def __init__(self, tensors, mask: Optional[Tensor]):
15
+ self.tensors = tensors
16
+ self.mask = mask
17
+
18
+ def to(self, device):
19
+ # type: (Device) -> NestedTensor # noqa
20
+ cast_tensor = self.tensors.to(device)
21
+ mask = self.mask
22
+ if mask is not None:
23
+ assert mask is not None
24
+ cast_mask = mask.to(device)
25
+ else:
26
+ cast_mask = None
27
+ return NestedTensor(cast_tensor, cast_mask)
28
+
29
+ def decompose(self):
30
+ return self.tensors, self.mask
31
+
32
+ def __repr__(self):
33
+ return str(self.tensors)
34
+
35
+
36
+ class PositionEmbeddingSine(nn.Module):
37
+ """
38
+ This is a more standard version of the position embedding, very similar to the one
39
+ used by the Attention is all you need paper, generalized to work on images.
40
+ """
41
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
42
+ super().__init__()
43
+ self.num_pos_feats = num_pos_feats
44
+ self.temperature = temperature
45
+ self.normalize = normalize
46
+ if scale is not None and normalize is False:
47
+ raise ValueError("normalize should be True if scale is passed")
48
+ if scale is None:
49
+ scale = 2 * math.pi
50
+ self.scale = scale
51
+
52
+ def forward(self, mask):
53
+ assert mask is not None
54
+ y_embed = mask.cumsum(1, dtype=torch.float32)
55
+ x_embed = mask.cumsum(2, dtype=torch.float32)
56
+ if self.normalize:
57
+ eps = 1e-6
58
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
59
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
60
+
61
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda()
62
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
63
+
64
+ pos_x = x_embed[:, :, :, None] / dim_t
65
+ pos_y = y_embed[:, :, :, None] / dim_t
66
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
67
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
68
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
69
+ # print(pos.shape)
70
+ return pos
71
+
72
+
73
+ class PositionEmbeddingLearned(nn.Module):
74
+ """
75
+ Absolute pos embedding, learned.
76
+ """
77
+ def __init__(self, num_pos_feats=256):
78
+ super().__init__()
79
+ self.row_embed = nn.Embedding(50, num_pos_feats)
80
+ self.col_embed = nn.Embedding(50, num_pos_feats)
81
+ self.reset_parameters()
82
+
83
+ def reset_parameters(self):
84
+ nn.init.uniform_(self.row_embed.weight)
85
+ nn.init.uniform_(self.col_embed.weight)
86
+
87
+ def forward(self, tensor_list: NestedTensor):
88
+ x = tensor_list.tensors
89
+ h, w = x.shape[-2:]
90
+ i = torch.arange(w, device=x.device)
91
+ j = torch.arange(h, device=x.device)
92
+ x_emb = self.col_embed(i)
93
+ y_emb = self.row_embed(j)
94
+ pos = torch.cat([
95
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
96
+ y_emb.unsqueeze(1).repeat(1, w, 1),
97
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
98
+ return pos
99
+
100
+ def build_position_encoding(hidden_dim=512, position_embedding='sine'):
101
+ N_steps = hidden_dim // 2
102
+ if position_embedding in ('v2', 'sine'):
103
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
104
+ elif position_embedding in ('v3', 'learned'):
105
+ position_embedding = PositionEmbeddingLearned(N_steps)
106
+ else:
107
+ raise ValueError(f"not supported {position_embedding}")
108
+
109
+ return position_embedding
110
+
111
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv_python
3
+ Pillow
4
+ scikit_image
5
+ thop
6
+ torch
7
+ gradio
seg.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ class sobel_net(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.conv_opx = nn.Conv2d(1, 1, 3, bias=False)
12
+ self.conv_opy = nn.Conv2d(1, 1, 3, bias=False)
13
+ sobel_kernelx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype='float32').reshape((1, 1, 3, 3))
14
+ sobel_kernely = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype='float32').reshape((1, 1, 3, 3))
15
+ self.conv_opx.weight.data = torch.from_numpy(sobel_kernelx)
16
+ self.conv_opy.weight.data = torch.from_numpy(sobel_kernely)
17
+
18
+ for p in self.parameters():
19
+ p.requires_grad = False
20
+
21
+ def forward(self, im): # input rgb
22
+ x = (0.299 * im[:, 0, :, :] + 0.587 * im[:, 1, :, :] + 0.114 * im[:, 2, :, :]).unsqueeze(1) # rgb2gray
23
+ gradx = self.conv_opx(x)
24
+ grady = self.conv_opy(x)
25
+
26
+ x = (gradx ** 2 + grady ** 2) ** 0.5
27
+ x = (x - x.min()) / (x.max() - x.min())
28
+ x = F.pad(x, (1, 1, 1, 1))
29
+
30
+ x = torch.cat([im, x], dim=1)
31
+ return x
32
+
33
+
34
+ class REBNCONV(nn.Module):
35
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
36
+ super(REBNCONV, self).__init__()
37
+
38
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
39
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
40
+ self.relu_s1 = nn.ReLU(inplace=True)
41
+
42
+ def forward(self, x):
43
+ hx = x
44
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
45
+
46
+ return xout
47
+
48
+
49
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
50
+ def _upsample_like(src, tar):
51
+ src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
52
+
53
+ return src
54
+
55
+
56
+ ### RSU-7 ###
57
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
58
+
59
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
60
+ super(RSU7, self).__init__()
61
+
62
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
63
+
64
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
65
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
66
+
67
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
68
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
69
+
70
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
71
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
72
+
73
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
74
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
75
+
76
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
77
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
78
+
79
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
80
+
81
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
82
+
83
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
84
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
85
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
86
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
87
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
88
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
89
+
90
+ def forward(self, x):
91
+ hx = x
92
+ hxin = self.rebnconvin(hx)
93
+
94
+ hx1 = self.rebnconv1(hxin)
95
+ hx = self.pool1(hx1)
96
+
97
+ hx2 = self.rebnconv2(hx)
98
+ hx = self.pool2(hx2)
99
+
100
+ hx3 = self.rebnconv3(hx)
101
+ hx = self.pool3(hx3)
102
+
103
+ hx4 = self.rebnconv4(hx)
104
+ hx = self.pool4(hx4)
105
+
106
+ hx5 = self.rebnconv5(hx)
107
+ hx = self.pool5(hx5)
108
+
109
+ hx6 = self.rebnconv6(hx)
110
+
111
+ hx7 = self.rebnconv7(hx6)
112
+
113
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
114
+ hx6dup = _upsample_like(hx6d, hx5)
115
+
116
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
117
+ hx5dup = _upsample_like(hx5d, hx4)
118
+
119
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
120
+ hx4dup = _upsample_like(hx4d, hx3)
121
+
122
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
123
+ hx3dup = _upsample_like(hx3d, hx2)
124
+
125
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
126
+ hx2dup = _upsample_like(hx2d, hx1)
127
+
128
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
129
+
130
+ return hx1d + hxin
131
+
132
+
133
+ ### RSU-6 ###
134
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
135
+
136
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
137
+ super(RSU6, self).__init__()
138
+
139
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
140
+
141
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
142
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
143
+
144
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
145
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
146
+
147
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
148
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
149
+
150
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
151
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
152
+
153
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
154
+
155
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
156
+
157
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
158
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
159
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
160
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
161
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
162
+
163
+ def forward(self, x):
164
+ hx = x
165
+
166
+ hxin = self.rebnconvin(hx)
167
+
168
+ hx1 = self.rebnconv1(hxin)
169
+ hx = self.pool1(hx1)
170
+
171
+ hx2 = self.rebnconv2(hx)
172
+ hx = self.pool2(hx2)
173
+
174
+ hx3 = self.rebnconv3(hx)
175
+ hx = self.pool3(hx3)
176
+
177
+ hx4 = self.rebnconv4(hx)
178
+ hx = self.pool4(hx4)
179
+
180
+ hx5 = self.rebnconv5(hx)
181
+
182
+ hx6 = self.rebnconv6(hx5)
183
+
184
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
185
+ hx5dup = _upsample_like(hx5d, hx4)
186
+
187
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
188
+ hx4dup = _upsample_like(hx4d, hx3)
189
+
190
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
191
+ hx3dup = _upsample_like(hx3d, hx2)
192
+
193
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
194
+ hx2dup = _upsample_like(hx2d, hx1)
195
+
196
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
197
+
198
+ return hx1d + hxin
199
+
200
+
201
+ ### RSU-5 ###
202
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
203
+
204
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
205
+ super(RSU5, self).__init__()
206
+
207
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
208
+
209
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
210
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
211
+
212
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
213
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
214
+
215
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
216
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
217
+
218
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
219
+
220
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
221
+
222
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
223
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
224
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
225
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
226
+
227
+ def forward(self, x):
228
+ hx = x
229
+
230
+ hxin = self.rebnconvin(hx)
231
+
232
+ hx1 = self.rebnconv1(hxin)
233
+ hx = self.pool1(hx1)
234
+
235
+ hx2 = self.rebnconv2(hx)
236
+ hx = self.pool2(hx2)
237
+
238
+ hx3 = self.rebnconv3(hx)
239
+ hx = self.pool3(hx3)
240
+
241
+ hx4 = self.rebnconv4(hx)
242
+
243
+ hx5 = self.rebnconv5(hx4)
244
+
245
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
246
+ hx4dup = _upsample_like(hx4d, hx3)
247
+
248
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
249
+ hx3dup = _upsample_like(hx3d, hx2)
250
+
251
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
252
+ hx2dup = _upsample_like(hx2d, hx1)
253
+
254
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
255
+
256
+ return hx1d + hxin
257
+
258
+
259
+ ### RSU-4 ###
260
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
261
+
262
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
263
+ super(RSU4, self).__init__()
264
+
265
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
266
+
267
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
268
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
269
+
270
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
271
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
272
+
273
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
274
+
275
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
276
+
277
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
278
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
279
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
280
+
281
+ def forward(self, x):
282
+ hx = x
283
+
284
+ hxin = self.rebnconvin(hx)
285
+
286
+ hx1 = self.rebnconv1(hxin)
287
+ hx = self.pool1(hx1)
288
+
289
+ hx2 = self.rebnconv2(hx)
290
+ hx = self.pool2(hx2)
291
+
292
+ hx3 = self.rebnconv3(hx)
293
+
294
+ hx4 = self.rebnconv4(hx3)
295
+
296
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
297
+ hx3dup = _upsample_like(hx3d, hx2)
298
+
299
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
300
+ hx2dup = _upsample_like(hx2d, hx1)
301
+
302
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+
310
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
311
+ super(RSU4F, self).__init__()
312
+
313
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
314
+
315
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
316
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
317
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
318
+
319
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
320
+
321
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
322
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
323
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
324
+
325
+ def forward(self, x):
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ return hx1d + hxin
341
+
342
+
343
+ ##### U^2-Net ####
344
+ class U2NET(nn.Module):
345
+
346
+ def __init__(self, in_ch=3, out_ch=1):
347
+ super(U2NET, self).__init__()
348
+ self.edge = sobel_net()
349
+
350
+ self.stage1 = RSU7(in_ch, 32, 64)
351
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
352
+
353
+ self.stage2 = RSU6(64, 32, 128)
354
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage3 = RSU5(128, 64, 256)
357
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage4 = RSU4(256, 128, 512)
360
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage5 = RSU4F(512, 256, 512)
363
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage6 = RSU4F(512, 256, 512)
366
+
367
+ # decoder
368
+ self.stage5d = RSU4F(1024, 256, 512)
369
+ self.stage4d = RSU4(1024, 128, 256)
370
+ self.stage3d = RSU5(512, 64, 128)
371
+ self.stage2d = RSU6(256, 32, 64)
372
+ self.stage1d = RSU7(128, 16, 64)
373
+
374
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
375
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
376
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
377
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
378
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
379
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
380
+
381
+ self.outconv = nn.Conv2d(6, out_ch, 1)
382
+
383
+ def forward(self, x):
384
+ x = self.edge(x)
385
+ hx = x
386
+
387
+ # stage 1
388
+ hx1 = self.stage1(hx)
389
+ hx = self.pool12(hx1)
390
+
391
+ # stage 2
392
+ hx2 = self.stage2(hx)
393
+ hx = self.pool23(hx2)
394
+
395
+ # stage 3
396
+ hx3 = self.stage3(hx)
397
+ hx = self.pool34(hx3)
398
+
399
+ # stage 4
400
+ hx4 = self.stage4(hx)
401
+ hx = self.pool45(hx4)
402
+
403
+ # stage 5
404
+ hx5 = self.stage5(hx)
405
+ hx = self.pool56(hx5)
406
+
407
+ # stage 6
408
+ hx6 = self.stage6(hx)
409
+ hx6up = _upsample_like(hx6, hx5)
410
+
411
+ # -------------------- decoder --------------------
412
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
413
+ hx5dup = _upsample_like(hx5d, hx4)
414
+
415
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
416
+ hx4dup = _upsample_like(hx4d, hx3)
417
+
418
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
419
+ hx3dup = _upsample_like(hx3d, hx2)
420
+
421
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
422
+ hx2dup = _upsample_like(hx2d, hx1)
423
+
424
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
425
+
426
+ # side output
427
+ d1 = self.side1(hx1d)
428
+
429
+ d2 = self.side2(hx2d)
430
+ d2 = _upsample_like(d2, d1)
431
+
432
+ d3 = self.side3(hx3d)
433
+ d3 = _upsample_like(d3, d1)
434
+
435
+ d4 = self.side4(hx4d)
436
+ d4 = _upsample_like(d4, d1)
437
+
438
+ d5 = self.side5(hx5d)
439
+ d5 = _upsample_like(d5, d1)
440
+
441
+ d6 = self.side6(hx6)
442
+ d6 = _upsample_like(d6, d1)
443
+
444
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
445
+
446
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
447
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
448
+
449
+
450
+ ### U^2-Net small ###
451
+ class U2NETP(nn.Module):
452
+
453
+ def __init__(self, in_ch=3, out_ch=1):
454
+ super(U2NETP, self).__init__()
455
+
456
+ self.stage1 = RSU7(in_ch, 16, 64)
457
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
458
+
459
+ self.stage2 = RSU6(64, 16, 64)
460
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
461
+
462
+ self.stage3 = RSU5(64, 16, 64)
463
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
464
+
465
+ self.stage4 = RSU4(64, 16, 64)
466
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
467
+
468
+ self.stage5 = RSU4F(64, 16, 64)
469
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
470
+
471
+ self.stage6 = RSU4F(64, 16, 64)
472
+
473
+ # decoder
474
+ self.stage5d = RSU4F(128, 16, 64)
475
+ self.stage4d = RSU4(128, 16, 64)
476
+ self.stage3d = RSU5(128, 16, 64)
477
+ self.stage2d = RSU6(128, 16, 64)
478
+ self.stage1d = RSU7(128, 16, 64)
479
+
480
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
481
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
482
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
483
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
484
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
485
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
486
+
487
+ self.outconv = nn.Conv2d(6, out_ch, 1)
488
+
489
+ def forward(self, x):
490
+ hx = x
491
+
492
+ # stage 1
493
+ hx1 = self.stage1(hx)
494
+ hx = self.pool12(hx1)
495
+
496
+ # stage 2
497
+ hx2 = self.stage2(hx)
498
+ hx = self.pool23(hx2)
499
+
500
+ # stage 3
501
+ hx3 = self.stage3(hx)
502
+ hx = self.pool34(hx3)
503
+
504
+ # stage 4
505
+ hx4 = self.stage4(hx)
506
+ hx = self.pool45(hx4)
507
+
508
+ # stage 5
509
+ hx5 = self.stage5(hx)
510
+ hx = self.pool56(hx5)
511
+
512
+ # stage 6
513
+ hx6 = self.stage6(hx)
514
+ hx6up = _upsample_like(hx6, hx5)
515
+
516
+ # decoder
517
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
518
+ hx5dup = _upsample_like(hx5d, hx4)
519
+
520
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
521
+ hx4dup = _upsample_like(hx4d, hx3)
522
+
523
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
524
+ hx3dup = _upsample_like(hx3d, hx2)
525
+
526
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
527
+ hx2dup = _upsample_like(hx2d, hx1)
528
+
529
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
530
+
531
+ # side output
532
+ d1 = self.side1(hx1d)
533
+
534
+ d2 = self.side2(hx2d)
535
+ d2 = _upsample_like(d2, d1)
536
+
537
+ d3 = self.side3(hx3d)
538
+ d3 = _upsample_like(d3, d1)
539
+
540
+ d4 = self.side4(hx4d)
541
+ d4 = _upsample_like(d4, d1)
542
+
543
+ d5 = self.side5(hx5d)
544
+ d5 = _upsample_like(d5, d1)
545
+
546
+ d6 = self.side6(hx6)
547
+ d6 = _upsample_like(d6, d1)
548
+
549
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
550
+
551
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(
552
+ d4), torch.sigmoid(d5), torch.sigmoid(d6)
553
+
554
+
555
+ def get_parameter_number(net):
556
+ total_num = sum(p.numel() for p in net.parameters())
557
+ trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
558
+ return {'Total': total_num, 'Trainable': trainable_num}
559
+
560
+
561
+ if __name__ == '__main__':
562
+ net = U2NET(4, 1).cuda()
563
+ print(get_parameter_number(net)) # 69090500 加attention后69442032
564
+ with torch.no_grad():
565
+ inputs = torch.zeros(1, 3, 256, 256).cuda()
566
+ outs = net(inputs)
567
+ print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256])