|
import math |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class BundledLoss(nn.Module): |
|
def __init__( |
|
self, |
|
single_modality_loss, |
|
multi_view_consistency_loss, |
|
volume_mask_loss, |
|
multi_view_consistency_weight: float, |
|
mvc_time_dependent: bool, |
|
mvc_steepness: float, |
|
modality: List, |
|
consistency_weight: float, |
|
consistency_source: str, |
|
): |
|
super().__init__() |
|
|
|
self.single_modality_loss = single_modality_loss |
|
self.multi_view_consistency_loss = multi_view_consistency_loss |
|
self.volume_mask_loss = volume_mask_loss |
|
|
|
self.mvc_weight = multi_view_consistency_weight |
|
self.mvc_time_dependent = mvc_time_dependent |
|
self.mvc_steepness = mvc_steepness |
|
self.modality = modality |
|
self.consistency_weight = consistency_weight |
|
self.consistency_source = consistency_source |
|
|
|
def forward( |
|
self, |
|
output: Dict, |
|
label, |
|
mask, |
|
epoch: int = 1, |
|
max_epoch: int = 70, |
|
spixel=None, |
|
raw_image=None, |
|
): |
|
|
|
total_loss = 0.0 |
|
loss_dict = {} |
|
for modality in self.modality: |
|
single_loss = self.single_modality_loss(output[modality], label, mask) |
|
|
|
for k, v in single_loss.items(): |
|
loss_dict[f"{k}/{modality}"] = v |
|
total_loss = total_loss + single_loss["total_loss"] |
|
|
|
if self.mvc_time_dependent: |
|
mvc_weight = self.mvc_weight * math.exp( |
|
-self.mvc_steepness * (1 - epoch / max_epoch) ** 2 |
|
) |
|
else: |
|
mvc_weight = self.mvc_weight |
|
|
|
multi_view_consistency_loss = self.multi_view_consistency_loss( |
|
output, label, spixel, raw_image, mask |
|
) |
|
for k, v in multi_view_consistency_loss.items(): |
|
if k not in ["total_loss", "tgt_map"]: |
|
loss_dict.update({k: v}) |
|
|
|
if self.consistency_weight != 0.0 and self.consistency_source == "ensemble": |
|
for modality in self.modality: |
|
consisitency_loss = self.volume_mask_loss( |
|
output[modality]["out_vol"], multi_view_consistency_loss["tgt_map"] |
|
) |
|
consisitency_loss = consisitency_loss["loss"] |
|
loss_dict[f"consistency_loss/{modality}"] = consisitency_loss |
|
total_loss = ( |
|
total_loss |
|
+ self.consistency_weight |
|
* consisitency_loss |
|
* math.exp(-self.mvc_steepness * (1 - epoch / max_epoch) ** 2) |
|
) |
|
|
|
total_loss = total_loss + mvc_weight * multi_view_consistency_loss["total_loss"] |
|
|
|
return {"total_loss": total_loss, **loss_dict} |
|
|