Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Sequence, Union | |
import numpy as np | |
import torch | |
from mmdet.evaluation.metrics.coco_panoptic_metric import print_panoptic_table, parse_pq_results | |
from mmengine import print_log, mkdir_or_exist | |
from mmengine.dist import barrier, broadcast_object_list, is_main_process | |
from mmdet.registry import METRICS | |
from mmdet.evaluation.metrics.base_video_metric import BaseVideoMetric, collect_tracking_results | |
from panopticapi.evaluation import PQStat | |
from seg.models.utils import mmpan2hbpan, INSTANCE_OFFSET_HB, mmgt2hbpan | |
from seg.models.utils import cal_pq, NO_OBJ_ID, IoUObj | |
def parse_pan_map_hb(pan_map: np.ndarray, data_sample: dict, num_classes: int) -> dict: | |
result = dict() | |
result['video_id'] = data_sample['video_id'] | |
result['frame_id'] = data_sample['frame_id'] | |
# For video evaluation, each map may include several loads, | |
# it is not efficient for saving an extra png map, especially | |
# for machines not with high performance ssd. | |
pan_labels = np.unique(pan_map) | |
segments_info = [] | |
for pan_label in pan_labels: | |
sem_label = pan_label // INSTANCE_OFFSET_HB | |
if sem_label >= num_classes: | |
continue | |
mask = (pan_map == pan_label).astype(np.uint8) | |
area = mask.sum() | |
# _mask = maskUtils.encode(np.asfortranarray(mask)) | |
# _mask['counts'] = _mask['counts'].decode() | |
segments_info.append({ | |
'id': int(pan_label), | |
'category_id': sem_label, | |
'area': int(area), | |
'mask': mask | |
}) | |
result['segments_info'] = segments_info | |
return result | |
def parse_data_sample_gt(data_sample: dict, num_things: int, num_stuff: int) -> dict: | |
num_classes = num_things + num_stuff | |
result = dict() | |
result['video_id'] = data_sample['video_id'] | |
result['frame_id'] = data_sample['frame_id'] | |
# For video evaluation, each map may include several loads, | |
# it is not efficient for saving an extra png map, especially | |
# for machines not with high performance ssd. | |
gt_instances = data_sample['gt_instances'] | |
segments_info = [] | |
for thing_id in range(len(gt_instances['labels'])): | |
mask = gt_instances['masks'].masks[thing_id].astype(np.uint8) | |
area = mask.sum() | |
pan_id = gt_instances['instances_ids'][thing_id] | |
cat = int(gt_instances['labels'][thing_id]) | |
if cat >= num_things: | |
raise ValueError(f"not reasonable value {cat}") | |
# _mask = maskUtils.encode(np.asfortranarray(mask)) | |
# _mask['counts'] = _mask['counts'].decode() | |
segments_info.append({ | |
'id': int(pan_id), | |
'category_id': cat, | |
'area': int(area), | |
'mask': mask | |
}) | |
gt_sem_seg = data_sample['gt_sem_seg']['sem_seg'][0].cpu().numpy() | |
for stuff_id in np.unique(gt_sem_seg): | |
if stuff_id < num_things: | |
continue | |
if stuff_id >= num_classes: | |
assert stuff_id == NO_OBJ_ID // INSTANCE_OFFSET_HB | |
_mask = (gt_sem_seg == stuff_id).astype(np.uint8) | |
area = _mask.sum() | |
cat = int(stuff_id) | |
pan_id = cat * INSTANCE_OFFSET_HB | |
segments_info.append({ | |
'id': int(pan_id), | |
'category_id': cat, | |
'area': int(area), | |
'mask': _mask | |
}) | |
if segments_info[-1]['id'] != NO_OBJ_ID: | |
segments_info.append({ | |
'id': int(NO_OBJ_ID), | |
'category_id': NO_OBJ_ID // INSTANCE_OFFSET_HB, | |
'area': 0, | |
'mask': np.zeros_like(gt_sem_seg, dtype=np.uint8) | |
}) | |
result['segments_info'] = segments_info | |
return result | |
class VIPSegMetric(BaseVideoMetric): | |
"""mAP evaluation metrics for the VIS task. | |
Args: | |
metric (str | list[str]): Metrics to be evaluated. | |
Default value is `youtube_vis_ap`.. | |
outfile_prefix (str | None): The prefix of json files. It includes | |
the file path and the prefix of filename, e.g., "a/b/prefix". | |
If not specified, a temp file will be created. Defaults to None. | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonyms metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Default: None | |
format_only (bool): If True, only formatting the results to the | |
official format and not performing evaluation. Defaults to False. | |
""" | |
default_prefix: Optional[str] = 'vip_seg' | |
def __init__(self, | |
metric: Union[str, List[str]] = 'VPQ@1', | |
outfile_prefix: Optional[str] = None, | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None, | |
format_only: bool = False) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
# vis evaluation metrics | |
self.metrics = metric if isinstance(metric, list) else [metric] | |
self.format_only = format_only | |
allowed_metrics = ['VPQ'] | |
for metric in self.metrics: | |
if metric not in allowed_metrics and metric.split('@')[0] not in allowed_metrics: | |
raise KeyError( | |
f"metric should be 'youtube_vis_ap', but got {metric}.") | |
self.outfile_prefix = outfile_prefix | |
self.per_video_res = [] | |
self.categories = {} | |
self._vis_meta_info = defaultdict(list) # record video and image infos | |
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
for track_data_sample in data_samples: | |
video_data_samples = track_data_sample['video_data_samples'] | |
ori_video_len = video_data_samples[0].ori_video_length | |
if ori_video_len == len(video_data_samples): | |
# video process | |
self.process_video(video_data_samples) | |
else: | |
# image process | |
raise NotImplementedError | |
def process_video(self, data_samples): | |
video_length = len(data_samples) | |
num_things = len(self.dataset_meta['thing_classes']) | |
num_stuff = len(self.dataset_meta['stuff_classes']) | |
num_classes = num_things + num_stuff | |
for frame_id in range(video_length): | |
img_data_sample = data_samples[frame_id].to_dict() | |
# 0 is for dummy dimension in fusion head, not batch. | |
pred = mmpan2hbpan(img_data_sample['pred_track_panoptic_seg']['sem_seg'][0], num_classes=num_classes) | |
if self.format_only: | |
vid_id = data_samples[frame_id].video_id | |
gt = mmgt2hbpan(data_samples[frame_id]) | |
mkdir_or_exist('vipseg_output/gt/') | |
mkdir_or_exist('vipseg_output/pred/') | |
torch.save(gt.to(device='cpu'), | |
'vipseg_output/gt/{:06d}_{:06d}.pth'.format(vid_id, frame_id)) | |
torch.save(torch.tensor(pred, device='cpu'), | |
'vipseg_output/pred/{:06d}_{:06d}.pth'.format(vid_id, frame_id)) | |
continue | |
pred_json = parse_pan_map_hb(pred, img_data_sample, num_classes=num_classes) | |
gt_json = parse_data_sample_gt(img_data_sample, num_things=num_things, num_stuff=num_stuff) | |
self.per_video_res.append((pred_json, gt_json)) | |
if self.format_only: | |
return | |
video_results = [] | |
for pred, gt in self.per_video_res: | |
intersection_info = dict() | |
gt_no_obj_info = gt['segments_info'][-1] | |
for pred_seg_info in pred['segments_info']: | |
intersection = int((gt_no_obj_info['mask'] * pred_seg_info['mask']).sum()) | |
union = pred_seg_info['area'] | |
intersection_info[gt_no_obj_info['id'], pred_seg_info['id']] = IoUObj( | |
intersection=intersection, | |
union=union | |
) | |
for pred_seg_info in pred['segments_info']: | |
for gt_seg_info in gt['segments_info'][:-1]: | |
intersection = int((gt_seg_info['mask'] * pred_seg_info['mask']).sum()) | |
union = gt_seg_info['area'] + pred_seg_info['area'] - \ | |
intersection - intersection_info[NO_OBJ_ID, pred_seg_info['id']].intersection | |
intersection_info[gt_seg_info['id'], pred_seg_info['id']] = IoUObj( | |
intersection=intersection, | |
union=union | |
) | |
video_results.append(intersection_info) | |
self.per_video_res.clear() | |
self.results.append(video_results) | |
def compute_metrics(self, results: List) -> Dict[str, float]: | |
"""Compute the metrics from processed results. | |
Args: | |
results (List): The processed results of each batch. | |
Returns: | |
Dict[str, float]: The computed metrics. The keys are the names of | |
the metrics, and the values are corresponding results. | |
""" | |
# split gt and prediction list | |
eval_results = {} | |
if self.format_only: | |
return eval_results | |
for metric in self.metrics: | |
seq_len = int(metric.split('@')[-1]) | |
pq_stat = PQStat() | |
cnt = 0 | |
for vid_idx, video_instances in enumerate(results): | |
for frame_x in range(len(video_instances)): | |
if frame_x + seq_len > len(video_instances): | |
break | |
global_intersection_info = defaultdict(IoUObj) | |
for frame_offset in range(seq_len): | |
frame_info = video_instances[frame_x + frame_offset] | |
for gt_id, pred_id in frame_info: | |
global_intersection_info[gt_id, pred_id] += frame_info[gt_id, pred_id] | |
pq_stat += cal_pq(global_intersection_info, classes=self.dataset_meta['classes']) | |
# global_intersection_info = defaultdict(IoUObj) | |
# for frame_idx, frame_info in enumerate(video_instances): | |
# for gt_id, pred_id in frame_info: | |
# global_intersection_info[gt_id, pred_id] += frame_info[gt_id, pred_id] | |
# if frame_idx - seq_len >= 0: | |
# out_frame_info = video_instances[frame_idx - seq_len] | |
# for gt_id, pred_id in out_frame_info: | |
# global_intersection_info[gt_id, pred_id] -= out_frame_info[gt_id, pred_id] | |
# assert global_intersection_info[gt_id, pred_id].is_legal() | |
# if frame_idx - seq_len >= -1: | |
# pq_stat += cal_pq(global_intersection_info, classes=self.dataset_meta['classes']) | |
# cnt += 1 | |
print_log("Total calculated clips: " + str(cnt), logger='current') | |
sub_metrics = [('All', None), ('Things', True), ('Stuff', False)] | |
pq_results = {} | |
for name, isthing in sub_metrics: | |
pq_results[name], classwise_results = pq_stat.pq_average( | |
self.categories, isthing=isthing) | |
if name == 'All': | |
pq_results['classwise'] = classwise_results | |
# classwise_results = { | |
# k: v | |
# for k, v in zip(self.dataset_meta['classes'], | |
# pq_results['classwise'].values()) | |
# } | |
print_panoptic_table(pq_results, None, logger='current') | |
metric_results = parse_pq_results(pq_results) | |
for key in metric_results: | |
eval_results[metric + f'_{key}'] = metric_results[key] | |
return eval_results | |
def evaluate(self, size: int) -> dict: | |
"""Evaluate the model performance of the whole dataset after processing | |
all batches. | |
Args: | |
size (int): Length of the entire validation dataset. | |
Returns: | |
dict: Evaluation metrics dict on the val dataset. The keys are the | |
names of the metrics, and the values are corresponding results. | |
""" | |
# wait for all processes to complete prediction. | |
barrier() | |
cls_idx = 0 | |
for thing_cls in self.dataset_meta['thing_classes']: | |
self.categories[cls_idx] = {'class': thing_cls, 'isthing': 1} | |
cls_idx += 1 | |
for stuff_cls in self.dataset_meta['stuff_classes']: | |
self.categories[cls_idx] = {'class': stuff_cls, 'isthing': 0} | |
cls_idx += 1 | |
assert cls_idx == len(self.dataset_meta['classes']) | |
if len(self.results) == 0: | |
warnings.warn( | |
f'{self.__class__.__name__} got empty `self.results`. Please ' | |
'ensure that the processed results are properly added into ' | |
'`self.results` in `process` method.') | |
results = collect_tracking_results(self.results, self.collect_device) | |
# # gather seq_info | |
# gathered_seq_info = all_gather_object(self._vis_meta_info['videos']) | |
# all_seq_info = [] | |
# for _seq_info in gathered_seq_info: | |
# all_seq_info.extend(_seq_info) | |
# # update self._vis_meta_info | |
# self._vis_meta_info = dict(videos=all_seq_info) | |
if is_main_process(): | |
print_log( | |
f"There are totally {len(results)} videos to be evaluated.", | |
logger='current' | |
) | |
_metrics = self.compute_metrics(results) # type: ignore | |
# Add prefix to metric names | |
if self.prefix: | |
_metrics = { | |
'/'.join((self.prefix, k)): v | |
for k, v in _metrics.items() | |
} | |
metrics = [_metrics] | |
else: | |
metrics = [None] # type: ignore | |
broadcast_object_list(metrics) | |
# reset the results list | |
self.results.clear() | |
# reset the vis_meta_info | |
self._vis_meta_info.clear() | |
return metrics[0] | |