junma commited on
Commit
56afa1a
1 Parent(s): f82a26e
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Author: Jun Ma
4
+
5
+ import os
6
+ join = os.path.join
7
+ import argparse
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import tifffile as tif
12
+ import monai
13
+ from tqdm import tqdm
14
+ from utils.postprocess import mask_overlay
15
+ from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType
16
+ from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead
17
+ import matplotlib.pyplot as plt
18
+ from skimage import io, exposure, segmentation, morphology
19
+ from utils.postprocess import watershed_post
20
+ from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference
21
+ import gradio as gr
22
+
23
+ def normalize_channel(img, lower=0.1, upper=99.9):
24
+ non_zero_vals = img[np.nonzero(img)]
25
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
26
+ if percentiles[1] - percentiles[0] > 0.001:
27
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
28
+ else:
29
+ img_norm = img
30
+ return img_norm
31
+
32
+ def preprocess(img_data):
33
+ if len(img_data.shape) == 2:
34
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
35
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
36
+ img_data = img_data[:,:, :3]
37
+ else:
38
+ pass
39
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
40
+ for i in range(3):
41
+ img_channel_i = img_data[:,:,i]
42
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
43
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
44
+ return pre_img_data
45
+
46
+
47
+ def inference(pre_img_data):
48
+ test_npy = pre_img_data/np.max(pre_img_data)
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device)
51
+ checkpoint = torch.load('./model.pth', map_location=torch.device(device))
52
+ model.load_state_dict(checkpoint['model_state_dict'])
53
+ model.eval()
54
+ with torch.no_grad():
55
+ test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
56
+
57
+ val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model)
58
+
59
+ # watershed postprocessing
60
+ val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1])
61
+ test_pred_mask = val_seg_inst.squeeze().astype(np.uint16)
62
+
63
+ # overlay
64
+ boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner')
65
+ boundary = morphology.binary_dilation(boundary, morphology.disk(1))
66
+ pre_img_data[boundary, 0] = 0
67
+ pre_img_data[boundary, 1] = 255
68
+ pre_img_data[boundary, 2] = 0
69
+
70
+ return test_pred_mask, pre_img_data
71
+
72
+
73
+ def predict(img):
74
+ print('##########', img.name)
75
+ img_name = img.name
76
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
77
+ img_data = tif.imread(img_name)
78
+ else:
79
+ img_data = io.imread(img_name)
80
+ if len(img_data.shape)==2:
81
+ pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9)
82
+ pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1)
83
+ else:
84
+ pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8)
85
+ for i in range(3):
86
+ img_channel_i = img_data[:,:,i]
87
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
88
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9)
89
+
90
+ seg_labels, seg_overlay = inference(pre_img_data)
91
+
92
+ tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
93
+
94
+ return seg_overlay, join(os.getcwd(), 'segmentation.tiff')
95
+
96
+ unicell_api = gr.Interface(
97
+ predict,
98
+ inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"),
99
+ outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")],
100
+ title = "UniCell Online Demo",
101
+ examples=['demo.png', 'demo.tif']
102
+ )
103
+
104
+ unicell_api.launch()
105
+
demo.png ADDED
demo.tif ADDED
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a6849ec1969d4abc37b8eb915d03f7b6d6eb3092fc3f1ac5060d1310ddf89f9
3
+ size 90440917
models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sun Mar 20 14:23:55 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ from .unicell_modules import *
models/unicell_modules.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the NVIDIA Source Code License
5
+ # ---------------------------------------------------------------
6
+ import torch
7
+ import torch.nn as nn
8
+ from functools import partial
9
+ import math
10
+ from itertools import repeat
11
+ import collections.abc
12
+ from typing import Tuple, Union
13
+ from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock, UnetrPrUpBlock
14
+ from monai.networks.blocks.dynunet_block import get_conv_layer
15
+
16
+ # From PyTorch internals
17
+ def _ntuple(n):
18
+ def parse(x):
19
+ if isinstance(x, collections.abc.Iterable):
20
+ return x
21
+ return tuple(repeat(x, n))
22
+ return parse
23
+
24
+ to_1tuple = _ntuple(1)
25
+ to_2tuple = _ntuple(2)
26
+ to_3tuple = _ntuple(3)
27
+ to_4tuple = _ntuple(4)
28
+ to_ntuple = _ntuple
29
+
30
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
31
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
32
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
33
+ def norm_cdf(x):
34
+ # Computes standard normal cumulative distribution function
35
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
36
+
37
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
38
+ print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
39
+ "The distribution of values may be incorrect.",
40
+ stacklevel=2)
41
+
42
+ with torch.no_grad():
43
+ # Values are generated by using a truncated uniform distribution and
44
+ # then using the inverse CDF for the normal distribution.
45
+ # Get upper and lower cdf values
46
+ l = norm_cdf((a - mean) / std)
47
+ u = norm_cdf((b - mean) / std)
48
+
49
+ # Uniformly fill tensor with values from [l, u], then translate to
50
+ # [2l-1, 2u-1].
51
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
52
+
53
+ # Use inverse cdf transform for normal distribution to get truncated
54
+ # standard normal
55
+ tensor.erfinv_()
56
+
57
+ # Transform to proper mean, std
58
+ tensor.mul_(std * math.sqrt(2.))
59
+ tensor.add_(mean)
60
+
61
+ # Clamp to ensure it's in the proper range
62
+ tensor.clamp_(min=a, max=b)
63
+ return tensor
64
+
65
+ #%%
66
+ class Mlp(nn.Module):
67
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
68
+ super().__init__()
69
+ out_features = out_features or in_features
70
+ hidden_features = hidden_features or in_features
71
+ self.fc1 = nn.Linear(in_features, hidden_features)
72
+ self.dwconv = DWConv(hidden_features)
73
+ self.act = act_layer()
74
+ self.fc2 = nn.Linear(hidden_features, out_features)
75
+ self.drop = nn.Dropout(drop)
76
+
77
+ self.apply(self._init_weights)
78
+
79
+ def _init_weights(self, m):
80
+ if isinstance(m, nn.Linear):
81
+ trunc_normal_(m.weight, std=.02)
82
+ if isinstance(m, nn.Linear) and m.bias is not None:
83
+ nn.init.constant_(m.bias, 0)
84
+ elif isinstance(m, nn.LayerNorm):
85
+ nn.init.constant_(m.bias, 0)
86
+ nn.init.constant_(m.weight, 1.0)
87
+ elif isinstance(m, nn.Conv2d):
88
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
89
+ fan_out //= m.groups
90
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
91
+ if m.bias is not None:
92
+ m.bias.data.zero_()
93
+
94
+ def forward(self, x, H, W):
95
+ x = self.fc1(x)
96
+ x = self.dwconv(x, H, W)
97
+ x = self.act(x)
98
+ x = self.drop(x)
99
+ x = self.fc2(x)
100
+ x = self.drop(x)
101
+ return x
102
+
103
+
104
+ class Attention(nn.Module):
105
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
106
+ super().__init__()
107
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
108
+
109
+ self.dim = dim
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.scale = qk_scale or head_dim ** -0.5
113
+
114
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
115
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
116
+ self.attn_drop = nn.Dropout(attn_drop)
117
+ self.proj = nn.Linear(dim, dim)
118
+ self.proj_drop = nn.Dropout(proj_drop)
119
+
120
+ self.sr_ratio = sr_ratio
121
+ if sr_ratio > 1:
122
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
123
+ self.norm = nn.LayerNorm(dim)
124
+
125
+ self.apply(self._init_weights)
126
+
127
+ def _init_weights(self, m):
128
+ if isinstance(m, nn.Linear):
129
+ trunc_normal_(m.weight, std=.02)
130
+ if isinstance(m, nn.Linear) and m.bias is not None:
131
+ nn.init.constant_(m.bias, 0)
132
+ elif isinstance(m, nn.LayerNorm):
133
+ nn.init.constant_(m.bias, 0)
134
+ nn.init.constant_(m.weight, 1.0)
135
+ elif isinstance(m, nn.Conv2d):
136
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137
+ fan_out //= m.groups
138
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
139
+ if m.bias is not None:
140
+ m.bias.data.zero_()
141
+
142
+ def forward(self, x, H, W):
143
+ B, N, C = x.shape
144
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
145
+
146
+ if self.sr_ratio > 1:
147
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
148
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
149
+ x_ = self.norm(x_)
150
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
+ else:
152
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
153
+ k, v = kv[0], kv[1]
154
+
155
+ attn = (q @ k.transpose(-2, -1)) * self.scale
156
+ attn = attn.softmax(dim=-1)
157
+ attn = self.attn_drop(attn)
158
+
159
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
160
+ x = self.proj(x)
161
+ x = self.proj_drop(x)
162
+
163
+ return x
164
+
165
+
166
+ class Block(nn.Module):
167
+
168
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
169
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
170
+ super().__init__()
171
+ self.norm1 = norm_layer(dim)
172
+ self.attn = Attention(
173
+ dim,
174
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
175
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
176
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
177
+ # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
178
+ self.drop_path = nn.Identity()
179
+ self.norm2 = norm_layer(dim)
180
+ mlp_hidden_dim = int(dim * mlp_ratio)
181
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
182
+
183
+ self.apply(self._init_weights)
184
+
185
+ def _init_weights(self, m):
186
+ if isinstance(m, nn.Linear):
187
+ trunc_normal_(m.weight, std=.02)
188
+ if isinstance(m, nn.Linear) and m.bias is not None:
189
+ nn.init.constant_(m.bias, 0)
190
+ elif isinstance(m, nn.LayerNorm):
191
+ nn.init.constant_(m.bias, 0)
192
+ nn.init.constant_(m.weight, 1.0)
193
+ elif isinstance(m, nn.Conv2d):
194
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
195
+ fan_out //= m.groups
196
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
197
+ if m.bias is not None:
198
+ m.bias.data.zero_()
199
+
200
+ def forward(self, x, H, W):
201
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
202
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
203
+
204
+ return x
205
+ #%%
206
+
207
+ class OverlapPatchEmbed(nn.Module):
208
+ """ Image to Patch Embedding
209
+ """
210
+
211
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
212
+ super().__init__()
213
+ img_size = to_2tuple(img_size)
214
+ patch_size = to_2tuple(patch_size)
215
+
216
+ self.img_size = img_size
217
+ self.patch_size = patch_size
218
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
219
+ self.num_patches = self.H * self.W
220
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
221
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
222
+ self.norm = nn.LayerNorm(embed_dim)
223
+
224
+ self.apply(self._init_weights)
225
+
226
+ def _init_weights(self, m):
227
+ if isinstance(m, nn.Linear):
228
+ trunc_normal_(m.weight, std=.02)
229
+ if isinstance(m, nn.Linear) and m.bias is not None:
230
+ nn.init.constant_(m.bias, 0)
231
+ elif isinstance(m, nn.LayerNorm):
232
+ nn.init.constant_(m.bias, 0)
233
+ nn.init.constant_(m.weight, 1.0)
234
+ elif isinstance(m, nn.Conv2d):
235
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
236
+ fan_out //= m.groups
237
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
238
+ if m.bias is not None:
239
+ m.bias.data.zero_()
240
+
241
+ def forward(self, x):
242
+ x = self.proj(x) # [2, 3, 224, 224]-> [2, 64, 56, 56]
243
+ # print(f"{x.shape=}")
244
+ _, _, H, W = x.shape
245
+ x = x.flatten(2).transpose(1, 2) # [2, 64, 56, 56]-> [2, 3136, 64]
246
+ # print(f"{x.shape=}")
247
+ x = self.norm(x) # [2, 3136, 64]-> [2, 3136, 64]
248
+ # print(f"{x.shape=}")
249
+
250
+ return x, H, W
251
+
252
+ # embed_dims=[64, 128, 256, 512]
253
+ # patch_embed1 = OverlapPatchEmbed(img_size=224,patch_size=7,stride=4,in_chans=in_chans, embed_dim=64)
254
+ # x1, H, W = patch_embed1(input_img)
255
+ # x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
256
+ # patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
257
+ # embed_dim=embed_dims[1])
258
+ # x2, H, W = patch_embed2(x1)
259
+ # x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
260
+
261
+ # patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
262
+ # embed_dim=embed_dims[2])
263
+ # x3, H, W = patch_embed3(x2)
264
+ # x3 = x3.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
265
+
266
+ # patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],embed_dim=embed_dims[3])
267
+ # x4, H, W = patch_embed4(x3)
268
+ # x4 = x4.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
269
+ #%%
270
+
271
+ class MixVisionTransformer(nn.Module):
272
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dims=[64, 128, 256, 512],
273
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
274
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
275
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
276
+ super().__init__()
277
+ # self.num_classes = num_classes
278
+ self.depths = depths
279
+
280
+ # patch_embed
281
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
282
+ embed_dim=embed_dims[0])
283
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
284
+ embed_dim=embed_dims[1])
285
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
286
+ embed_dim=embed_dims[2])
287
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
288
+ embed_dim=embed_dims[3])
289
+
290
+ # transformer encoder
291
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
292
+ cur = 0
293
+ self.block1 = nn.ModuleList([Block(
294
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
295
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
296
+ sr_ratio=sr_ratios[0])
297
+ for i in range(depths[0])])
298
+ self.norm1 = norm_layer(embed_dims[0])
299
+
300
+ cur += depths[0]
301
+ self.block2 = nn.ModuleList([Block(
302
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
303
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
304
+ sr_ratio=sr_ratios[1])
305
+ for i in range(depths[1])])
306
+ self.norm2 = norm_layer(embed_dims[1])
307
+
308
+ cur += depths[1]
309
+ self.block3 = nn.ModuleList([Block(
310
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
311
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
312
+ sr_ratio=sr_ratios[2])
313
+ for i in range(depths[2])])
314
+ self.norm3 = norm_layer(embed_dims[2])
315
+
316
+ cur += depths[2]
317
+ self.block4 = nn.ModuleList([Block(
318
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
319
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
320
+ sr_ratio=sr_ratios[3])
321
+ for i in range(depths[3])])
322
+ self.norm4 = norm_layer(embed_dims[3])
323
+
324
+ # classification head
325
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
326
+
327
+ self.apply(self._init_weights)
328
+
329
+ def _init_weights(self, m):
330
+ if isinstance(m, nn.Linear):
331
+ trunc_normal_(m.weight, std=.02)
332
+ if isinstance(m, nn.Linear) and m.bias is not None:
333
+ nn.init.constant_(m.bias, 0)
334
+ elif isinstance(m, nn.LayerNorm):
335
+ nn.init.constant_(m.bias, 0)
336
+ nn.init.constant_(m.weight, 1.0)
337
+ elif isinstance(m, nn.Conv2d):
338
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
339
+ fan_out //= m.groups
340
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
341
+ if m.bias is not None:
342
+ m.bias.data.zero_()
343
+
344
+ def init_weights(self, pretrained=None):
345
+ if isinstance(pretrained, str):
346
+ # logger = get_root_logger()
347
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
348
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False)
349
+ torch.load(pretrained, map_location='cpu')
350
+
351
+ def reset_drop_path(self, drop_path_rate):
352
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
353
+ cur = 0
354
+ for i in range(self.depths[0]):
355
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
356
+
357
+ cur += self.depths[0]
358
+ for i in range(self.depths[1]):
359
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
360
+
361
+ cur += self.depths[1]
362
+ for i in range(self.depths[2]):
363
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
364
+
365
+ cur += self.depths[2]
366
+ for i in range(self.depths[3]):
367
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
368
+
369
+ def freeze_patch_emb(self):
370
+ self.patch_embed1.requires_grad = False
371
+
372
+ @torch.jit.ignore
373
+ def no_weight_decay(self):
374
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
375
+
376
+ def get_classifier(self):
377
+ return self.head
378
+
379
+ # def reset_classifier(self, num_classes, global_pool=''):
380
+ # self.num_classes = num_classes
381
+ # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
382
+
383
+ def forward_features(self, x):
384
+ B = x.shape[0]
385
+ outs = []
386
+
387
+ # stage 1
388
+ x, H, W = self.patch_embed1(x)
389
+ for i, blk in enumerate(self.block1):
390
+ x = blk(x, H, W)
391
+ x = self.norm1(x)
392
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
393
+ outs.append(x)
394
+
395
+ # stage 2
396
+ x, H, W = self.patch_embed2(x)
397
+ for i, blk in enumerate(self.block2):
398
+ x = blk(x, H, W)
399
+ x = self.norm2(x)
400
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
401
+ outs.append(x)
402
+
403
+ # stage 3
404
+ x, H, W = self.patch_embed3(x)
405
+ for i, blk in enumerate(self.block3):
406
+ x = blk(x, H, W)
407
+ x = self.norm3(x)
408
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
409
+ outs.append(x)
410
+
411
+ # stage 4
412
+ x, H, W = self.patch_embed4(x)
413
+ for i, blk in enumerate(self.block4):
414
+ x = blk(x, H, W)
415
+ x = self.norm4(x)
416
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
417
+ outs.append(x)
418
+
419
+ return outs
420
+
421
+ def forward(self, x):
422
+ x = self.forward_features(x)
423
+ # x = self.head(x)
424
+
425
+ return x
426
+
427
+
428
+ class DWConv(nn.Module):
429
+ def __init__(self, dim=768):
430
+ super(DWConv, self).__init__()
431
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
432
+
433
+ def forward(self, x, H, W):
434
+ B, N, C = x.shape
435
+ x = x.transpose(1, 2).view(B, C, H, W)
436
+ x = self.dwconv(x)
437
+ x = x.flatten(2).transpose(1, 2)
438
+ return x
439
+
440
+
441
+
442
+
443
+ class mit_b0(MixVisionTransformer):
444
+ def __init__(self, **kwargs):
445
+ super(mit_b0, self).__init__(
446
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
447
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
448
+ drop_rate=0.0, drop_path_rate=0.1)
449
+
450
+
451
+
452
+ class mit_b1(MixVisionTransformer):
453
+ def __init__(self, **kwargs):
454
+ super(mit_b1, self).__init__(
455
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
456
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
457
+ drop_rate=0.0, drop_path_rate=0.1)
458
+
459
+
460
+ class mit_b2(MixVisionTransformer):
461
+ def __init__(self, **kwargs):
462
+ super(mit_b2, self).__init__(
463
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
464
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
465
+ drop_rate=0.0, drop_path_rate=0.1)
466
+
467
+
468
+
469
+ class mit_b3(MixVisionTransformer):
470
+ def __init__(self, **kwargs):
471
+ super(mit_b3, self).__init__(
472
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
473
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
474
+ drop_rate=0.0, drop_path_rate=0.1)
475
+
476
+
477
+
478
+ class mit_b4(MixVisionTransformer):
479
+ def __init__(self, **kwargs):
480
+ super(mit_b4, self).__init__(
481
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
482
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
483
+ drop_rate=0.0, drop_path_rate=0.1)
484
+
485
+
486
+
487
+ class mit_b5(MixVisionTransformer):
488
+ def __init__(self, **kwargs):
489
+ super(mit_b5, self).__init__(
490
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
491
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
492
+ drop_rate=0.0, drop_path_rate=0.1)
493
+
494
+
495
+ #%% B2
496
+ class MiT_B2_UNet_MultiHead(nn.Module):
497
+ def __init__(self,
498
+ in_channels: int,
499
+ out_channels: int,
500
+ regress_class: int = 1,
501
+ img_size: Tuple[int, int] = (256,256),
502
+
503
+ feature_size: int = 16,
504
+ spatial_dims: int = 2,
505
+ # hidden_size: int = 768,
506
+ # mlp_dim: int = 3072,
507
+ num_heads = [1, 2, 4, 8],
508
+ # pos_embed: str = "perceptron",
509
+ norm_name: Union[Tuple, str] = "instance",
510
+ conv_block: bool = False,
511
+ res_block: bool = True,
512
+ dropout_rate: float = 0.0,
513
+ debug: bool = False
514
+ ):
515
+ super().__init__()
516
+ self.debug = debug
517
+ self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16],
518
+ num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
519
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
520
+
521
+ self.encoder1 = UnetrBasicBlock(
522
+ spatial_dims=spatial_dims,
523
+ in_channels=in_channels,
524
+ out_channels=feature_size,
525
+ kernel_size=3,
526
+ stride=1,
527
+ norm_name=norm_name,
528
+ res_block=True,
529
+ )
530
+
531
+ self.encoder2 = UnetrBasicBlock(
532
+ spatial_dims=spatial_dims,
533
+ in_channels=2 * feature_size,
534
+ out_channels=2 * feature_size,
535
+ kernel_size=3,
536
+ stride=1,
537
+ norm_name=norm_name,
538
+ res_block=True,
539
+ )
540
+
541
+ self.encoder3 = UnetrBasicBlock(
542
+ spatial_dims=spatial_dims,
543
+ in_channels=4 * feature_size,
544
+ out_channels=4 * feature_size,
545
+ kernel_size=3,
546
+ stride=1,
547
+ norm_name=norm_name,
548
+ res_block=True,
549
+ )
550
+
551
+ self.encoder4 = UnetrBasicBlock(
552
+ spatial_dims=spatial_dims,
553
+ in_channels=8 * feature_size,
554
+ out_channels=8 * feature_size,
555
+ kernel_size=3,
556
+ stride=1,
557
+ norm_name=norm_name,
558
+ res_block=True,
559
+ )
560
+
561
+ self.encoder5 = UnetrBasicBlock(
562
+ spatial_dims=spatial_dims,
563
+ in_channels=16 * feature_size,
564
+ out_channels=16 * feature_size,
565
+ kernel_size=3,
566
+ stride=1,
567
+ norm_name=norm_name,
568
+ res_block=True,
569
+ )
570
+
571
+ self.decoder4 = UnetrUpBlock(
572
+ spatial_dims=2,
573
+ in_channels=feature_size * 16,
574
+ out_channels=feature_size * 8,
575
+ kernel_size=3,
576
+ upsample_kernel_size=2,
577
+ norm_name=norm_name,
578
+ res_block=res_block,
579
+ )
580
+ self.decoder3 = UnetrUpBlock(
581
+ spatial_dims=2,
582
+ in_channels=feature_size * 8,
583
+ out_channels=feature_size * 4,
584
+ kernel_size=3,
585
+ upsample_kernel_size=2,
586
+ norm_name=norm_name,
587
+ res_block=res_block,
588
+ )
589
+ self.decoder2 = UnetrUpBlock(
590
+ spatial_dims=2,
591
+ in_channels=feature_size * 4,
592
+ out_channels=feature_size * 2,
593
+ kernel_size=3,
594
+ upsample_kernel_size=2,
595
+ norm_name=norm_name,
596
+ res_block=res_block,
597
+ )
598
+
599
+ self.transp_conv = get_conv_layer(
600
+ spatial_dims=2,
601
+ in_channels=feature_size*2,
602
+ out_channels=feature_size*2,
603
+ kernel_size=3,
604
+ stride=2,
605
+ conv_only=True,
606
+ is_transposed=True,
607
+ )
608
+ self.decoder1 = UnetrUpBlock(
609
+ spatial_dims=2,
610
+ in_channels=feature_size * 2,
611
+ out_channels=feature_size,
612
+ kernel_size=3,
613
+ upsample_kernel_size=2,
614
+ norm_name=norm_name,
615
+ res_block=res_block,
616
+ )
617
+
618
+ self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
619
+ self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore
620
+
621
+ def forward(self, x_in):
622
+ hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
623
+ enc1 = self.encoder1(x_in) # (B, 16, 256, 256)
624
+ x1 = hidden_states_out[0] # (B, 32, 64, 64)
625
+ enc2 = self.encoder2(x1) # (B, 64, 32, 32)
626
+ x2 = hidden_states_out[1] # (B, 64, 32, 32)
627
+ enc3 = self.encoder3(x2) # (B, 128, 16, 16)
628
+ x3 = hidden_states_out[2] # (B, 128, 16,16)
629
+ enc4 = self.encoder4(x3) # (B, 256, 8, 8)
630
+ x4 = hidden_states_out[3] # (B, 256, 8, 8)
631
+ enc5 = self.encoder5(x4) # (B, 256, 8, 8)
632
+ # print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}")
633
+
634
+ dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16)
635
+ dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32)
636
+ dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64)
637
+ dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128]
638
+ dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256)
639
+ logits = self.out_interior(dec1)
640
+ dist = self.out_dist(dec1)
641
+
642
+ if self.debug:
643
+ return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
644
+ else:
645
+ return logits, dist
646
+
647
+ # print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}")
648
+
649
+ img_size = 256
650
+ in_chans = 3
651
+ B = 2
652
+ input_img = torch.randn((B,in_chans,img_size,img_size))
653
+
654
+ b2 = MiT_B2_UNet_MultiHead(3, 3, img_size=img_size)
655
+ logits, dist = b2(input_img)
656
+
657
+
658
+ #%% B3
659
+ class MiT_B3_UNet_MultiHead(nn.Module):
660
+ def __init__(self,
661
+ in_channels: int,
662
+ out_channels: int,
663
+ regress_class: int = 1,
664
+ img_size: Tuple[int, int] = (256,256),
665
+
666
+ feature_size: int = 16,
667
+ spatial_dims: int = 2,
668
+ # hidden_size: int = 768,
669
+ # mlp_dim: int = 3072,
670
+ num_heads = [1, 2, 4, 8],
671
+ # pos_embed: str = "perceptron",
672
+ norm_name: Union[Tuple, str] = "instance",
673
+ conv_block: bool = False,
674
+ res_block: bool = True,
675
+ dropout_rate: float = 0.0,
676
+ debug: bool = False
677
+ ):
678
+ super().__init__()
679
+ self.debug = debug
680
+ self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16],
681
+ num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
682
+ drop_rate=0.0, drop_path_rate=0.1)
683
+
684
+ self.encoder1 = UnetrBasicBlock(
685
+ spatial_dims=spatial_dims,
686
+ in_channels=in_channels,
687
+ out_channels=feature_size,
688
+ kernel_size=3,
689
+ stride=1,
690
+ norm_name=norm_name,
691
+ res_block=True,
692
+ )
693
+
694
+ self.encoder2 = UnetrBasicBlock(
695
+ spatial_dims=spatial_dims,
696
+ in_channels=2 * feature_size,
697
+ out_channels=2 * feature_size,
698
+ kernel_size=3,
699
+ stride=1,
700
+ norm_name=norm_name,
701
+ res_block=True,
702
+ )
703
+
704
+ self.encoder3 = UnetrBasicBlock(
705
+ spatial_dims=spatial_dims,
706
+ in_channels=4 * feature_size,
707
+ out_channels=4 * feature_size,
708
+ kernel_size=3,
709
+ stride=1,
710
+ norm_name=norm_name,
711
+ res_block=True,
712
+ )
713
+
714
+ self.encoder4 = UnetrBasicBlock(
715
+ spatial_dims=spatial_dims,
716
+ in_channels=8 * feature_size,
717
+ out_channels=8 * feature_size,
718
+ kernel_size=3,
719
+ stride=1,
720
+ norm_name=norm_name,
721
+ res_block=True,
722
+ )
723
+
724
+ self.encoder5 = UnetrBasicBlock(
725
+ spatial_dims=spatial_dims,
726
+ in_channels=16 * feature_size,
727
+ out_channels=16 * feature_size,
728
+ kernel_size=3,
729
+ stride=1,
730
+ norm_name=norm_name,
731
+ res_block=True,
732
+ )
733
+
734
+ self.decoder4 = UnetrUpBlock(
735
+ spatial_dims=2,
736
+ in_channels=feature_size * 16,
737
+ out_channels=feature_size * 8,
738
+ kernel_size=3,
739
+ upsample_kernel_size=2,
740
+ norm_name=norm_name,
741
+ res_block=res_block,
742
+ )
743
+ self.decoder3 = UnetrUpBlock(
744
+ spatial_dims=2,
745
+ in_channels=feature_size * 8,
746
+ out_channels=feature_size * 4,
747
+ kernel_size=3,
748
+ upsample_kernel_size=2,
749
+ norm_name=norm_name,
750
+ res_block=res_block,
751
+ )
752
+ self.decoder2 = UnetrUpBlock(
753
+ spatial_dims=2,
754
+ in_channels=feature_size * 4,
755
+ out_channels=feature_size * 2,
756
+ kernel_size=3,
757
+ upsample_kernel_size=2,
758
+ norm_name=norm_name,
759
+ res_block=res_block,
760
+ )
761
+
762
+ self.transp_conv = get_conv_layer(
763
+ spatial_dims=2,
764
+ in_channels=feature_size*2,
765
+ out_channels=feature_size*2,
766
+ kernel_size=3,
767
+ stride=2,
768
+ conv_only=True,
769
+ is_transposed=True,
770
+ )
771
+ self.decoder1 = UnetrUpBlock(
772
+ spatial_dims=2,
773
+ in_channels=feature_size * 2,
774
+ out_channels=feature_size,
775
+ kernel_size=3,
776
+ upsample_kernel_size=2,
777
+ norm_name=norm_name,
778
+ res_block=res_block,
779
+ )
780
+
781
+ self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
782
+ self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore
783
+
784
+ def forward(self, x_in):
785
+ hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
786
+ enc1 = self.encoder1(x_in) # (B, 16, 256, 256)
787
+ x1 = hidden_states_out[0] # (B, 32, 64, 64)
788
+ enc2 = self.encoder2(x1) # (B, 64, 32, 32)
789
+ x2 = hidden_states_out[1] # (B, 64, 32, 32)
790
+ enc3 = self.encoder3(x2) # (B, 128, 16, 16)
791
+ x3 = hidden_states_out[2] # (B, 128, 16,16)
792
+ enc4 = self.encoder4(x3) # (B, 256, 8, 8)
793
+ x4 = hidden_states_out[3] # (B, 256, 8, 8)
794
+ enc5 = self.encoder5(x4) # (B, 256, 8, 8)
795
+ # print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}")
796
+
797
+ dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16)
798
+ dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32)
799
+ dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64)
800
+ dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128]
801
+ dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256)
802
+ logits = self.out_interior(dec1)
803
+ dist = self.out_dist(dec1)
804
+
805
+ if self.debug:
806
+ return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
807
+ else:
808
+ return logits, dist
809
+
810
+ # print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}")
811
+
812
+
813
+
814
+ #%% head
815
+ class MLPEmbedding(nn.Module):
816
+ """
817
+ Linear Embedding
818
+ used in head
819
+ """
820
+ def __init__(self, input_dim=2048, embed_dim=768):
821
+ super().__init__()
822
+ self.proj = nn.Linear(input_dim, embed_dim)
823
+
824
+ def forward(self, x):
825
+ x = x.flatten(2).transpose(1, 2)
826
+ x = self.proj(x)
827
+ return x
828
+
829
+ class All_MLP_Head(nn.Module):
830
+ """
831
+ All MLP head in segformer
832
+ Simple and Efficient Design for Semantic Segmentation with Transformers
833
+ """
834
+ def __init__(self, in_channels=[64,128,320,512], # channel number of multi-scale features
835
+ in_index=[0,1,2,3],
836
+ feature_strides=[4,8,16,32],
837
+ dropout_ratio=0.1,
838
+ num_classes=3,
839
+ embedding_dim=768,
840
+ output_hidden_states=False):
841
+ super().__init__()
842
+ self.in_channels = in_channels
843
+ assert len(feature_strides) == len(self.in_channels)
844
+ assert min(feature_strides) == feature_strides[0]
845
+ self.in_index = in_index
846
+ self.feature_strides = feature_strides
847
+ self.dropout_ratio = dropout_ratio
848
+ self.num_classes = num_classes
849
+ self.output_hidden_states = output_hidden_states
850
+
851
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
852
+
853
+ # unify channel number to 768
854
+ self.linear_c4 = MLPEmbedding(input_dim=c4_in_channels, embed_dim=embedding_dim)
855
+ self.linear_c3 = MLPEmbedding(input_dim=c3_in_channels, embed_dim=embedding_dim)
856
+ self.linear_c2 = MLPEmbedding(input_dim=c2_in_channels, embed_dim=embedding_dim)
857
+ self.linear_c1 = MLPEmbedding(input_dim=c1_in_channels, embed_dim=embedding_dim)
858
+
859
+ self.linear_fuse = nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1,bias=False)
860
+ self.batch_norm = nn.BatchNorm2d(embedding_dim) # 4: number of blocks
861
+ self.activation = nn.ReLU()
862
+ if dropout_ratio>0:
863
+ self.dropout = nn.Dropout2d(self.dropout_ratio)
864
+ self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
865
+
866
+ def forward(self, inputs):
867
+ # x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
868
+ c1, c2, c3, c4 = inputs
869
+
870
+ ############## MLP decoder on C1-C4 ###########
871
+ n, _, h, w = c4.shape
872
+ # normalize channel number and resample to 1/4 HxW
873
+ _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
874
+ _c4 = nn.functional.interpolate(_c4, size=c1.size()[2:], mode='bilinear',align_corners=False)
875
+
876
+ _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
877
+ _c3 = nn.functional.interpolate(_c3, size=c1.size()[2:], mode='bilinear',align_corners=False)
878
+
879
+ _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
880
+ _c2 = nn.functional.interpolate(_c2, size=c1.size()[2:], mode='bilinear',align_corners=False)
881
+
882
+ _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
883
+
884
+ # concatenate features
885
+ hidden_states = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
886
+ hidden_states = self.batch_norm(hidden_states)
887
+ hidden_states = self.activation(hidden_states)
888
+ hidden_states = self.dropout(hidden_states)
889
+ # predict results
890
+ x = self.linear_pred(hidden_states)
891
+ if self.output_hidden_states:
892
+ return x, hidden_states
893
+ else:
894
+ return x
895
+
896
+
897
+
898
+ #%% test different networks
899
+ # img_size = 256
900
+ # in_chans = 3
901
+ # B = 2
902
+ # input_img = torch.randn((B,in_chans,img_size,img_size))
903
+
904
+ # b3 = mit_b3_demo(img_size=img_size)
905
+ # b3_out = b3(input_img)
906
+ # for feature in b3_out:
907
+ # print(f"{feature.shape=}")
908
+ # head = All_MLP_Head()
909
+ # outputs = head(b3_out)
910
+ # print(f"{outputs.shape = }")
911
+
912
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ numba
4
+ einops
5
+ imagecodecs
6
+ matplotlib
7
+ monai
8
+ pandas
9
+ pillow
10
+ scikit-image
11
+ torch
12
+ torchvision
utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Apr 7 10:53:23 2022
5
+
6
+ @author: jma
7
+ """
utils/multi_task_sliding_window_inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri Apr 1 19:18:58 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ from typing import Any, Callable, List, Sequence, Tuple, Union
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
13
+ from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option
14
+
15
+
16
+ __all__ = ["multi_task_sliding_window_inference"]
17
+
18
+ def multi_task_sliding_window_inference(
19
+ inputs: torch.Tensor,
20
+ roi_size: Union[Sequence[int], int],
21
+ sw_batch_size: int,
22
+ predictor: Callable[..., torch.Tensor],
23
+ overlap = 0.25,
24
+ mode = "constant",
25
+ sigma_scale = 0.125,
26
+ padding_mode = "constant",
27
+ cval = 0.0,
28
+ sw_device = None,
29
+ device = None,
30
+ *args: Any,
31
+ **kwargs: Any,
32
+ ) -> torch.Tensor:
33
+ """
34
+ Sliding window inference on `inputs` with `predictor`.
35
+
36
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
37
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
38
+
39
+ Args:
40
+ inputs: input image to be processed (assuming NCHW[D])
41
+ roi_size: the spatial window size for inferences.
42
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
43
+ if the components of the `roi_size` are non-positive values, the transform will use the
44
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
45
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
46
+ sw_batch_size: the batch size to run window slices.
47
+ predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)`
48
+ should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D];
49
+ where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`.
50
+ overlap: Amount of overlap between scans.
51
+ mode: {``"constant"``, ``"gaussian"``}
52
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
53
+
54
+ - ``"constant``": gives equal weight to all predictions.
55
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
56
+
57
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
58
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
59
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
60
+ spatial dimensions.
61
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
62
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
63
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
64
+ cval: fill value for 'constant' padding mode. Default: 0
65
+ sw_device: device for the window data.
66
+ By default the device (and accordingly the memory) of the `inputs` is used.
67
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
68
+ device: device for the stitched output prediction.
69
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
70
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
71
+ `inputs` and `roi_size`. Output is on the `device`.
72
+ args: optional args to be passed to ``predictor``.
73
+ kwargs: optional keyword args to be passed to ``predictor``.
74
+
75
+ Note:
76
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
77
+
78
+ """
79
+ num_spatial_dims = len(inputs.shape) - 2
80
+ if overlap < 0 or overlap >= 1:
81
+ raise AssertionError("overlap must be >= 0 and < 1.")
82
+
83
+ # determine image spatial size and batch size
84
+ # Note: all input images must have the same image size and batch size
85
+ image_size_ = list(inputs.shape[2:])
86
+ batch_size = inputs.shape[0]
87
+
88
+ if device is None:
89
+ device = inputs.device
90
+ if sw_device is None:
91
+ sw_device = inputs.device
92
+
93
+ roi_size = fall_back_tuple(roi_size, image_size_)
94
+ # in case that image size is smaller than roi size
95
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
96
+ pad_size = []
97
+ for k in range(len(inputs.shape) - 1, 1, -1):
98
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
99
+ half = diff // 2
100
+ pad_size.extend([half, diff - half])
101
+ inputs = F.pad(inputs, pad=pad_size, mode=mode, value=cval)
102
+
103
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
104
+
105
+ # Store all slices in list
106
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
107
+ num_win = len(slices) # number of windows per image
108
+ total_slices = num_win * batch_size # total number of windows
109
+
110
+ # Create window-level importance map
111
+ importance_map = compute_importance_map(
112
+ get_valid_patch_size(image_size, roi_size), mode="gaussian", sigma_scale=sigma_scale, device=device
113
+ )
114
+
115
+ # Perform predictions
116
+ output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)
117
+ output_dist = torch.tensor(0.0, device=device)
118
+ _initialized = False
119
+ for slice_g in range(0, total_slices, sw_batch_size):
120
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
121
+ unravel_slice = [
122
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
123
+ for idx in slice_range
124
+ ]
125
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
126
+ seg_logit, seg_dist = predictor(window_data)# .to(device) # batched patch segmentation
127
+ seg_logit = torch.nn.functional.interpolate(seg_logit, size=roi_size, mode="bilinear", align_corners=False)
128
+ seg_logit = torch.softmax(seg_logit, dim=1)
129
+ seg_dist = torch.nn.functional.interpolate(seg_dist, size=roi_size, mode="bilinear", align_corners=False)
130
+ seg_dist = torch.sigmoid(seg_dist)
131
+
132
+ if not _initialized: # init. buffer at the first iteration
133
+ output_classes = seg_logit.shape[1]
134
+ dist_class = seg_dist.shape[1]
135
+ output_shape = [batch_size, output_classes] + list(image_size)
136
+ output_dist_shape = [batch_size, dist_class] + list(image_size)
137
+ # allocate memory to store the full output and the count for overlapping parts
138
+ output_image = torch.zeros(output_shape, dtype=torch.float32, device=device)
139
+ output_dist = torch.zeros(output_dist_shape, dtype=torch.float32, device=device)
140
+ count_map = torch.zeros(output_shape, dtype=torch.float32, device=device)
141
+ count_dist_map = torch.zeros(output_dist_shape, dtype=torch.float32, device=device)
142
+ _initialized = True
143
+
144
+ # store the result in the proper location of the full output. Apply weights from importance map.
145
+ for idx, original_idx in zip(slice_range, unravel_slice):
146
+ output_image[original_idx] += importance_map * seg_logit[idx - slice_g]
147
+ output_dist[original_idx] += importance_map * seg_dist[idx - slice_g]
148
+ count_map[original_idx] += importance_map
149
+ count_dist_map[original_idx] += importance_map
150
+
151
+ # account for any overlapping sections
152
+ output_image = output_image / count_map
153
+ output_dist = output_dist / count_dist_map
154
+
155
+ final_slicing: List[slice] = []
156
+ for sp in range(num_spatial_dims):
157
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
158
+ final_slicing.insert(0, slice_dim)
159
+ while len(final_slicing) < len(output_image.shape):
160
+ final_slicing.insert(0, slice(None))
161
+ return output_image[final_slicing], output_dist[final_slicing]
162
+
163
+
164
+
165
+ def _get_scan_interval(
166
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
167
+ ) -> Tuple[int, ...]:
168
+ """
169
+ Compute scan interval according to the image size, roi size and overlap.
170
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
171
+ use 1 instead to make sure sliding window works.
172
+
173
+ """
174
+ if len(image_size) != num_spatial_dims:
175
+ raise ValueError("image coord different from spatial dims.")
176
+ if len(roi_size) != num_spatial_dims:
177
+ raise ValueError("roi coord different from spatial dims.")
178
+
179
+ scan_interval = []
180
+ for i in range(num_spatial_dims):
181
+ if roi_size[i] == image_size[i]:
182
+ scan_interval.append(int(roi_size[i]))
183
+ else:
184
+ interval = int(roi_size[i] * (1 - overlap))
185
+ scan_interval.append(interval if interval > 0 else 1)
186
+ return tuple(scan_interval)
187
+
utils/postprocess.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Apr 7 10:51:48 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ import numpy as np
10
+ from skimage import segmentation, measure, exposure, morphology
11
+ import scipy.ndimage as nd
12
+ from tqdm import tqdm
13
+ import skimage
14
+ import colorsys
15
+
16
+ def fill_holes(label_img, size=10, connectivity=1):
17
+ output_image = np.copy(label_img)
18
+ props = measure.regionprops(np.squeeze(label_img.astype('int')), cache=False)
19
+ for prop in props:
20
+ if prop.euler_number < 1:
21
+
22
+ patch = output_image[prop.slice]
23
+
24
+ filled = morphology.remove_small_holes(
25
+ ar=(patch == prop.label),
26
+ area_threshold=size,
27
+ connectivity=connectivity)
28
+
29
+ output_image[prop.slice] = np.where(filled, prop.label, patch)
30
+
31
+ return output_image
32
+
33
+ def watershed_post(distmaps, interiors, dist_thre=0.1, interior_thre=0.2):
34
+ """
35
+ Parameters
36
+ ----------
37
+ distmaps : float (N, H, W) N is the number of cells
38
+ distance transform map of cell/nuclear [0,1].
39
+ interiors : float (N, H, W)
40
+ interior map of cell/nuclear [0,1].
41
+
42
+ Returns
43
+ -------
44
+ label_images : uint (N, H, W)
45
+ cell/nuclear instance segmentation.
46
+
47
+ """
48
+
49
+ label_images = []
50
+ for maxima, interior in zip(distmaps, interiors):# in interiors[0:num]:
51
+ interior = nd.gaussian_filter(interior.astype(np.float32), 2)
52
+ # find marker based on distance map
53
+ if skimage.__version__ > '0.18.2':
54
+ markers = measure.label(morphology.h_maxima(image=maxima, h=dist_thre, footprint=morphology.disk(2)))
55
+ else:
56
+ markers = measure.label(morphology.h_maxima(image=maxima, h=dist_thre, selem=morphology.disk(2)))
57
+ # print('distmap marker num:', np.max(markers), 'interior marker num:', np.max(makers_interior))
58
+
59
+ label_image = segmentation.watershed(-1 * interior, markers,
60
+ mask=interior > interior_thre, # 0.2/0.3
61
+ watershed_line=0)
62
+
63
+ label_image = morphology.remove_small_objects(label_image, min_size=15)
64
+ # fill in holes that lie completely within a segmentation label
65
+ label_image = fill_holes(label_image, size=15)
66
+
67
+ # Relabel the label image
68
+ label_image, _, _ = segmentation.relabel_sequential(label_image)
69
+ label_images.append(label_image)
70
+ label_images = np.stack(label_images, axis=0).astype(np.uint)
71
+ return label_images
72
+
73
+
74
+
75
+ def hsv_to_rgb(arr):
76
+ hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
77
+ h, s, v = np.rollaxis(arr, axis=-1)
78
+ r, g, b = hsv_to_rgb_channels(h, s, v)
79
+ rgb = np.stack((r,g,b), axis=-1)
80
+ return rgb
81
+
82
+ def mask_overlay(img, masks):
83
+ """ overlay masks on image (set image to grayscale)
84
+ Adapted from https://github.com/MouseLand/cellpose/blob/06df602fbe074be02db3d716e280f0990816c726/cellpose/plot.py#L172
85
+ Parameters
86
+ ----------------
87
+
88
+ img: int or float, 2D or 3D array
89
+ img is of size [Ly x Lx (x nchan)]
90
+
91
+ masks: int, 2D array
92
+ masks where 0=NO masks; 1,2,...=mask labels
93
+
94
+ Returns
95
+ ----------------
96
+
97
+ RGB: uint8, 3D array
98
+ array of masks overlaid on grayscale image
99
+
100
+ """
101
+
102
+ if img.ndim>2:
103
+ img = img.astype(np.float32).mean(axis=-1)
104
+ else:
105
+ img = img.astype(np.float32)
106
+
107
+ HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
108
+ HSV[:,:,2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
109
+ hues = np.linspace(0, 1, masks.max()+1)[np.random.permutation(masks.max())]
110
+ for n in range(int(masks.max())):
111
+ ipix = (masks==n+1).nonzero()
112
+ HSV[ipix[0],ipix[1],0] = hues[n]
113
+
114
+ HSV[ipix[0],ipix[1],1] = 1.0
115
+ RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
116
+ return RGB
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+