|
from fastai.layers import * |
|
from .layers import * |
|
from fastai.torch_core import * |
|
from fastai.callbacks.hooks import * |
|
from fastai.vision import * |
|
|
|
|
|
|
|
|
|
__all__ = ['DynamicUnetDeep', 'DynamicUnetWide'] |
|
|
|
|
|
def _get_sfs_idxs(sizes: Sizes) -> List[int]: |
|
"Get the indexes of the layers where the size of the activation changes." |
|
feature_szs = [size[-1] for size in sizes] |
|
sfs_idxs = list( |
|
np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0] |
|
) |
|
if feature_szs[0] != feature_szs[1]: |
|
sfs_idxs = [0] + sfs_idxs |
|
return sfs_idxs |
|
|
|
|
|
class CustomPixelShuffle_ICNR(nn.Module): |
|
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`." |
|
|
|
def __init__( |
|
self, |
|
ni: int, |
|
nf: int = None, |
|
scale: int = 2, |
|
blur: bool = False, |
|
leaky: float = None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
nf = ifnone(nf, ni) |
|
self.conv = custom_conv_layer( |
|
ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs |
|
) |
|
icnr(self.conv[0].weight) |
|
self.shuf = nn.PixelShuffle(scale) |
|
|
|
|
|
|
|
self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) |
|
self.blur = nn.AvgPool2d(2, stride=1) |
|
self.relu = relu(True, leaky=leaky) |
|
|
|
def forward(self, x): |
|
x = self.shuf(self.relu(self.conv(x))) |
|
return self.blur(self.pad(x)) if self.blur else x |
|
|
|
|
|
class UnetBlockDeep(nn.Module): |
|
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." |
|
|
|
def __init__( |
|
self, |
|
up_in_c: int, |
|
x_in_c: int, |
|
hook: Hook, |
|
final_div: bool = True, |
|
blur: bool = False, |
|
leaky: float = None, |
|
self_attention: bool = False, |
|
nf_factor: float = 1.0, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.hook = hook |
|
self.shuf = CustomPixelShuffle_ICNR( |
|
up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs |
|
) |
|
self.bn = batchnorm_2d(x_in_c) |
|
ni = up_in_c // 2 + x_in_c |
|
nf = int((ni if final_div else ni // 2) * nf_factor) |
|
self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) |
|
self.conv2 = custom_conv_layer( |
|
nf, nf, leaky=leaky, self_attention=self_attention, **kwargs |
|
) |
|
self.relu = relu(leaky=leaky) |
|
|
|
def forward(self, up_in: Tensor) -> Tensor: |
|
s = self.hook.stored |
|
up_out = self.shuf(up_in) |
|
ssh = s.shape[-2:] |
|
if ssh != up_out.shape[-2:]: |
|
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') |
|
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) |
|
return self.conv2(self.conv1(cat_x)) |
|
|
|
|
|
class DynamicUnetDeep(SequentialEx): |
|
"Create a U-Net from a given architecture." |
|
|
|
def __init__( |
|
self, |
|
encoder: nn.Module, |
|
n_classes: int, |
|
blur: bool = False, |
|
blur_final=True, |
|
self_attention: bool = False, |
|
y_range: Optional[Tuple[float, float]] = None, |
|
last_cross: bool = True, |
|
bottle: bool = False, |
|
norm_type: Optional[NormType] = NormType.Batch, |
|
nf_factor: float = 1.0, |
|
**kwargs |
|
): |
|
extra_bn = norm_type == NormType.Spectral |
|
imsize = (256, 256) |
|
sfs_szs = model_sizes(encoder, size=imsize) |
|
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) |
|
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) |
|
x = dummy_eval(encoder, imsize).detach() |
|
|
|
ni = sfs_szs[-1][1] |
|
middle_conv = nn.Sequential( |
|
custom_conv_layer( |
|
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs |
|
), |
|
custom_conv_layer( |
|
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs |
|
), |
|
).eval() |
|
x = middle_conv(x) |
|
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv] |
|
|
|
for i, idx in enumerate(sfs_idxs): |
|
not_final = i != len(sfs_idxs) - 1 |
|
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) |
|
do_blur = blur and (not_final or blur_final) |
|
sa = self_attention and (i == len(sfs_idxs) - 3) |
|
unet_block = UnetBlockDeep( |
|
up_in_c, |
|
x_in_c, |
|
self.sfs[i], |
|
final_div=not_final, |
|
blur=blur, |
|
self_attention=sa, |
|
norm_type=norm_type, |
|
extra_bn=extra_bn, |
|
nf_factor=nf_factor, |
|
**kwargs |
|
).eval() |
|
layers.append(unet_block) |
|
x = unet_block(x) |
|
|
|
ni = x.shape[1] |
|
if imsize != sfs_szs[0][-2:]: |
|
layers.append(PixelShuffle_ICNR(ni, **kwargs)) |
|
if last_cross: |
|
layers.append(MergeLayer(dense=True)) |
|
ni += in_channels(encoder) |
|
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) |
|
layers += [ |
|
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) |
|
] |
|
if y_range is not None: |
|
layers.append(SigmoidRange(*y_range)) |
|
super().__init__(*layers) |
|
|
|
def __del__(self): |
|
if hasattr(self, "sfs"): |
|
self.sfs.remove() |
|
|
|
|
|
|
|
class UnetBlockWide(nn.Module): |
|
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." |
|
|
|
def __init__( |
|
self, |
|
up_in_c: int, |
|
x_in_c: int, |
|
n_out: int, |
|
hook: Hook, |
|
final_div: bool = True, |
|
blur: bool = False, |
|
leaky: float = None, |
|
self_attention: bool = False, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.hook = hook |
|
up_out = x_out = n_out // 2 |
|
self.shuf = CustomPixelShuffle_ICNR( |
|
up_in_c, up_out, blur=blur, leaky=leaky, **kwargs |
|
) |
|
self.bn = batchnorm_2d(x_in_c) |
|
ni = up_out + x_in_c |
|
self.conv = custom_conv_layer( |
|
ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs |
|
) |
|
self.relu = relu(leaky=leaky) |
|
|
|
def forward(self, up_in: Tensor) -> Tensor: |
|
s = self.hook.stored |
|
up_out = self.shuf(up_in) |
|
ssh = s.shape[-2:] |
|
if ssh != up_out.shape[-2:]: |
|
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') |
|
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) |
|
return self.conv(cat_x) |
|
|
|
|
|
class DynamicUnetWide(SequentialEx): |
|
"Create a U-Net from a given architecture." |
|
|
|
def __init__( |
|
self, |
|
encoder: nn.Module, |
|
n_classes: int, |
|
blur: bool = False, |
|
blur_final=True, |
|
self_attention: bool = False, |
|
y_range: Optional[Tuple[float, float]] = None, |
|
last_cross: bool = True, |
|
bottle: bool = False, |
|
norm_type: Optional[NormType] = NormType.Batch, |
|
nf_factor: int = 1, |
|
**kwargs |
|
): |
|
|
|
nf = 512 * nf_factor |
|
extra_bn = norm_type == NormType.Spectral |
|
imsize = (256, 256) |
|
sfs_szs = model_sizes(encoder, size=imsize) |
|
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) |
|
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) |
|
x = dummy_eval(encoder, imsize).detach() |
|
|
|
ni = sfs_szs[-1][1] |
|
middle_conv = nn.Sequential( |
|
custom_conv_layer( |
|
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs |
|
), |
|
custom_conv_layer( |
|
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs |
|
), |
|
).eval() |
|
x = middle_conv(x) |
|
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv] |
|
|
|
for i, idx in enumerate(sfs_idxs): |
|
not_final = i != len(sfs_idxs) - 1 |
|
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) |
|
do_blur = blur and (not_final or blur_final) |
|
sa = self_attention and (i == len(sfs_idxs) - 3) |
|
|
|
n_out = nf if not_final else nf // 2 |
|
|
|
unet_block = UnetBlockWide( |
|
up_in_c, |
|
x_in_c, |
|
n_out, |
|
self.sfs[i], |
|
final_div=not_final, |
|
blur=blur, |
|
self_attention=sa, |
|
norm_type=norm_type, |
|
extra_bn=extra_bn, |
|
**kwargs |
|
).eval() |
|
layers.append(unet_block) |
|
x = unet_block(x) |
|
|
|
ni = x.shape[1] |
|
if imsize != sfs_szs[0][-2:]: |
|
layers.append(PixelShuffle_ICNR(ni, **kwargs)) |
|
if last_cross: |
|
layers.append(MergeLayer(dense=True)) |
|
ni += in_channels(encoder) |
|
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) |
|
layers += [ |
|
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) |
|
] |
|
if y_range is not None: |
|
layers.append(SigmoidRange(*y_range)) |
|
super().__init__(*layers) |
|
|
|
def __del__(self): |
|
if hasattr(self, "sfs"): |
|
self.sfs.remove() |
|
|