File size: 2,646 Bytes
be2715b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import Tensor
import numpy as np
import glob
import pandas as pd

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter

        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    if input.dim() == 3:
        return dice_coeff(input, target, reduce_batch_first, epsilon)
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]


def iou_2d(outputs: torch.Tensor, labels: torch.Tensor, reduce_batch_first: bool =False, epsilon=1e-6):
    if outputs.dim() == 2 or reduce_batch_first:
        inter = torch.dot(outputs.reshape(-1), labels.reshape(-1))
        union = outputs.sum() + labels.sum() - inter
        return (inter + epsilon)/ (union + epsilon)
    else:
        iou = 0 
        for idx in range(outputs.size(0)):
            iou += iou_2d(outputs[idx], labels[idx])
        return iou/outputs.size(0)

def multiclass_iou(outputs: torch.Tensor, labels: torch.Tensor, reduce_batch_first: bool =False):
    assert outputs.size() == labels.size()
    if outputs.dim() == 3:
        return iou_2d(outputs, labels, reduce_batch_first)
    iou = 0
    for cidx in range(outputs.size(1)):
        iou += iou_2d(outputs[:,cidx,...], labels[:, cidx, ...], reduce_batch_first)
    return iou/outputs.size(1)

def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)