File size: 4,492 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
from torch import nn
from typing import Optional
from modules.img2plane.deeplabv3.encoders import get_encoder
from modules.img2plane.deeplabv3.base import initialization as init
from .my_decoder import DeepLabV3Decoder
class DeepLabV3(nn.Module):
"""DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: A number of convolution filters in ASPP module. Default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
Returns:
``torch.nn.Module``: **DeepLabV3**
.. _DeeplabV3:
https://arxiv.org/abs/1706.05587
"""
def initialize(self):
init.initialize_decoder(self.decoder)
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_channels: int = 256,
in_channels: int = 5, # 3 for rgb, and 2 for pixel coordinates
):
super().__init__()
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
)
self.decoder = DeepLabV3Decoder(
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
)
def check_input_shape(self, x):
h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
raise RuntimeError(
f"Wrong input shape height={h}, width={w}. Expected image height and width "
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
)
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
self.check_input_shape(x)
features = self.encoder(x)
decoder_output = self.decoder(*features)
return decoder_output
@torch.no_grad()
def predict(self, x):
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
Args:
x: 4D torch tensor with shape (batch_size, channels, height, width)
Return:
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
"""
if self.training:
self.eval()
x = self.forward(x)
return x
|