TTP / opencd /models /change_detectors /dual_input_encoder_decoder.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
919 Bytes
# Copyright (c) Open-CD. All rights reserved.
from typing import List, Optional
import torch
from torch import Tensor
from opencd.registry import MODELS
from .siamencoder_decoder import SiamEncoderDecoder
@MODELS.register_module()
class DIEncoderDecoder(SiamEncoderDecoder):
"""Dual Input Encoder Decoder segmentors.
DIEncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
"""
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
# `in_channels` is not in the ATTRIBUTE for some backbone CLASS.
img_from, img_to = torch.split(inputs, self.backbone_inchannels, dim=1)
x = self.backbone(img_from, img_to)
if self.with_neck:
x = self.neck(x)
return x