import pickle import time import numpy as np import torch import tqdm from pcdet.models import load_data_to_gpu from pcdet.utils import common_utils def statistics_info(cfg, ret_dict, metric, disp_dict): for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0) metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0) metric['gt_num'] += ret_dict.get('gt', 0) min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0] disp_dict['recall_%s' % str(min_thresh)] = \ '(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num']) def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None): result_dir.mkdir(parents=True, exist_ok=True) final_output_dir = result_dir / 'final_result' / 'data' if args.save_to_file: final_output_dir.mkdir(parents=True, exist_ok=True) metric = { 'gt_num': 0, } for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: metric['recall_roi_%s' % str(cur_thresh)] = 0 metric['recall_rcnn_%s' % str(cur_thresh)] = 0 dataset = dataloader.dataset class_names = dataset.class_names det_annos = [] if getattr(args, 'infer_time', False): start_iter = int(len(dataloader) * 0.1) infer_time_meter = common_utils.AverageMeter() logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) if dist_test: num_gpus = torch.cuda.device_count() local_rank = cfg.LOCAL_RANK % num_gpus model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], broadcast_buffers=False ) model.eval() if cfg.LOCAL_RANK == 0: progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True) start_time = time.time() for i, batch_dict in enumerate(dataloader): load_data_to_gpu(batch_dict) if getattr(args, 'infer_time', False): start_time = time.time() with torch.no_grad(): pred_dicts, ret_dict = model(batch_dict) disp_dict = {} if getattr(args, 'infer_time', False): inference_time = time.time() - start_time infer_time_meter.update(inference_time * 1000) # use ms to measure inference time disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})' statistics_info(cfg, ret_dict, metric, disp_dict) annos = dataset.generate_prediction_dicts( batch_dict, pred_dicts, class_names, output_path=final_output_dir if args.save_to_file else None ) det_annos += annos if cfg.LOCAL_RANK == 0: progress_bar.set_postfix(disp_dict) progress_bar.update() if cfg.LOCAL_RANK == 0: progress_bar.close() if dist_test: rank, world_size = common_utils.get_dist_info() det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir') metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir') logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) sec_per_example = (time.time() - start_time) / len(dataloader.dataset) logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example) if cfg.LOCAL_RANK != 0: return {} ret_dict = {} if dist_test: for key, val in metric[0].items(): for k in range(1, world_size): metric[0][key] += metric[k][key] metric = metric[0] gt_num_cnt = metric['gt_num'] for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1) logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall)) logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall)) ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall total_pred_objects = 0 for anno in det_annos: total_pred_objects += anno['name'].__len__() logger.info('Average predicted number of objects(%d samples): %.3f' % (len(det_annos), total_pred_objects / max(1, len(det_annos)))) with open(result_dir / 'result.pkl', 'wb') as f: pickle.dump(det_annos, f) print(f"length of det_annos: {len(det_annos)}") print(dataset) result_str, result_dict = dataset.evaluation( det_annos, class_names, eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC, output_path=final_output_dir ) print(f"result_dict: {result_dict.keys()}") logger.info(result_str) ret_dict.update(result_dict) logger.info('Result is saved to %s' % result_dir) logger.info('****************Evaluation done.*****************') return ret_dict if __name__ == '__main__': pass