TTP / opencd /models /decode_heads /general_scd_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame contribute delete
737 Bytes
# Copyright (c) Open-CD. All rights reserved.
from opencd.registry import MODELS
from .multi_head import MultiHeadDecoder
@MODELS.register_module()
class GeneralSCDHead(MultiHeadDecoder):
"""The Head of General Semantic Change Detection Head."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, inputs):
inputs1, inputs2 = inputs
out1 = self.semantic_cd_head(inputs1)
out2 = self.semantic_cd_head_aux(inputs2)
inputs_ = self.binary_cd_neck(inputs1, inputs2)
out = self.binary_cd_head(inputs_)
out_dict = dict(
seg_logits=out,
seg_logits_from=out1,
seg_logits_to=out2
)
return out_dict