Spaces:
Runtime error
Runtime error
File size: 5,676 Bytes
231edce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch.nn as nn
from typing import Optional, Union, List
from ...encoders.create import create_encoder
from ...base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
)
from .decoder import UnetDecoder
class Unet(SegmentationModel):
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
for fusing decoder blocks with skip connections.
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
Length of the list should be the same as **encoder_depth**
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Available options are **True, False, "inplace"**
decoder_attention_type: Attention module used in decoder of the model. Available options are
**None** and **scse** (https://arxiv.org/abs/1808.08127).
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
Returns:
``torch.nn.Module``: Unet
.. _Unet:
https://arxiv.org/abs/1505.04597
"""
def __init__(
self,
encoder_name: str,
encoder_params: dict = {"pretrained": True, "depth": 5},
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
deep_supervision: bool = False,
dropout: float = 0.2,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
upsampling: int = 1,
aux_params: Optional[dict] = None,
):
super().__init__()
encoder_depth = encoder_params.pop("depth", 5)
self.encoder = create_encoder(
name=encoder_name,
encoder_params=encoder_params,
in_channels=in_channels
)
self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
deep_supervision=deep_supervision,
attention_type=decoder_attention_type,
)
self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=classes,
dropout=dropout,
kernel_size=3,
upsampling=upsampling,
)
self.deep_supervision = deep_supervision
if self.deep_supervision:
self.supervisor_heads = []
self.supervisor_heads.append(
SegmentationHead(
in_channels=decoder_channels[-2],
out_channels=classes,
dropout=dropout,
kernel_size=3,
upsampling=upsampling,
)
)
self.supervisor_heads.append(
SegmentationHead(
in_channels=decoder_channels[-3],
out_channels=classes,
dropout=dropout,
kernel_size=3,
upsampling=upsampling,
)
)
self.supervisor_heads = nn.Sequential(*self.supervisor_heads)
if aux_params is not None:
self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
else:
self.classification_head = None
self.name = "u-{}".format(encoder_name)
self.initialize()
|