WSCL / losses /__init__.py
yhzhai's picture
release code
482ab8a
from .bundled_loss import BundledLoss
from .consisitency_loss import get_consistency_loss, get_volume_seg_map
from .entropy_loss import get_entropy_loss
from .loss import Loss
from .map_label_loss import get_map_label_loss
from .map_mask_loss import get_map_mask_loss
from .multi_view_consistency_loss import (
get_multi_view_consistency_loss,
get_spixel_tgt_map,
)
from .volume_label_loss import get_volume_label_loss
from .volume_mask_loss import get_volume_mask_loss
def get_bundled_loss(opt):
"""Loss function for the overeall training, including the multi-view
consistency loss."""
single_modality_loss = get_loss(opt)
multi_view_consistency_loss = get_multi_view_consistency_loss(opt)
volume_mask_loss = get_volume_mask_loss(opt)
bundled_loss = BundledLoss(
single_modality_loss,
multi_view_consistency_loss,
volume_mask_loss,
opt.mvc_weight,
opt.mvc_time_dependent,
opt.mvc_steepness,
opt.modality,
opt.consistency_weight,
opt.consistency_source,
)
return bundled_loss
def get_loss(opt):
"""Loss function for a single model, excluding the multi-view consistency
loss."""
map_label_loss = get_map_label_loss(opt)
volume_label_loss = get_volume_label_loss(opt)
map_mask_loss = get_map_mask_loss(opt)
volume_mask_loss = get_volume_mask_loss(opt)
consisitency_loss = get_consistency_loss(opt)
entropy_loss = get_entropy_loss(opt)
loss = Loss(
map_label_loss,
volume_label_loss,
map_mask_loss,
volume_mask_loss,
consisitency_loss,
entropy_loss,
opt.map_label_weight,
opt.volume_label_weight,
opt.map_mask_weight,
opt.volume_mask_weight,
opt.consistency_weight,
opt.map_entropy_weight,
opt.volume_entropy_weight,
opt.consistency_source,
)
return loss