File size: 1,925 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from .bundled_loss import BundledLoss
from .consisitency_loss import get_consistency_loss, get_volume_seg_map
from .entropy_loss import get_entropy_loss
from .loss import Loss
from .map_label_loss import get_map_label_loss
from .map_mask_loss import get_map_mask_loss
from .multi_view_consistency_loss import (
    get_multi_view_consistency_loss,
    get_spixel_tgt_map,
)
from .volume_label_loss import get_volume_label_loss
from .volume_mask_loss import get_volume_mask_loss


def get_bundled_loss(opt):
    """Loss function for the overeall training, including the multi-view
    consistency loss."""
    single_modality_loss = get_loss(opt)
    multi_view_consistency_loss = get_multi_view_consistency_loss(opt)
    volume_mask_loss = get_volume_mask_loss(opt)
    bundled_loss = BundledLoss(
        single_modality_loss,
        multi_view_consistency_loss,
        volume_mask_loss,
        opt.mvc_weight,
        opt.mvc_time_dependent,
        opt.mvc_steepness,
        opt.modality,
        opt.consistency_weight,
        opt.consistency_source,
    )

    return bundled_loss


def get_loss(opt):
    """Loss function for a single model, excluding the multi-view consistency
    loss."""
    map_label_loss = get_map_label_loss(opt)
    volume_label_loss = get_volume_label_loss(opt)
    map_mask_loss = get_map_mask_loss(opt)
    volume_mask_loss = get_volume_mask_loss(opt)
    consisitency_loss = get_consistency_loss(opt)
    entropy_loss = get_entropy_loss(opt)
    loss = Loss(
        map_label_loss,
        volume_label_loss,
        map_mask_loss,
        volume_mask_loss,
        consisitency_loss,
        entropy_loss,
        opt.map_label_weight,
        opt.volume_label_weight,
        opt.map_mask_weight,
        opt.volume_mask_weight,
        opt.consistency_weight,
        opt.map_entropy_weight,
        opt.volume_entropy_weight,
        opt.consistency_source,
    )
    return loss