Towsif7's picture
firrst commit
59e40e1
raw
history blame
No virus
3.04 kB
"""
Source url: https://github.com/Karel911/TRACER
Author: Min Seok Lee and Wooseok Shin
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
Changes:
- Refactored code
- Removed unused code
- Added comments
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
from torch import Tensor
from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
from carvekit.ml.arch.tracerb7.att_modules import (
RFB_Block,
aggregation,
ObjectAttention,
)
class TracerDecoder(nn.Module):
"""Tracer Decoder"""
def __init__(
self,
encoder: EfficientEncoderB7,
features_channels: Optional[List[int]] = None,
rfb_channel: Optional[List[int]] = None,
):
"""
Initialize the tracer decoder.
Args:
encoder: The encoder to use.
features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
rfb_channel: The channels of the RFB features. default: [32, 64, 128]
"""
super().__init__()
if rfb_channel is None:
rfb_channel = [32, 64, 128]
if features_channels is None:
features_channels = [48, 80, 224, 640]
self.encoder = encoder
self.features_channels = features_channels
# Receptive Field Blocks
features_channels = rfb_channel
self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
# Multi-level aggregation
self.agg = aggregation(features_channels)
# Object Attention
self.ObjectAttention2 = ObjectAttention(
channel=self.features_channels[1], kernel_size=3
)
self.ObjectAttention1 = ObjectAttention(
channel=self.features_channels[0], kernel_size=3
)
def forward(self, inputs: torch.Tensor) -> Tensor:
"""
Forward pass of the tracer decoder.
Args:
inputs: Preprocessed images.
Returns:
Tensors of segmentation masks and mask of object edges.
"""
features = self.encoder(inputs)
x3_rfb = self.rfb2(features[1])
x4_rfb = self.rfb3(features[2])
x5_rfb = self.rfb4(features[3])
D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
D_1 = self.ObjectAttention2(D_0, features[1])
ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
D_2 = self.ObjectAttention1(ds_map, features[0])
ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
final_map = (ds_map2 + ds_map1 + ds_map0) / 3
return torch.sigmoid(final_map)