Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from typing import Optional, Sequence, Dict | |
import numpy as np | |
import torch | |
from mmengine import MMLogger, print_log | |
from mmengine.evaluator import BaseMetric | |
from prettytable import PrettyTable | |
from torchmetrics.functional.classification import multiclass_precision, multiclass_recall, multiclass_f1_score, \ | |
multiclass_jaccard_index, multiclass_accuracy, binary_accuracy | |
from opencd.registry import METRICS | |
class CDMetric(BaseMetric): | |
default_prefix: Optional[str] = 'cd' | |
def __init__(self, | |
ignore_index: int = 255, | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None, | |
**kwargs) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
self.ignore_index = ignore_index | |
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
for data_sample in data_samples: | |
pred_label = data_sample['pred_sem_seg']['data'].squeeze() | |
# format_only always for test dataset without ground truth | |
gt_label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label) | |
self.results.append((pred_label, gt_label)) | |
def compute_metrics(self, results: list) -> Dict[str, float]: | |
num_classes = len(self.dataset_meta['classes']) | |
class_names = self.dataset_meta['classes'] | |
assert num_classes == 2, 'Only support binary classification in CDMetric.' | |
logger: MMLogger = MMLogger.get_current_instance() | |
pred_label, label = zip(*results) | |
preds = torch.stack(pred_label, dim=0) | |
target = torch.stack(label, dim=0) | |
multiclass_precision_ = multiclass_precision(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) | |
multiclass_recall_ = multiclass_recall(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) | |
multiclass_f1_score_ = multiclass_f1_score(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) | |
multiclass_jaccard_index_ = multiclass_jaccard_index(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) | |
accuracy_ = multiclass_accuracy(preds, target, num_classes=num_classes, average=None, ignore_index=self.ignore_index) | |
binary_accuracy_ = binary_accuracy(preds, target, ignore_index=self.ignore_index) | |
ret_metrics = OrderedDict({ | |
'acc': accuracy_.cpu().numpy(), | |
'p': multiclass_precision_.cpu().numpy(), | |
'r': multiclass_recall_.cpu().numpy(), | |
'f1': multiclass_f1_score_.cpu().numpy(), | |
'iou': multiclass_jaccard_index_.cpu().numpy(), | |
'macc': binary_accuracy_.cpu().numpy(), | |
}) | |
metrics = dict() | |
for k, v in ret_metrics.items(): | |
if k == 'macc': | |
metrics[k] = v.item() | |
else: | |
for i in range(num_classes): | |
metrics[k + '_' + class_names[i]] = v[i].item() | |
# each class table | |
ret_metrics.pop('macc', None) | |
ret_metrics_class = OrderedDict({ | |
ret_metric: np.round(ret_metric_value * 100, 2) | |
for ret_metric, ret_metric_value in ret_metrics.items() | |
}) | |
ret_metrics_class.update({'Class': class_names}) | |
ret_metrics_class.move_to_end('Class', last=False) | |
class_table_data = PrettyTable() | |
for key, val in ret_metrics_class.items(): | |
class_table_data.add_column(key, val) | |
print_log('per class results:', logger) | |
print_log('\n' + class_table_data.get_string(), logger=logger) | |
return metrics | |