Lewislou commited on
Commit
f065752
1 Parent(s): 0764ef3

Upload 4 files

Browse files
models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sun Mar 20 14:23:55 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ #from .unetr2d import UNETR2D
10
+ #from .swin_unetr import SwinUNETR
models/convnext.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from functools import partial
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+ from monai.networks.layers.factories import Act, Conv, Pad, Pool
15
+ from monai.networks.layers.utils import get_norm_layer
16
+ from monai.utils.module import look_up_option
17
+ from typing import List, NamedTuple, Optional, Tuple, Type, Union
18
+ class Block(nn.Module):
19
+ r""" ConvNeXt Block. There are two equivalent implementations:
20
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
21
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
22
+ We use (2) as we find it slightly faster in PyTorch
23
+
24
+ Args:
25
+ dim (int): Number of input channels.
26
+ drop_path (float): Stochastic depth rate. Default: 0.0
27
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
28
+ """
29
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
30
+ super().__init__()
31
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
32
+ self.norm = LayerNorm(dim, eps=1e-6)
33
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
34
+ self.act = nn.GELU()
35
+ self.pwconv2 = nn.Linear(4 * dim, dim)
36
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
37
+ requires_grad=True) if layer_scale_init_value > 0 else None
38
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
39
+
40
+ def forward(self, x):
41
+ input = x
42
+ x = self.dwconv(x)
43
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
44
+ x = self.norm(x)
45
+ x = self.pwconv1(x)
46
+ x = self.act(x)
47
+ x = self.pwconv2(x)
48
+ if self.gamma is not None:
49
+ x = self.gamma * x
50
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
51
+
52
+ x = input + self.drop_path(x)
53
+ return x
54
+
55
+ class ConvNeXt(nn.Module):
56
+ r""" ConvNeXt
57
+ A PyTorch impl of : `A ConvNet for the 2020s` -
58
+ https://arxiv.org/pdf/2201.03545.pdf
59
+
60
+ Args:
61
+ in_chans (int): Number of input image channels. Default: 3
62
+ num_classes (int): Number of classes for classification head. Default: 1000
63
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
64
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
65
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
66
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
67
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
68
+ """
69
+ def __init__(self, in_chans=3, num_classes=21841,
70
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
71
+ layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3],
72
+ ):
73
+ super().__init__()
74
+ # conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2]
75
+ # self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False)
76
+ # self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)
77
+
78
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
79
+ stem = nn.Sequential(
80
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
81
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
82
+ )
83
+ self.downsample_layers.append(stem)
84
+ for i in range(3):
85
+ downsample_layer = nn.Sequential(
86
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
87
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
88
+ )
89
+ self.downsample_layers.append(downsample_layer)
90
+
91
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
92
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
93
+ cur = 0
94
+ for i in range(4):
95
+ stage = nn.Sequential(
96
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
97
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
98
+ )
99
+ self.stages.append(stage)
100
+ cur += depths[i]
101
+
102
+
103
+ self.out_indices = out_indices
104
+
105
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
106
+ for i_layer in range(4):
107
+ layer = norm_layer(dims[i_layer])
108
+ layer_name = f'norm{i_layer}'
109
+ self.add_module(layer_name, layer)
110
+ self.apply(self._init_weights)
111
+
112
+
113
+ def _init_weights(self, m):
114
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
115
+ trunc_normal_(m.weight, std=.02)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def forward_features(self, x):
119
+ outs = []
120
+
121
+ for i in range(4):
122
+ x = self.downsample_layers[i](x)
123
+ x = self.stages[i](x)
124
+ if i in self.out_indices:
125
+ norm_layer = getattr(self, f'norm{i}')
126
+ x_out = norm_layer(x)
127
+
128
+ outs.append(x_out)
129
+
130
+ return tuple(outs)
131
+
132
+ def forward(self, x):
133
+ x = self.forward_features(x)
134
+
135
+ return x
136
+
137
+ class LayerNorm(nn.Module):
138
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
139
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
140
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
141
+ with shape (batch_size, channels, height, width).
142
+ """
143
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
144
+ super().__init__()
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
146
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
147
+ self.eps = eps
148
+ self.data_format = data_format
149
+ if self.data_format not in ["channels_last", "channels_first"]:
150
+ raise NotImplementedError
151
+ self.normalized_shape = (normalized_shape, )
152
+
153
+ def forward(self, x):
154
+ if self.data_format == "channels_last":
155
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
156
+ elif self.data_format == "channels_first":
157
+ u = x.mean(1, keepdim=True)
158
+ s = (x - u).pow(2).mean(1, keepdim=True)
159
+ x = (x - u) / torch.sqrt(s + self.eps)
160
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
161
+ return x
162
+
163
+
164
+ model_urls = {
165
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
166
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
167
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
168
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
169
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
170
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
171
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
172
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
173
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
174
+ }
175
+
176
+ @register_model
177
+ def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
178
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
179
+ if pretrained:
180
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
181
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
182
+ model.load_state_dict(checkpoint["model"])
183
+ return model
184
+
185
+ @register_model
186
+ def convnext_small(pretrained=False,in_22k=False, **kwargs):
187
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
188
+ if pretrained:
189
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
190
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191
+ model.load_state_dict(checkpoint["model"], strict=False)
192
+ return model
193
+
194
+ @register_model
195
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
196
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
197
+ if pretrained:
198
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
199
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200
+ model.load_state_dict(checkpoint["model"], strict=False)
201
+ return model
202
+
203
+ @register_model
204
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
205
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
206
+ if pretrained:
207
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
208
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
209
+ model.load_state_dict(checkpoint["model"])
210
+ return model
211
+
212
+ @register_model
213
+ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
214
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
215
+ if pretrained:
216
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
217
+ url = model_urls['convnext_xlarge_22k']
218
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
219
+ model.load_state_dict(checkpoint["model"])
220
+ return model
models/flexible_unet.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ }
38
+
39
+
40
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
41
+ """
42
+ Get the encoder output channels by given backbone name.
43
+
44
+ Args:
45
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
46
+ in_channels: channel of input tensor, default to 3.
47
+
48
+ Returns:
49
+ A tuple of output feature map channels' length .
50
+ """
51
+ encoder_channel_tuple = encoder_feature_channel[backbone]
52
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
53
+ encoder_channel = tuple(encoder_channel_list)
54
+ return encoder_channel
55
+
56
+
57
+ class UNetDecoder(nn.Module):
58
+ """
59
+ UNet Decoder.
60
+ This class refers to `segmentation_models.pytorch
61
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
62
+
63
+ Args:
64
+ spatial_dims: number of spatial dimensions.
65
+ encoder_channels: number of output channels for all feature maps in encoder.
66
+ `len(encoder_channels)` should be no less than 2.
67
+ decoder_channels: number of output channels for all feature maps in decoder.
68
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
69
+ act: activation type and arguments.
70
+ norm: feature normalization type and arguments.
71
+ dropout: dropout ratio.
72
+ bias: whether to have a bias term in convolution blocks in this decoder.
73
+ upsample: upsampling mode, available options are
74
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
75
+ pre_conv: a conv block applied before upsampling.
76
+ Only used in the "nontrainable" or "pixelshuffle" mode.
77
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
78
+ Only used in the "nontrainable" mode.
79
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
80
+ Only used in the "nontrainable" mode.
81
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
82
+
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ spatial_dims: int,
88
+ encoder_channels: Sequence[int],
89
+ decoder_channels: Sequence[int],
90
+ act: Union[str, tuple],
91
+ norm: Union[str, tuple],
92
+ dropout: Union[float, tuple],
93
+ bias: bool,
94
+ upsample: str,
95
+ pre_conv: Optional[str],
96
+ interp_mode: str,
97
+ align_corners: Optional[bool],
98
+ is_pad: bool,
99
+ ):
100
+
101
+ super().__init__()
102
+ if len(encoder_channels) < 2:
103
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
104
+ if len(decoder_channels) != len(encoder_channels) - 1:
105
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
106
+
107
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
108
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
109
+ halves = [True] * (len(skip_channels) - 1)
110
+ halves.append(False)
111
+ blocks = []
112
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
113
+ blocks.append(
114
+ UpCat(
115
+ spatial_dims=spatial_dims,
116
+ in_chns=in_chn,
117
+ cat_chns=skip_chn,
118
+ out_chns=out_chn,
119
+ act=act,
120
+ norm=norm,
121
+ dropout=dropout,
122
+ bias=bias,
123
+ upsample=upsample,
124
+ pre_conv=pre_conv,
125
+ interp_mode=interp_mode,
126
+ align_corners=align_corners,
127
+ halves=halve,
128
+ is_pad=is_pad,
129
+ )
130
+ )
131
+ self.blocks = nn.ModuleList(blocks)
132
+
133
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
134
+ skips = features[:-1][::-1]
135
+ features = features[1:][::-1]
136
+
137
+ x = features[0]
138
+ for i, block in enumerate(self.blocks):
139
+ if i < skip_connect:
140
+ skip = skips[i]
141
+ else:
142
+ skip = None
143
+ x = block(x, skip)
144
+
145
+ return x
146
+
147
+
148
+ class SegmentationHead(nn.Sequential):
149
+ """
150
+ Segmentation head.
151
+ This class refers to `segmentation_models.pytorch
152
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
153
+
154
+ Args:
155
+ spatial_dims: number of spatial dimensions.
156
+ in_channels: number of input channels for the block.
157
+ out_channels: number of output channels for the block.
158
+ kernel_size: kernel size for the conv layer.
159
+ act: activation type and arguments.
160
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
161
+
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ spatial_dims: int,
167
+ in_channels: int,
168
+ out_channels: int,
169
+ kernel_size: int = 3,
170
+ act: Optional[Union[Tuple, str]] = None,
171
+ scale_factor: float = 1.0,
172
+ ):
173
+
174
+ conv_layer = Conv[Conv.CONV, spatial_dims](
175
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
176
+ )
177
+ up_layer: nn.Module = nn.Identity()
178
+ if scale_factor > 1.0:
179
+ up_layer = UpSample(
180
+ spatial_dims=spatial_dims,
181
+ scale_factor=scale_factor,
182
+ mode="nontrainable",
183
+ pre_conv=None,
184
+ interp_mode=InterpolateMode.LINEAR,
185
+ )
186
+ if act is not None:
187
+ act_layer = get_act_layer(act)
188
+ else:
189
+ act_layer = nn.Identity()
190
+ super().__init__(conv_layer, up_layer, act_layer)
191
+
192
+
193
+ class FlexibleUNet(nn.Module):
194
+ """
195
+ A flexible implementation of UNet-like encoder-decoder architecture.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ out_channels: int,
202
+ backbone: str,
203
+ pretrained: bool = False,
204
+ decoder_channels: Tuple = (256, 128, 64, 32, 16),
205
+ spatial_dims: int = 2,
206
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
207
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
208
+ dropout: Union[float, tuple] = 0.0,
209
+ decoder_bias: bool = False,
210
+ upsample: str = "nontrainable",
211
+ interp_mode: str = "nearest",
212
+ is_pad: bool = True,
213
+ ) -> None:
214
+ """
215
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
216
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
217
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
218
+ is False
219
+
220
+ Args:
221
+ in_channels: number of input channels.
222
+ out_channels: number of output channels.
223
+ backbone: name of backbones to initialize, only support efficientnet right now,
224
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
225
+ pretrained: whether to initialize pretrained ImageNet weights, only available
226
+ for spatial_dims=2 and batch norm is used, default to False.
227
+ decoder_channels: number of output channels for all feature maps in decoder.
228
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
229
+ to (256, 128, 64, 32, 16).
230
+ spatial_dims: number of spatial dimensions, default to 2.
231
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
232
+ "momentum": 0.1}).
233
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
234
+ dropout: dropout ratio, default to 0.0.
235
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
236
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
237
+ ``"nontrainable"``.
238
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
239
+ Only used in the "nontrainable" mode.
240
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
241
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
242
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
243
+ """
244
+ super().__init__()
245
+
246
+ if backbone not in encoder_feature_channel:
247
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
248
+
249
+ if spatial_dims not in (2, 3):
250
+ raise ValueError("spatial_dims can only be 2 or 3.")
251
+
252
+ adv_prop = "ap" in backbone
253
+
254
+ self.backbone = backbone
255
+ self.spatial_dims = spatial_dims
256
+ model_name = backbone
257
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
258
+ self.encoder = EfficientNetBNFeatures(
259
+ model_name=model_name,
260
+ pretrained=pretrained,
261
+ in_channels=in_channels,
262
+ spatial_dims=spatial_dims,
263
+ norm=norm,
264
+ adv_prop=adv_prop,
265
+ )
266
+ self.decoder = UNetDecoder(
267
+ spatial_dims=spatial_dims,
268
+ encoder_channels=encoder_channels,
269
+ decoder_channels=decoder_channels,
270
+ act=act,
271
+ norm=norm,
272
+ dropout=dropout,
273
+ bias=decoder_bias,
274
+ upsample=upsample,
275
+ interp_mode=interp_mode,
276
+ pre_conv=None,
277
+ align_corners=None,
278
+ is_pad=is_pad,
279
+ )
280
+ self.dist_head = SegmentationHead(
281
+ spatial_dims=spatial_dims,
282
+ in_channels=decoder_channels[-1],
283
+ out_channels=32,
284
+ kernel_size=1,
285
+ act='relu',
286
+ )
287
+ self.prob_head = SegmentationHead(
288
+ spatial_dims=spatial_dims,
289
+ in_channels=decoder_channels[-1],
290
+ out_channels=1,
291
+ kernel_size=1,
292
+ act='sigmoid',
293
+ )
294
+
295
+ def forward(self, inputs: torch.Tensor):
296
+ """
297
+ Do a typical encoder-decoder-header inference.
298
+
299
+ Args:
300
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
301
+ N is defined by `dimensions`.
302
+
303
+ Returns:
304
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
305
+
306
+ """
307
+ x = inputs
308
+ enc_out = self.encoder(x)
309
+ decoder_out = self.decoder(enc_out)
310
+ dist = self.dist_head(decoder_out)
311
+ prob = self.prob_head(decoder_out)
312
+ return dist,prob
models/flexible_unet_convnext.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+ from . import convnext
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ "convnext_small": (96, 192, 384, 768),
38
+ "convnext_base": (128, 256, 512, 1024),
39
+ "van_b2": (64, 128, 320, 512),
40
+ "van_b1": (64, 128, 320, 512),
41
+ }
42
+
43
+
44
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
45
+ """
46
+ Get the encoder output channels by given backbone name.
47
+
48
+ Args:
49
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
50
+ in_channels: channel of input tensor, default to 3.
51
+
52
+ Returns:
53
+ A tuple of output feature map channels' length .
54
+ """
55
+ encoder_channel_tuple = encoder_feature_channel[backbone]
56
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
57
+ encoder_channel = tuple(encoder_channel_list)
58
+ return encoder_channel
59
+
60
+
61
+ class UNetDecoder(nn.Module):
62
+ """
63
+ UNet Decoder.
64
+ This class refers to `segmentation_models.pytorch
65
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
66
+
67
+ Args:
68
+ spatial_dims: number of spatial dimensions.
69
+ encoder_channels: number of output channels for all feature maps in encoder.
70
+ `len(encoder_channels)` should be no less than 2.
71
+ decoder_channels: number of output channels for all feature maps in decoder.
72
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
73
+ act: activation type and arguments.
74
+ norm: feature normalization type and arguments.
75
+ dropout: dropout ratio.
76
+ bias: whether to have a bias term in convolution blocks in this decoder.
77
+ upsample: upsampling mode, available options are
78
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
79
+ pre_conv: a conv block applied before upsampling.
80
+ Only used in the "nontrainable" or "pixelshuffle" mode.
81
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
82
+ Only used in the "nontrainable" mode.
83
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
84
+ Only used in the "nontrainable" mode.
85
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ spatial_dims: int,
92
+ encoder_channels: Sequence[int],
93
+ decoder_channels: Sequence[int],
94
+ act: Union[str, tuple],
95
+ norm: Union[str, tuple],
96
+ dropout: Union[float, tuple],
97
+ bias: bool,
98
+ upsample: str,
99
+ pre_conv: Optional[str],
100
+ interp_mode: str,
101
+ align_corners: Optional[bool],
102
+ is_pad: bool,
103
+ ):
104
+
105
+ super().__init__()
106
+ if len(encoder_channels) < 2:
107
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
108
+ if len(decoder_channels) != len(encoder_channels) - 1:
109
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
110
+
111
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
112
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
113
+ halves = [True] * (len(skip_channels) - 1)
114
+ halves.append(False)
115
+ blocks = []
116
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
117
+ blocks.append(
118
+ UpCat(
119
+ spatial_dims=spatial_dims,
120
+ in_chns=in_chn,
121
+ cat_chns=skip_chn,
122
+ out_chns=out_chn,
123
+ act=act,
124
+ norm=norm,
125
+ dropout=dropout,
126
+ bias=bias,
127
+ upsample=upsample,
128
+ pre_conv=pre_conv,
129
+ interp_mode=interp_mode,
130
+ align_corners=align_corners,
131
+ halves=halve,
132
+ is_pad=is_pad,
133
+ )
134
+ )
135
+ self.blocks = nn.ModuleList(blocks)
136
+
137
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 3):
138
+ skips = features[:-1][::-1]
139
+ features = features[1:][::-1]
140
+
141
+ x = features[0]
142
+ for i, block in enumerate(self.blocks):
143
+ if i < skip_connect:
144
+ skip = skips[i]
145
+ else:
146
+ skip = None
147
+ x = block(x, skip)
148
+
149
+ return x
150
+
151
+
152
+ class SegmentationHead(nn.Sequential):
153
+ """
154
+ Segmentation head.
155
+ This class refers to `segmentation_models.pytorch
156
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
157
+
158
+ Args:
159
+ spatial_dims: number of spatial dimensions.
160
+ in_channels: number of input channels for the block.
161
+ out_channels: number of output channels for the block.
162
+ kernel_size: kernel size for the conv layer.
163
+ act: activation type and arguments.
164
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
165
+
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ spatial_dims: int,
171
+ in_channels: int,
172
+ out_channels: int,
173
+ kernel_size: int = 3,
174
+ act: Optional[Union[Tuple, str]] = None,
175
+ scale_factor: float = 1.0,
176
+ ):
177
+
178
+ conv_layer = Conv[Conv.CONV, spatial_dims](
179
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
180
+ )
181
+ up_layer: nn.Module = nn.Identity()
182
+ # if scale_factor > 1.0:
183
+ # up_layer = UpSample(
184
+ # in_channels=out_channels,
185
+ # spatial_dims=spatial_dims,
186
+ # scale_factor=scale_factor,
187
+ # mode="deconv",
188
+ # pre_conv=None,
189
+ # interp_mode=InterpolateMode.LINEAR,
190
+ # )
191
+ if scale_factor > 1.0:
192
+ up_layer = UpSample(
193
+ spatial_dims=spatial_dims,
194
+ scale_factor=scale_factor,
195
+ mode="nontrainable",
196
+ pre_conv=None,
197
+ interp_mode=InterpolateMode.LINEAR,
198
+ )
199
+ if act is not None:
200
+ act_layer = get_act_layer(act)
201
+ else:
202
+ act_layer = nn.Identity()
203
+ super().__init__(conv_layer, up_layer, act_layer)
204
+
205
+
206
+ class FlexibleUNet_star(nn.Module):
207
+ """
208
+ A flexible implementation of UNet-like encoder-decoder architecture.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ in_channels: int,
214
+ out_channels: int,
215
+ backbone: str,
216
+ pretrained: bool = False,
217
+ decoder_channels: Tuple = (256, 128, 64, 32),
218
+ #decoder_channels: Tuple = (1024, 512, 256, 128),
219
+ spatial_dims: int = 2,
220
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
221
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
222
+ dropout: Union[float, tuple] = 0.0,
223
+ decoder_bias: bool = False,
224
+ upsample: str = "nontrainable",
225
+ interp_mode: str = "nearest",
226
+ is_pad: bool = True,
227
+ n_rays: int = 32,
228
+ prob_out_channels: int = 1,
229
+ ) -> None:
230
+ """
231
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
232
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
233
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
234
+ is False
235
+
236
+ Args:
237
+ in_channels: number of input channels.
238
+ out_channels: number of output channels.
239
+ backbone: name of backbones to initialize, only support efficientnet right now,
240
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
241
+ pretrained: whether to initialize pretrained ImageNet weights, only available
242
+ for spatial_dims=2 and batch norm is used, default to False.
243
+ decoder_channels: number of output channels for all feature maps in decoder.
244
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
245
+ to (256, 128, 64, 32, 16).
246
+ spatial_dims: number of spatial dimensions, default to 2.
247
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
248
+ "momentum": 0.1}).
249
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
250
+ dropout: dropout ratio, default to 0.0.
251
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
252
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
253
+ ``"nontrainable"``.
254
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
255
+ Only used in the "nontrainable" mode.
256
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
257
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
258
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
259
+ """
260
+ super().__init__()
261
+
262
+ if backbone not in encoder_feature_channel:
263
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
264
+
265
+ if spatial_dims not in (2, 3):
266
+ raise ValueError("spatial_dims can only be 2 or 3.")
267
+
268
+ adv_prop = "ap" in backbone
269
+
270
+ self.backbone = backbone
271
+ self.spatial_dims = spatial_dims
272
+ model_name = backbone
273
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
274
+
275
+ self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
276
+
277
+ self.decoder = UNetDecoder(
278
+ spatial_dims=spatial_dims,
279
+ encoder_channels=encoder_channels,
280
+ decoder_channels=decoder_channels,
281
+ act=act,
282
+ norm=norm,
283
+ dropout=dropout,
284
+ bias=decoder_bias,
285
+ upsample=upsample,
286
+ interp_mode=interp_mode,
287
+ pre_conv=None,
288
+ align_corners=None,
289
+ is_pad=is_pad,
290
+ )
291
+ self.dist_head = SegmentationHead(
292
+ spatial_dims=spatial_dims,
293
+ in_channels=decoder_channels[-1],
294
+ out_channels=n_rays,
295
+ kernel_size=1,
296
+ act='relu',
297
+ scale_factor = 2,
298
+ )
299
+ self.prob_head = SegmentationHead(
300
+ spatial_dims=spatial_dims,
301
+ in_channels=decoder_channels[-1],
302
+ out_channels=prob_out_channels,
303
+ kernel_size=1,
304
+ act='sigmoid',
305
+ scale_factor = 2,
306
+ )
307
+
308
+ def forward(self, inputs: torch.Tensor):
309
+ """
310
+ Do a typical encoder-decoder-header inference.
311
+
312
+ Args:
313
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
314
+ N is defined by `dimensions`.
315
+
316
+ Returns:
317
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
318
+
319
+ """
320
+ x = inputs
321
+ enc_out = self.encoder(x)
322
+ decoder_out = self.decoder(enc_out)
323
+
324
+ dist = self.dist_head(decoder_out)
325
+ prob = self.prob_head(decoder_out)
326
+
327
+ return dist,prob
328
+
329
+
330
+
331
+ class FlexibleUNet_hv(nn.Module):
332
+ """
333
+ A flexible implementation of UNet-like encoder-decoder architecture.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ in_channels: int,
339
+ out_channels: int,
340
+ backbone: str,
341
+ pretrained: bool = False,
342
+ decoder_channels: Tuple = (1024, 512, 256, 128),
343
+ spatial_dims: int = 2,
344
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
345
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
346
+ dropout: Union[float, tuple] = 0.0,
347
+ decoder_bias: bool = False,
348
+ upsample: str = "nontrainable",
349
+ interp_mode: str = "nearest",
350
+ is_pad: bool = True,
351
+ n_rays: int = 32,
352
+ prob_out_channels: int = 1,
353
+ ) -> None:
354
+ """
355
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
356
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
357
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
358
+ is False
359
+
360
+ Args:
361
+ in_channels: number of input channels.
362
+ out_channels: number of output channels.
363
+ backbone: name of backbones to initialize, only support efficientnet right now,
364
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
365
+ pretrained: whether to initialize pretrained ImageNet weights, only available
366
+ for spatial_dims=2 and batch norm is used, default to False.
367
+ decoder_channels: number of output channels for all feature maps in decoder.
368
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
369
+ to (256, 128, 64, 32, 16).
370
+ spatial_dims: number of spatial dimensions, default to 2.
371
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
372
+ "momentum": 0.1}).
373
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
374
+ dropout: dropout ratio, default to 0.0.
375
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
376
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
377
+ ``"nontrainable"``.
378
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
379
+ Only used in the "nontrainable" mode.
380
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
381
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
382
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
383
+ """
384
+ super().__init__()
385
+
386
+ if backbone not in encoder_feature_channel:
387
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
388
+
389
+ if spatial_dims not in (2, 3):
390
+ raise ValueError("spatial_dims can only be 2 or 3.")
391
+
392
+ adv_prop = "ap" in backbone
393
+
394
+ self.backbone = backbone
395
+ self.spatial_dims = spatial_dims
396
+ model_name = backbone
397
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
398
+ self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
399
+ self.decoder = UNetDecoder(
400
+ spatial_dims=spatial_dims,
401
+ encoder_channels=encoder_channels,
402
+ decoder_channels=decoder_channels,
403
+ act=act,
404
+ norm=norm,
405
+ dropout=dropout,
406
+ bias=decoder_bias,
407
+ upsample=upsample,
408
+ interp_mode=interp_mode,
409
+ pre_conv=None,
410
+ align_corners=None,
411
+ is_pad=is_pad,
412
+ )
413
+ self.dist_head = SegmentationHead(
414
+ spatial_dims=spatial_dims,
415
+ in_channels=decoder_channels[-1],
416
+ out_channels=n_rays,
417
+ kernel_size=1,
418
+ act=None,
419
+ scale_factor = 2,
420
+ )
421
+ self.prob_head = SegmentationHead(
422
+ spatial_dims=spatial_dims,
423
+ in_channels=decoder_channels[-1],
424
+ out_channels=prob_out_channels,
425
+ kernel_size=1,
426
+ act='sigmoid',
427
+ scale_factor = 2,
428
+ )
429
+
430
+ def forward(self, inputs: torch.Tensor):
431
+ """
432
+ Do a typical encoder-decoder-header inference.
433
+
434
+ Args:
435
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
436
+ N is defined by `dimensions`.
437
+
438
+ Returns:
439
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
440
+
441
+ """
442
+ x = inputs
443
+ enc_out = self.encoder(x)
444
+ decoder_out = self.decoder(enc_out)
445
+ dist = self.dist_head(decoder_out)
446
+ prob = self.prob_head(decoder_out)
447
+ return dist,prob