|
|
|
import torchmetrics |
|
import sketchers_v1 as usketchers |
|
from pytorch_v0 import * |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_batch_edt_kernel = ('kernel_dt', ''' |
|
extern "C" __global__ void kernel_dt( |
|
const int bs, |
|
const int h, |
|
const int w, |
|
const float diam2, |
|
float* data, |
|
float* output |
|
) { |
|
int idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
if (idx >= bs*h*w) { |
|
return; |
|
} |
|
int pb = idx / (h*w); |
|
int pi = (idx - h*w*pb) / w; |
|
int pj = (idx - h*w*pb - w*pi); |
|
|
|
float cost; |
|
float mincost = diam2; |
|
for (int j = 0; j < w; j++) { |
|
cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j); |
|
if (cost < mincost) { |
|
mincost = cost; |
|
} |
|
} |
|
output[idx] = mincost; |
|
return; |
|
} |
|
''') |
|
_batch_edt = None |
|
def batch_edt(img, block=1024): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(img.shape)==4: |
|
assert img.shape[1]==1 |
|
img = img.squeeze(1) |
|
expand = True |
|
else: |
|
expand = False |
|
bs,h,w = img.shape |
|
diam2 = h**2 + w**2 |
|
odtype = img.dtype |
|
grid = (img.nelement()+block-1) // block |
|
|
|
|
|
|
|
sums = img.sum(dim=(1,2)) |
|
ans = torch.tensor(np.stack([ |
|
scipy.ndimage.morphology.distance_transform_edt(i) |
|
if s!=0 else |
|
np.ones_like(i) * np.sqrt(diam2) |
|
for i,s in zip(1-img, sums) |
|
]), dtype=odtype) |
|
|
|
if expand: |
|
ans = ans.unsqueeze(1) |
|
return ans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_chamfer_distance(gt, pred, block=1024, return_more=False): |
|
t = batch_chamfer_distance_t(gt, pred, block=block) |
|
p = batch_chamfer_distance_p(gt, pred, block=block) |
|
cd = (t + p) / 2 |
|
return cd |
|
def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): |
|
assert gt.device==pred.device and gt.shape==pred.shape |
|
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
|
dpred = batch_edt(pred, block=block) |
|
cd = (gt*dpred).float().mean((-2,-1)) / np.sqrt(h**2+w**2) |
|
if len(cd.shape)==2: |
|
assert cd.shape[1]==1 |
|
cd = cd.squeeze(1) |
|
return cd |
|
def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): |
|
assert gt.device==pred.device and gt.shape==pred.shape |
|
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
|
dgt = batch_edt(gt, block=block) |
|
cd = (pred*dgt).float().mean((-2,-1)) / np.sqrt(h**2+w**2) |
|
if len(cd.shape)==2: |
|
assert cd.shape[1]==1 |
|
cd = cd.squeeze(1) |
|
return cd |
|
|
|
|
|
|
|
def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): |
|
assert gt.device==pred.device and gt.shape==pred.shape |
|
bs,h,w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
|
dgt = batch_edt(gt, block=block) |
|
dpred = batch_edt(pred, block=block) |
|
hd = torch.stack([ |
|
(dgt*pred).amax(dim=(-2,-1)), |
|
(dpred*gt).amax(dim=(-2,-1)), |
|
]).amax(dim=0).float() / np.sqrt(h**2+w**2) |
|
if len(hd.shape)==2: |
|
assert hd.shape[1]==1 |
|
hd = hd.squeeze(1) |
|
return hd |
|
|
|
|
|
|
|
|
|
class ChamferDistance2dMetric(torchmetrics.Metric): |
|
full_state_update=False |
|
def __init__( |
|
self, block=1024, convert_dog=True, |
|
t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.block = block |
|
self.convert_dog = convert_dog |
|
self.dog_params = { |
|
't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon, |
|
'kernel_factor': kernel_factor, 'clip': clip, |
|
} |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
if self.convert_dog: |
|
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() |
|
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() |
|
dist = batch_chamfer_distance(target, preds, block=self.block) |
|
|
|
|
|
|
|
return dist.sum().item() |
|
|
|
def calc(self, preds: torch.Tensor, target: torch.Tensor): |
|
if self.convert_dog: |
|
preds = (usketchers.batch_dog(preds, **self.dog_params) > 0.5).float() |
|
target = (usketchers.batch_dog(target, **self.dog_params) > 0.5).float() |
|
dist = batch_chamfer_distance(target, preds, block=self.block) |
|
|
|
|
|
|
|
return dist.sum().item() |
|
|
|
def compute(self): |
|
return self.running_sum.float() / self.running_count |
|
|
|
class ChamferDistance2dTMetric(ChamferDistance2dMetric): |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
if self.convert_dog: |
|
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() |
|
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() |
|
dist = batch_chamfer_distance_t(target, preds, block=self.block) |
|
self.running_sum += dist.sum() |
|
self.running_count += len(dist) |
|
return dist.sum().item() |
|
|
|
class ChamferDistance2dPMetric(ChamferDistance2dMetric): |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
if self.convert_dog: |
|
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() |
|
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() |
|
dist = batch_chamfer_distance_p(target, preds, block=self.block) |
|
self.running_sum += dist.sum() |
|
self.running_count += len(dist) |
|
return dist.sum().item() |
|
|
|
class HausdorffDistance2dMetric(torchmetrics.Metric): |
|
def __init__( |
|
self, block=1024, convert_dog=True, |
|
t=2.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.block = block |
|
self.convert_dog = convert_dog |
|
self.dog_params = { |
|
't': t, 'sigma': sigma, 'k': k, 'epsilon': epsilon, |
|
'kernel_factor': kernel_factor, 'clip': clip, |
|
} |
|
self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
return |
|
def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
if self.convert_dog: |
|
preds = (usketchers.batch_dog(preds, **self.dog_params)>0.5).float() |
|
target = (usketchers.batch_dog(target, **self.dog_params)>0.5).float() |
|
dist = batch_hausdorff_distance(target, preds, block=self.block) |
|
self.running_sum += dist.sum() |
|
self.running_count += len(dist) |
|
return |
|
def compute(self): |
|
return self.running_sum.float() / self.running_count |
|
|
|
|
|
|
|
|
|
|
|
|
|
|