|
|
"""Contains the UNet decoder. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from sharp.models.blocks import ( |
|
|
NormLayerName, |
|
|
norm_layer_2d, |
|
|
residual_block_2d, |
|
|
) |
|
|
|
|
|
from .base_decoder import BaseDecoder |
|
|
|
|
|
|
|
|
class UNetDecoder(BaseDecoder): |
|
|
"""Decoder of UNet model.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim_out: int, |
|
|
width: List[int] | int, |
|
|
steps: int = 5, |
|
|
norm_type: NormLayerName = "group_norm", |
|
|
norm_num_groups=8, |
|
|
blocks_per_layer=2, |
|
|
) -> None: |
|
|
"""Initialize UNet Decoder. |
|
|
|
|
|
Args: |
|
|
dim_out: The number of output channels. |
|
|
width: Width of last input feature map from encoder |
|
|
or the width list of all input feature maps from encoder. |
|
|
steps: The number of upsampling steps. |
|
|
norm_type: Which kind of normalization layer to use. |
|
|
norm_num_groups: How many groups to use for group norm (if relevant). |
|
|
blocks_per_layer: How many blocks per layer to use. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
if blocks_per_layer < 1: |
|
|
raise ValueError("blocks_per_layer must be greater or equal to one.") |
|
|
|
|
|
self.dim_out = dim_out |
|
|
|
|
|
self.convs_up = nn.ModuleList() |
|
|
|
|
|
self.output_dims: list[int] |
|
|
|
|
|
if isinstance(width, int): |
|
|
self.input_dims = [width >> i for i in range(0, steps + 1)] |
|
|
else: |
|
|
self.input_dims = width[::-1][: steps + 1] |
|
|
|
|
|
for i_step in range(steps): |
|
|
input_width = self.input_dims[i_step] |
|
|
current_width = self.input_dims[i_step + 1] |
|
|
convs_up_i = nn.Sequential( |
|
|
nn.Upsample(scale_factor=2), |
|
|
residual_block_2d( |
|
|
input_width * (1 if i_step == 0 else 2), |
|
|
current_width, |
|
|
norm_type=norm_type, |
|
|
norm_num_groups=norm_num_groups, |
|
|
), |
|
|
*[ |
|
|
residual_block_2d( |
|
|
current_width, |
|
|
current_width, |
|
|
norm_type=norm_type, |
|
|
norm_num_groups=norm_num_groups, |
|
|
) |
|
|
for _ in range(blocks_per_layer - 1) |
|
|
], |
|
|
) |
|
|
self.convs_up.append(convs_up_i) |
|
|
input_width = 2 * current_width |
|
|
current_width //= 2 |
|
|
|
|
|
last_width = self.input_dims[-1] |
|
|
self.conv_out = nn.Sequential( |
|
|
norm_layer_2d(last_width * 2, norm_type, num_groups=norm_num_groups), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(last_width * 2, dim_out, 1), |
|
|
norm_layer_2d(dim_out, norm_type, num_groups=norm_num_groups), |
|
|
nn.ReLU(), |
|
|
) |
|
|
|
|
|
def forward(self, features: list[torch.Tensor]) -> torch.Tensor: |
|
|
"""Apply UNet to image. |
|
|
|
|
|
Args: |
|
|
features: The input multi-level feature map from encoder. |
|
|
|
|
|
Returns: |
|
|
The output feature map. |
|
|
""" |
|
|
i_feature_layer = len(features) - 1 |
|
|
out = self.convs_up[0](features[i_feature_layer]) |
|
|
i_feature_layer -= 1 |
|
|
for conv_up in self.convs_up[1:]: |
|
|
out = conv_up(torch.cat([out, features[i_feature_layer]], dim=1)) |
|
|
i_feature_layer -= 1 |
|
|
out = self.conv_out(torch.cat([out, features[i_feature_layer]], dim=1)) |
|
|
|
|
|
return out |
|
|
|