File size: 4,262 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from functools import cache
import torch
from einops import reduce
from jaxtyping import Float
from lpips import LPIPS
from skimage.metrics import structural_similarity
from torch import Tensor
@torch.no_grad()
def compute_psnr(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, " batch"]:
ground_truth = ground_truth.clip(min=0, max=1)
predicted = predicted.clip(min=0, max=1)
mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean")
return -10 * mse.log10()
@cache
def get_lpips(device: torch.device) -> LPIPS:
return LPIPS(net="vgg").to(device)
@torch.no_grad()
def compute_lpips(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, " batch"]:
value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True)
return value[:, 0, 0, 0]
@torch.no_grad()
def compute_ssim(
ground_truth: Float[Tensor, "batch channel height width"],
predicted: Float[Tensor, "batch channel height width"],
) -> Float[Tensor, " batch"]:
ssim = [
structural_similarity(
gt.detach().cpu().numpy(),
hat.detach().cpu().numpy(),
win_size=11,
gaussian_weights=True,
channel_axis=0,
data_range=1.0,
)
for gt, hat in zip(ground_truth, predicted)
]
return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device)
def compute_geodesic_distance_from_two_matrices(m1, m2):
batch = m1.shape[0]
m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3
cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2
cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)))
cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)) * -1)
theta = torch.acos(cos)
# theta = torch.min(theta, 2*np.pi - theta)
return theta
def angle_error_mat(R1, R2):
cos = (torch.trace(torch.mm(R1.T, R2)) - 1) / 2
cos = torch.clamp(cos, -1.0, 1.0) # numerical errors can make it out of bounds
return torch.rad2deg(torch.abs(torch.acos(cos)))
def angle_error_vec(v1, v2):
n = torch.norm(v1) * torch.norm(v2)
cos_theta = torch.dot(v1, v2) / n
cos_theta = torch.clamp(cos_theta, -1.0, 1.0) # numerical errors can make it out of bounds
return torch.rad2deg(torch.acos(cos_theta))
def compute_translation_error(t1, t2):
return torch.norm(t1 - t2)
@torch.no_grad()
def compute_pose_error(pose_gt, pose_pred):
R_gt = pose_gt[:3, :3]
t_gt = pose_gt[:3, 3]
R = pose_pred[:3, :3]
t = pose_pred[:3, 3]
error_t = angle_error_vec(t, t_gt)
error_t = torch.minimum(error_t, 180 - error_t) # ambiguity of E estimation
error_t_scale = compute_translation_error(t, t_gt)
error_R = angle_error_mat(R, R_gt)
return error_t, error_t_scale, error_R
@torch.no_grad()
def abs_relative_difference(output, target, valid_mask=None):
actual_output = output
actual_target = target
abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target
if valid_mask is not None:
abs_relative_diff[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = output.shape[-1] * output.shape[-2]
abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n
return abs_relative_diff.mean()
# adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py
@torch.no_grad()
def threshold_percentage(output, target, threshold_val, valid_mask=None):
d1 = output / target
d2 = target / output
max_d1_d2 = torch.max(d1, d2)
zero = torch.zeros_like(output)
one = torch.ones_like(output)
bit_mat = torch.where(max_d1_d2 < threshold_val, one, zero)
if valid_mask is not None:
bit_mat[~valid_mask] = 0
n = valid_mask.sum((-1, -2))
else:
n = output.shape[-1] * output.shape[-2]
count_mat = torch.sum(bit_mat, (-1, -2))
threshold_mat = count_mat / n
return threshold_mat.mean()
@torch.no_grad()
def delta1_acc(pred, gt, valid_mask):
return threshold_percentage(pred, gt, 1.25, valid_mask) |