Spaces:
Build error
Build error
# based on https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d | |
import torch | |
class DepthMetric: | |
def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], depth_cap=None, prediction_type='depth'): | |
self.thresholds = thresholds | |
self.depth_cap = depth_cap | |
self.metric_keys = self.get_metric_keys() | |
self.prediction_type = prediction_type | |
def compute_scale_and_shift(self, prediction, target, mask): | |
# system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
a_01 = torch.sum(mask * prediction, (1, 2)) | |
a_11 = torch.sum(mask, (1, 2)) | |
# right hand side: b = [b_0, b_1] | |
b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
b_1 = torch.sum(mask * target, (1, 2)) | |
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
x_0 = torch.zeros_like(b_0) | |
x_1 = torch.zeros_like(b_1) | |
det = a_00 * a_11 - a_01 * a_01 | |
# A needs to be a positive definite matrix. | |
valid = det > 0 | |
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
return x_0, x_1 | |
def get_metric_keys(self): | |
metric_keys = [] | |
for threshold in self.thresholds: | |
metric_keys.append('d>{}'.format(threshold)) | |
metric_keys.append('rmse') | |
metric_keys.append('l1_err') | |
metric_keys.append('abs_rel') | |
return metric_keys | |
def compute_metrics(self, prediction, target, mask): | |
# check inputs | |
prediction = prediction.float() | |
target = target.float() | |
mask = mask.float() | |
assert prediction.shape == target.shape == mask.shape | |
assert len(prediction.shape) == 4 | |
assert prediction.shape[1] == 1 | |
assert prediction.dtype == target.dtype == mask.dtype == torch.float32 | |
# process inputs | |
prediction = prediction.squeeze(1) | |
target = target.squeeze(1) | |
mask = (mask.squeeze(1) > 0.5).long() | |
# output dict | |
metrics = {} | |
# get the predicted disparity | |
prediction_disparity = torch.zeros_like(prediction) | |
if self.prediction_type == 'depth': | |
prediction_disparity[mask == 1] = 1.0 / (prediction[mask == 1] + 1.e-6) | |
elif self.prediction_type == 'disparity': | |
prediction_disparity[mask == 1] = prediction[mask == 1] | |
else: | |
raise ValueError('Unknown prediction type: {}'.format(self.prediction_type)) | |
# transform predicted disparity to align with depth | |
target_disparity = torch.zeros_like(target) | |
target_disparity[mask == 1] = 1.0 / target[mask == 1] | |
scale, shift = self.compute_scale_and_shift(prediction_disparity, target_disparity, mask) | |
prediction_aligned = scale.view(-1, 1, 1) * prediction_disparity + shift.view(-1, 1, 1) | |
if self.depth_cap is not None: | |
disparity_cap = 1.0 / self.depth_cap | |
prediction_aligned[prediction_aligned < disparity_cap] = disparity_cap | |
prediciton_depth = 1.0 / prediction_aligned | |
# delta > threshold, [batch_size, ] | |
for threshold in self.thresholds: | |
err = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
err[mask == 1] = torch.max( | |
prediciton_depth[mask == 1] / target[mask == 1], | |
target[mask == 1] / prediciton_depth[mask == 1], | |
) | |
err[mask == 1] = (err[mask == 1] > threshold).float() | |
metrics['d>{}'.format(threshold)] = torch.sum(err, (1, 2)) / torch.sum(mask, (1, 2)) | |
# rmse, [batch_size, ] | |
rmse = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
rmse[mask == 1] = (prediciton_depth[mask == 1] - target[mask == 1]) ** 2 | |
rmse = torch.sum(rmse, (1, 2)) / torch.sum(mask, (1, 2)) | |
metrics['rmse'] = torch.sqrt(rmse) | |
# l1 error, [batch_size, ] | |
l1_err = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
l1_err[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) | |
metrics['l1_err'] = torch.sum(l1_err, (1, 2)) / torch.sum(mask, (1, 2)) | |
# abs_rel, [batch_size, ] | |
abs_rel = torch.zeros_like(prediciton_depth, dtype=torch.float) | |
abs_rel[mask == 1] = torch.abs(prediciton_depth[mask == 1] - target[mask == 1]) / target[mask == 1] | |
metrics['abs_rel'] = torch.sum(abs_rel, (1, 2)) / torch.sum(mask, (1, 2)) | |
return metrics, prediciton_depth.unsqueeze(1) | |