|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Loss(nn.Module): |
|
def __init__( |
|
self, |
|
map_label_loss, |
|
volume_label_loss, |
|
map_mask_loss, |
|
volume_mask_loss, |
|
consistency_loss, |
|
entropy_loss, |
|
map_label_weight, |
|
volume_label_weight, |
|
map_mask_weight, |
|
volume_mask_weight, |
|
consistency_weight, |
|
map_entropy_weight, |
|
volume_entropy_weight, |
|
consistency_source, |
|
): |
|
super().__init__() |
|
|
|
self.map_label_loss = map_label_loss |
|
self.volume_label_loss = volume_label_loss |
|
self.map_mask_loss = map_mask_loss |
|
self.volume_mask_loss = volume_mask_loss |
|
self.consistency_loss = consistency_loss |
|
self.entropy_loss = entropy_loss |
|
|
|
self.map_label_weight = map_label_weight |
|
self.volume_label_weight = volume_label_weight |
|
self.map_mask_weight = map_mask_weight |
|
self.volume_mask_weight = volume_mask_weight |
|
self.consistency_weight = consistency_weight |
|
self.map_entropy_weight = map_entropy_weight |
|
self.volume_entropy_weight = volume_entropy_weight |
|
self.consistency_source = consistency_source |
|
|
|
def forward(self, output, label, mask): |
|
total_loss = 0.0 |
|
loss_dict = {} |
|
|
|
|
|
label = label.float() |
|
|
|
map_label_loss = self.map_label_loss( |
|
output["map_pred"], output["out_map"], label |
|
)["loss"] |
|
total_loss = total_loss + self.map_label_weight * map_label_loss |
|
loss_dict.update({"map_label_loss": map_label_loss}) |
|
|
|
if self.volume_label_weight != 0.0: |
|
volume_label_loss = self.volume_label_loss( |
|
output["vol_pred"], output["out_vol"], label |
|
)["loss"] |
|
total_loss = total_loss + self.volume_label_weight * volume_label_loss |
|
loss_dict.update({"vol_label_loss": volume_label_loss}) |
|
|
|
|
|
|
|
map_mask_loss = self.map_mask_loss(output["out_map"], mask)["loss"] |
|
total_loss = total_loss + self.map_mask_weight * map_mask_loss |
|
loss_dict.update({"map_mask_loss": map_mask_loss}) |
|
|
|
if self.volume_mask_weight != 0.0: |
|
volume_mask_loss = self.volume_mask_loss(output["out_vol"], mask)["loss"] |
|
total_loss = total_loss + self.volume_mask_weight * volume_mask_loss |
|
loss_dict.update({"vol_mask_loss": volume_mask_loss}) |
|
|
|
|
|
if self.consistency_weight != 0.0 and self.consistency_source == "self": |
|
consistency_loss = self.consistency_loss( |
|
output["out_vol"], output["out_map"], label |
|
) |
|
consistency_loss = consistency_loss["loss"] |
|
total_loss = total_loss + self.consistency_weight * consistency_loss |
|
loss_dict.update({"consistency_loss": consistency_loss}) |
|
|
|
|
|
if self.map_entropy_weight != 0.0: |
|
map_entropy_loss = self.entropy_loss(output["out_map"])["loss"] |
|
total_loss = total_loss + self.map_entropy_weight * map_entropy_loss |
|
loss_dict.update({"map_entropy_loss": map_entropy_loss}) |
|
|
|
if self.volume_entropy_weight != 0: |
|
volume_entropy_loss = self.entropy_loss(output["out_vol"])["loss"] |
|
total_loss = total_loss + self.volume_entropy_weight * volume_entropy_loss |
|
loss_dict.update({"vol_entropy_loss": volume_entropy_loss}) |
|
|
|
loss_dict.update({"total_loss": total_loss}) |
|
return loss_dict |
|
|