thov's picture
Update src/utils.py
f9c2643
raw
history blame contribute delete
No virus
1.05 kB
from itertools import chain
def IOU(pred, label):
sum_ = pred+label
overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
union = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==1])
try:
iou = overlap/(union+overlap)
except ZeroDivisionError:
iou = 0
return iou
def dice_score(pred, label):
sum_ = pred+label
overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2])
predAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
labelAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1])
try:
ds = (2*overlap)/(predAera+labelAera)
except ZeroDivisionError:
ds = 0
return ds
def getTargetSegmentation(batch):
# input is 1-channel of values between 0 and 1
# values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
# output is 1 channel of discrete values : 0, 1, 2 and 3
denom = 0.33333334 # for ACDC this value
return (batch / denom).round().long().squeeze()