Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> metrics | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 29/01/2024 16:32 | |
==================================================''' | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
class SeqIOU: | |
def __init__(self, n_class, ignored_sids=[]): | |
self.n_class = n_class | |
self.ignored_sids = ignored_sids | |
self.class_iou = np.zeros(n_class) | |
self.precisions = [] | |
def add(self, pred, target): | |
for i in range(self.n_class): | |
inter = np.sum((pred == target) * (target == i)) | |
union = np.sum(target == i) + np.sum(pred == i) - inter | |
if union > 0: | |
self.class_iou[i] = inter / union | |
acc = (pred == target) | |
if len(self.ignored_sids) == 0: | |
acc_ratio = np.sum(acc) / pred.shape[0] | |
else: | |
pred_mask = (pred >= 0) | |
target_mask = (target >= 0) | |
for i in self.ignored_sids: | |
pred_mask = pred_mask & (pred == i) | |
target_mask = target_mask & (target == i) | |
acc = acc & (1 - pred_mask) | |
tgt = (1 - target_mask) | |
if np.sum(tgt) == 0: | |
acc_ratio = 0 | |
else: | |
acc_ratio = np.sum(acc) / np.sum(tgt) | |
self.precisions.append(acc_ratio) | |
def get_mean_iou(self): | |
return np.mean(self.class_iou) | |
def get_mean_precision(self): | |
return np.mean(self.precisions) | |
def clear(self): | |
self.precisions = [] | |
self.class_iou = np.zeros(self.n_class) | |
def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float: | |
class_iou = np.zeros(n_class) | |
for i in range(n_class): | |
if i in ignored_ids: | |
continue | |
inter = np.sum((pred == target) * (target == i)) | |
union = np.sum(target == i) + np.sum(pred == i) - inter | |
if union > 0: | |
class_iou[i] = inter / union | |
return np.mean(class_iou) | |
# return class_iou | |
def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float: | |
acc = (pred == target) | |
if len(ignored_ids) == 0: | |
return np.sum(acc) / pred.shape[0] | |
else: | |
pred_mask = (pred >= 0) | |
target_mask = (target >= 0) | |
for i in ignored_ids: | |
pred_mask = pred_mask & (pred == i) | |
target_mask = target_mask & (target == i) | |
acc = acc & (1 - pred_mask) | |
tgt = (1 - target_mask) | |
if np.sum(tgt) == 0: | |
return 0 | |
return np.sum(acc) / np.sum(tgt) | |
def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor: | |
bs = pred.shape[0] | |
_, target_ids = torch.topk(target, k=k, dim=1) | |
target_ids = target_ids.cpu().numpy() | |
_, top_ids = torch.topk(pred, k=k, dim=1) # [B, k, 1] | |
top_ids = top_ids.cpu().numpy() | |
acc = 0 | |
for i in range(bs): | |
# print('top_ids: ', i, top_ids[i], target_ids[i]) | |
overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0] | |
acc = acc + len(overlap) / k | |
acc = acc / bs | |
return torch.from_numpy(np.array([acc])).to(pred.device) | |
def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple: | |
''' | |
:param pred: [B, N, C] | |
:param target: [B, N] | |
:param ignored_ids: [] | |
:return: | |
''' | |
pred_ids = torch.max(pred, dim=-1)[1] | |
if len(ignored_ids) == 0: | |
acc = (pred_ids == target) | |
inacc = torch.logical_not(acc) | |
acc_ratio = torch.sum(acc) / torch.numel(target) | |
inacc_ratio = torch.sum(inacc) / torch.numel(target) | |
else: | |
acc = (pred_ids == target) | |
inacc = torch.logical_not(acc) | |
mask = torch.zeros_like(acc) | |
for i in ignored_ids: | |
mask = torch.logical_and(mask, (target == i)) | |
acc = torch.logical_and(acc, torch.logical_not(mask)) | |
acc_ratio = torch.sum(acc) / torch.numel(target) | |
inacc_ratio = torch.sum(inacc) / torch.numel(target) | |
return acc_ratio, inacc_ratio | |
def compute_seg_loss_weight(pred: torch.Tensor, | |
target: torch.Tensor, | |
background_id: int = 0, | |
weight_background: float = 0.1) -> torch.Tensor: | |
''' | |
:param pred: [B, C, N] | |
:param target: [B, N] | |
:param background_id: | |
:param weight_background: | |
:return: | |
''' | |
pred = pred.transpose(-2, -1).contiguous() # [B, N, C] -> [B, C, N] | |
weight = torch.ones(size=(pred.shape[1],), device=pred.device).float() | |
pred = torch.log_softmax(pred, dim=1) | |
weight[background_id] = weight_background | |
seg_loss = F.cross_entropy(pred, target.long(), weight=weight) | |
return seg_loss | |
def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
cls_loss = torch.zeros(size=[], device=pred.device) | |
if len(pred.shape) == 2: | |
n_valid = torch.sum(target > 0) | |
cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum') | |
cls_loss = cls_loss / n_valid | |
else: | |
for i in range(pred.shape[-1]): | |
cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum') | |
n_valid = torch.sum(target > 0) | |
cls_loss = cls_loss / n_valid | |
return cls_loss | |
def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
cls_loss = torch.zeros(size=[], device=pred.device) | |
if len(pred.shape) == 2: | |
cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1), | |
torch.softmax(target, dim=-1), | |
reduction='sum') | |
else: | |
for i in range(pred.shape[-1]): | |
cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1), | |
torch.softmax(target[..., i], dim=-1), | |
reduction='sum') | |
cls_loss = cls_loss / pred.shape[-1] | |
return cls_loss | |
def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None): | |
''' | |
:param pred: [B, N, C] | |
:param target: [B, N, C] | |
:param mean_xyz: | |
:param scale_xyz: | |
:param mask: | |
:return: | |
''' | |
loss = (pred - target) | |
loss = torch.abs(loss).mean(dim=1) | |
if mask is not None: | |
return torch.mean(loss[mask]) | |
else: | |
return torch.mean(loss) | |
def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None): | |
b, c, n = pred.shape | |
p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n)) | |
p3ds_homo = torch.cat( | |
[pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)], | |
dim=1) # [B, 4, N] | |
p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :]) # [B, 3, N] | |
# print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape) | |
p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :] | |
loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1) | |
loss = torch.clamp_max(loss, max=max_value) | |
if mask is not None: | |
return torch.mean(loss[mask]) | |
else: | |
return torch.mean(loss) | |