thov's picture
implementation perso of dice and iou
32278ee
raw
history blame
No virus
1.1 kB
from itertools import chain
from medpy.metric.binary import dc, hd, asd, assd
def IOU(pred, label):
sum_ = pred+label
overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
union = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==1])
try:
iou = overlap/(union+overlap)
except ZeroDivisionError:
iou = 0
return iou
def dice_score(pred, label):
sum_ = pred+label
overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
predAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
labelAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
try:
ds = (2*overlap)/(predAera+labelAera)
except ZeroDivisionError:
ds = 0
return ds
def getTargetSegmentation(batch):
# input is 1-channel of values between 0 and 1
# values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
# output is 1 channel of discrete values : 0, 1, 2 and 3
denom = 0.33333334 # for ACDC this value
return (batch / denom).round().long().squeeze()