DuyTa's picture
d2f9d5b1696d723607b469880b3d5616ee5b225a5296662d3b494a9f93762c27
864c14f verified
raw
history blame
1.86 kB
import torch
from torch import nn
class Loss_VAE(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss(reduction='sum')
def forward(self, recon_x, x, mu, log_var):
mse = self.mse(recon_x, x)
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
loss = mse + kld
return loss
def DiceScore(
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
) -> torch.Tensor:
"""Computes Dice score metric from full size Tensor and collects average.
Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
The values should be binarized.
include_background: whether to skip Dice computation on the first channel of
the predicted output. Defaults to True.
Returns:
Dice scores per batch and per class, (shape [batch_size, num_classes]).
Raises:
ValueError: when `y_pred` and `y` have different shapes.
"""
y = y.float()
y_pred = y_pred.float()
if y.shape != y_pred.shape:
raise ValueError("y_pred and y should have same shapes.")
# reducing only spatial dimensions (not batch nor channels)
n_len = len(y_pred.shape)
reduce_axis = list(range(2, n_len))
intersection = torch.sum(y * y_pred, dim=reduce_axis)
y_o = torch.sum(y, reduce_axis)
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
denominator = y_o + y_pred_o
return torch.where(
denominator > 0,
(2.0 * intersection) / denominator,
torch.tensor(float("1"), device=y_o.device),
)