Realcat
fix: eloftr
63f3cf2
raw
history blame
7.48 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> metrics
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 16:32
=================================================='''
import torch
import numpy as np
import torch.nn.functional as F
class SeqIOU:
def __init__(self, n_class, ignored_sids=[]):
self.n_class = n_class
self.ignored_sids = ignored_sids
self.class_iou = np.zeros(n_class)
self.precisions = []
def add(self, pred, target):
for i in range(self.n_class):
inter = np.sum((pred == target) * (target == i))
union = np.sum(target == i) + np.sum(pred == i) - inter
if union > 0:
self.class_iou[i] = inter / union
acc = (pred == target)
if len(self.ignored_sids) == 0:
acc_ratio = np.sum(acc) / pred.shape[0]
else:
pred_mask = (pred >= 0)
target_mask = (target >= 0)
for i in self.ignored_sids:
pred_mask = pred_mask & (pred == i)
target_mask = target_mask & (target == i)
acc = acc & (1 - pred_mask)
tgt = (1 - target_mask)
if np.sum(tgt) == 0:
acc_ratio = 0
else:
acc_ratio = np.sum(acc) / np.sum(tgt)
self.precisions.append(acc_ratio)
def get_mean_iou(self):
return np.mean(self.class_iou)
def get_mean_precision(self):
return np.mean(self.precisions)
def clear(self):
self.precisions = []
self.class_iou = np.zeros(self.n_class)
def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float:
class_iou = np.zeros(n_class)
for i in range(n_class):
if i in ignored_ids:
continue
inter = np.sum((pred == target) * (target == i))
union = np.sum(target == i) + np.sum(pred == i) - inter
if union > 0:
class_iou[i] = inter / union
return np.mean(class_iou)
# return class_iou
def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float:
acc = (pred == target)
if len(ignored_ids) == 0:
return np.sum(acc) / pred.shape[0]
else:
pred_mask = (pred >= 0)
target_mask = (target >= 0)
for i in ignored_ids:
pred_mask = pred_mask & (pred == i)
target_mask = target_mask & (target == i)
acc = acc & (1 - pred_mask)
tgt = (1 - target_mask)
if np.sum(tgt) == 0:
return 0
return np.sum(acc) / np.sum(tgt)
def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor:
bs = pred.shape[0]
_, target_ids = torch.topk(target, k=k, dim=1)
target_ids = target_ids.cpu().numpy()
_, top_ids = torch.topk(pred, k=k, dim=1) # [B, k, 1]
top_ids = top_ids.cpu().numpy()
acc = 0
for i in range(bs):
# print('top_ids: ', i, top_ids[i], target_ids[i])
overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0]
acc = acc + len(overlap) / k
acc = acc / bs
return torch.from_numpy(np.array([acc])).to(pred.device)
def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple:
'''
:param pred: [B, N, C]
:param target: [B, N]
:param ignored_ids: []
:return:
'''
pred_ids = torch.max(pred, dim=-1)[1]
if len(ignored_ids) == 0:
acc = (pred_ids == target)
inacc = torch.logical_not(acc)
acc_ratio = torch.sum(acc) / torch.numel(target)
inacc_ratio = torch.sum(inacc) / torch.numel(target)
else:
acc = (pred_ids == target)
inacc = torch.logical_not(acc)
mask = torch.zeros_like(acc)
for i in ignored_ids:
mask = torch.logical_and(mask, (target == i))
acc = torch.logical_and(acc, torch.logical_not(mask))
acc_ratio = torch.sum(acc) / torch.numel(target)
inacc_ratio = torch.sum(inacc) / torch.numel(target)
return acc_ratio, inacc_ratio
def compute_seg_loss_weight(pred: torch.Tensor,
target: torch.Tensor,
background_id: int = 0,
weight_background: float = 0.1) -> torch.Tensor:
'''
:param pred: [B, C, N]
:param target: [B, N]
:param background_id:
:param weight_background:
:return:
'''
pred = pred.transpose(-2, -1).contiguous() # [B, N, C] -> [B, C, N]
weight = torch.ones(size=(pred.shape[1],), device=pred.device).float()
pred = torch.log_softmax(pred, dim=1)
weight[background_id] = weight_background
seg_loss = F.cross_entropy(pred, target.long(), weight=weight)
return seg_loss
def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
cls_loss = torch.zeros(size=[], device=pred.device)
if len(pred.shape) == 2:
n_valid = torch.sum(target > 0)
cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum')
cls_loss = cls_loss / n_valid
else:
for i in range(pred.shape[-1]):
cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum')
n_valid = torch.sum(target > 0)
cls_loss = cls_loss / n_valid
return cls_loss
def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
cls_loss = torch.zeros(size=[], device=pred.device)
if len(pred.shape) == 2:
cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1),
torch.softmax(target, dim=-1),
reduction='sum')
else:
for i in range(pred.shape[-1]):
cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1),
torch.softmax(target[..., i], dim=-1),
reduction='sum')
cls_loss = cls_loss / pred.shape[-1]
return cls_loss
def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None):
'''
:param pred: [B, N, C]
:param target: [B, N, C]
:param mean_xyz:
:param scale_xyz:
:param mask:
:return:
'''
loss = (pred - target)
loss = torch.abs(loss).mean(dim=1)
if mask is not None:
return torch.mean(loss[mask])
else:
return torch.mean(loss)
def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None):
b, c, n = pred.shape
p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n))
p3ds_homo = torch.cat(
[pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)],
dim=1) # [B, 4, N]
p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :]) # [B, 3, N]
# print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape)
p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :]
loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1)
loss = torch.clamp_max(loss, max=max_value)
if mask is not None:
return torch.mean(loss[mask])
else:
return torch.mean(loss)