from itertools import chain def IOU(pred, label): sum_ = pred+label overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2]) union = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==1]) try: iou = overlap/(union+overlap) except ZeroDivisionError: iou = 0 return iou def dice_score(pred, label): sum_ = pred+label overlap = sum([1 for _, val in enumerate(list(chain(*sum_))) if val==2]) predAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1]) labelAera = sum([1 for _, val in enumerate(list(chain(*pred))) if val==1]) try: ds = (2*overlap)/(predAera+labelAera) except ZeroDivisionError: ds = 0 return ds def getTargetSegmentation(batch): # input is 1-channel of values between 0 and 1 # values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647 # output is 1 channel of discrete values : 0, 1, 2 and 3 denom = 0.33333334 # for ACDC this value return (batch / denom).round().long().squeeze()