|
import pickle |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm |
|
|
|
from pcdet.models import load_data_to_gpu |
|
from pcdet.utils import common_utils |
|
|
|
|
|
def statistics_info(cfg, ret_dict, metric, disp_dict): |
|
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
|
metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0) |
|
metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0) |
|
metric['gt_num'] += ret_dict.get('gt', 0) |
|
min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0] |
|
disp_dict['recall_%s' % str(min_thresh)] = \ |
|
'(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num']) |
|
|
|
|
|
def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None): |
|
result_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
final_output_dir = result_dir / 'final_result' / 'data' |
|
if args.save_to_file: |
|
final_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
metric = { |
|
'gt_num': 0, |
|
} |
|
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
|
metric['recall_roi_%s' % str(cur_thresh)] = 0 |
|
metric['recall_rcnn_%s' % str(cur_thresh)] = 0 |
|
|
|
dataset = dataloader.dataset |
|
class_names = dataset.class_names |
|
det_annos = [] |
|
|
|
if getattr(args, 'infer_time', False): |
|
start_iter = int(len(dataloader) * 0.1) |
|
infer_time_meter = common_utils.AverageMeter() |
|
|
|
logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) |
|
if dist_test: |
|
num_gpus = torch.cuda.device_count() |
|
local_rank = cfg.LOCAL_RANK % num_gpus |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
model, |
|
device_ids=[local_rank], |
|
broadcast_buffers=False |
|
) |
|
model.eval() |
|
|
|
if cfg.LOCAL_RANK == 0: |
|
progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True) |
|
start_time = time.time() |
|
for i, batch_dict in enumerate(dataloader): |
|
load_data_to_gpu(batch_dict) |
|
|
|
if getattr(args, 'infer_time', False): |
|
start_time = time.time() |
|
|
|
with torch.no_grad(): |
|
pred_dicts, ret_dict = model(batch_dict) |
|
|
|
disp_dict = {} |
|
|
|
if getattr(args, 'infer_time', False): |
|
inference_time = time.time() - start_time |
|
infer_time_meter.update(inference_time * 1000) |
|
|
|
disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})' |
|
|
|
statistics_info(cfg, ret_dict, metric, disp_dict) |
|
annos = dataset.generate_prediction_dicts( |
|
batch_dict, pred_dicts, class_names, |
|
output_path=final_output_dir if args.save_to_file else None |
|
) |
|
det_annos += annos |
|
if cfg.LOCAL_RANK == 0: |
|
progress_bar.set_postfix(disp_dict) |
|
progress_bar.update() |
|
|
|
if cfg.LOCAL_RANK == 0: |
|
progress_bar.close() |
|
|
|
if dist_test: |
|
rank, world_size = common_utils.get_dist_info() |
|
det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir') |
|
metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir') |
|
|
|
logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) |
|
sec_per_example = (time.time() - start_time) / len(dataloader.dataset) |
|
logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example) |
|
|
|
if cfg.LOCAL_RANK != 0: |
|
return {} |
|
|
|
ret_dict = {} |
|
if dist_test: |
|
for key, val in metric[0].items(): |
|
for k in range(1, world_size): |
|
metric[0][key] += metric[k][key] |
|
metric = metric[0] |
|
|
|
gt_num_cnt = metric['gt_num'] |
|
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: |
|
cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
|
cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) |
|
logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall)) |
|
logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall)) |
|
ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall |
|
ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall |
|
|
|
total_pred_objects = 0 |
|
for anno in det_annos: |
|
total_pred_objects += anno['name'].__len__() |
|
logger.info('Average predicted number of objects(%d samples): %.3f' |
|
% (len(det_annos), total_pred_objects / max(1, len(det_annos)))) |
|
|
|
with open(result_dir / 'result.pkl', 'wb') as f: |
|
pickle.dump(det_annos, f) |
|
|
|
print(f"length of det_annos: {len(det_annos)}") |
|
print(dataset) |
|
result_str, result_dict = dataset.evaluation( |
|
det_annos, class_names, |
|
eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC, |
|
output_path=final_output_dir |
|
) |
|
print(f"result_dict: {result_dict.keys()}") |
|
logger.info(result_str) |
|
ret_dict.update(result_dict) |
|
logger.info('Result is saved to %s' % result_dir) |
|
logger.info('****************Evaluation done.*****************') |
|
return ret_dict |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|