File size: 1,048 Bytes
32278ee
e6f4cd4
32278ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6f4cd4
 
 
 
 
 
 
32278ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from itertools import chain

def IOU(pred, label):
    sum_ = pred+label
    overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
    union = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==1])
    try:
        iou = overlap/(union+overlap)
    except ZeroDivisionError:
        iou = 0
    return iou

def dice_score(pred, label):
    sum_ = pred+label
    overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
    predAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
    labelAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
    try:
        ds = (2*overlap)/(predAera+labelAera)
    except ZeroDivisionError:
        ds = 0
    return ds

def getTargetSegmentation(batch):
    # input is 1-channel of values between 0 and 1
    # values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
    # output is 1 channel of discrete values : 0, 1, 2 and 3

    denom = 0.33333334  # for ACDC this value
    return (batch / denom).round().long().squeeze()