TTP / mmseg /ttp /metrics.py
KyanChen's picture
Upload 1861 files
3b96cb1
from collections import OrderedDict
from typing import Optional, Sequence, Dict
import numpy as np
import torch
from mmengine import MMLogger, print_log
from mmengine.evaluator import BaseMetric
from prettytable import PrettyTable
from torchmetrics.functional.classification import multiclass_precision, multiclass_recall, multiclass_f1_score, \
multiclass_jaccard_index, multiclass_accuracy, binary_accuracy
from opencd.registry import METRICS
@METRICS.register_module()
class CDMetric(BaseMetric):
default_prefix: Optional[str] = 'cd'
def __init__(self,
ignore_index: int = 255,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
**kwargs) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.ignore_index = ignore_index
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
# format_only always for test dataset without ground truth
gt_label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label)
self.results.append((pred_label, gt_label))
def compute_metrics(self, results: list) -> Dict[str, float]:
num_classes = len(self.dataset_meta['classes'])
class_names = self.dataset_meta['classes']
assert num_classes == 2, 'Only support binary classification in CDMetric.'
logger: MMLogger = MMLogger.get_current_instance()
pred_label, label = zip(*results)
preds = torch.stack(pred_label, dim=0)
target = torch.stack(label, dim=0)
multiclass_precision_ = multiclass_precision(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index)
multiclass_recall_ = multiclass_recall(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index)
multiclass_f1_score_ = multiclass_f1_score(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index)
multiclass_jaccard_index_ = multiclass_jaccard_index(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index)
accuracy_ = multiclass_accuracy(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index)
binary_accuracy_ = binary_accuracy(preds, target, ignore_index=self.ignore_index)
ret_metrics = OrderedDict({
'acc': accuracy_.cpu().numpy(),
'p': multiclass_precision_.cpu().numpy(),
'r': multiclass_recall_.cpu().numpy(),
'f1': multiclass_f1_score_.cpu().numpy(),
'iou': multiclass_jaccard_index_.cpu().numpy(),
'macc': binary_accuracy_.cpu().numpy(),
})
metrics = dict()
for k, v in ret_metrics.items():
if k == 'macc':
metrics[k] = v.item()
else:
for i in range(num_classes):
metrics[k + '_' + class_names[i]] = v[i].item()
# each class table
ret_metrics.pop('macc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
return metrics