WSCL / losses /loss.py
yhzhai's picture
release code
482ab8a
raw history blame
No virus
3.65 kB
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(
self,
map_label_loss,
volume_label_loss,
map_mask_loss,
volume_mask_loss,
consistency_loss,
entropy_loss,
map_label_weight,
volume_label_weight,
map_mask_weight,
volume_mask_weight,
consistency_weight,
map_entropy_weight,
volume_entropy_weight,
consistency_source,
):
super().__init__()
self.map_label_loss = map_label_loss
self.volume_label_loss = volume_label_loss
self.map_mask_loss = map_mask_loss
self.volume_mask_loss = volume_mask_loss
self.consistency_loss = consistency_loss
self.entropy_loss = entropy_loss
self.map_label_weight = map_label_weight
self.volume_label_weight = volume_label_weight
self.map_mask_weight = map_mask_weight
self.volume_mask_weight = volume_mask_weight
self.consistency_weight = consistency_weight
self.map_entropy_weight = map_entropy_weight
self.volume_entropy_weight = volume_entropy_weight
self.consistency_source = consistency_source
def forward(self, output, label, mask):
total_loss = 0.0
loss_dict = {}
# --- label loss ---
label = label.float()
# compute map label loss anyway
map_label_loss = self.map_label_loss(
output["map_pred"], output["out_map"], label
)["loss"]
total_loss = total_loss + self.map_label_weight * map_label_loss
loss_dict.update({"map_label_loss": map_label_loss})
if self.volume_label_weight != 0.0:
volume_label_loss = self.volume_label_loss(
output["vol_pred"], output["out_vol"], label
)["loss"]
total_loss = total_loss + self.volume_label_weight * volume_label_loss
loss_dict.update({"vol_label_loss": volume_label_loss})
# --- mask loss ---
# compute map mask loss anyway
map_mask_loss = self.map_mask_loss(output["out_map"], mask)["loss"]
total_loss = total_loss + self.map_mask_weight * map_mask_loss
loss_dict.update({"map_mask_loss": map_mask_loss})
if self.volume_mask_weight != 0.0:
volume_mask_loss = self.volume_mask_loss(output["out_vol"], mask)["loss"]
total_loss = total_loss + self.volume_mask_weight * volume_mask_loss
loss_dict.update({"vol_mask_loss": volume_mask_loss})
# --- self-consistency loss ---
if self.consistency_weight != 0.0 and self.consistency_source == "self":
consistency_loss = self.consistency_loss(
output["out_vol"], output["out_map"], label
)
consistency_loss = consistency_loss["loss"]
total_loss = total_loss + self.consistency_weight * consistency_loss
loss_dict.update({"consistency_loss": consistency_loss})
# --- entropy loss ---
if self.map_entropy_weight != 0.0:
map_entropy_loss = self.entropy_loss(output["out_map"])["loss"]
total_loss = total_loss + self.map_entropy_weight * map_entropy_loss
loss_dict.update({"map_entropy_loss": map_entropy_loss})
if self.volume_entropy_weight != 0:
volume_entropy_loss = self.entropy_loss(output["out_vol"])["loss"]
total_loss = total_loss + self.volume_entropy_weight * volume_entropy_loss
loss_dict.update({"vol_entropy_loss": volume_entropy_loss})
loss_dict.update({"total_loss": total_loss})
return loss_dict