File size: 5,336 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import pickle
import time
import numpy as np
import torch
import tqdm
from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils
def statistics_info(cfg, ret_dict, metric, disp_dict):
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0)
metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0)
metric['gt_num'] += ret_dict.get('gt', 0)
min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0]
disp_dict['recall_%s' % str(min_thresh)] = \
'(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num'])
def eval_one_epoch(cfg, args, model, dataloader, epoch_id, logger, dist_test=False, result_dir=None):
result_dir.mkdir(parents=True, exist_ok=True)
final_output_dir = result_dir / 'final_result' / 'data'
if args.save_to_file:
final_output_dir.mkdir(parents=True, exist_ok=True)
metric = {
'gt_num': 0,
}
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
metric['recall_roi_%s' % str(cur_thresh)] = 0
metric['recall_rcnn_%s' % str(cur_thresh)] = 0
dataset = dataloader.dataset
class_names = dataset.class_names
det_annos = []
if getattr(args, 'infer_time', False):
start_iter = int(len(dataloader) * 0.1)
infer_time_meter = common_utils.AverageMeter()
logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id)
if dist_test:
num_gpus = torch.cuda.device_count()
local_rank = cfg.LOCAL_RANK % num_gpus
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
broadcast_buffers=False
)
model.eval()
if cfg.LOCAL_RANK == 0:
progress_bar = tqdm.tqdm(total=len(dataloader), leave=True, desc='eval', dynamic_ncols=True)
start_time = time.time()
for i, batch_dict in enumerate(dataloader):
load_data_to_gpu(batch_dict)
if getattr(args, 'infer_time', False):
start_time = time.time()
with torch.no_grad():
pred_dicts, ret_dict = model(batch_dict)
disp_dict = {}
if getattr(args, 'infer_time', False):
inference_time = time.time() - start_time
infer_time_meter.update(inference_time * 1000)
# use ms to measure inference time
disp_dict['infer_time'] = f'{infer_time_meter.val:.2f}({infer_time_meter.avg:.2f})'
statistics_info(cfg, ret_dict, metric, disp_dict)
annos = dataset.generate_prediction_dicts(
batch_dict, pred_dicts, class_names,
output_path=final_output_dir if args.save_to_file else None
)
det_annos += annos
if cfg.LOCAL_RANK == 0:
progress_bar.set_postfix(disp_dict)
progress_bar.update()
if cfg.LOCAL_RANK == 0:
progress_bar.close()
if dist_test:
rank, world_size = common_utils.get_dist_info()
det_annos = common_utils.merge_results_dist(det_annos, len(dataset), tmpdir=result_dir / 'tmpdir')
metric = common_utils.merge_results_dist([metric], world_size, tmpdir=result_dir / 'tmpdir')
logger.info('*************** Performance of EPOCH %s *****************' % epoch_id)
sec_per_example = (time.time() - start_time) / len(dataloader.dataset)
logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example)
if cfg.LOCAL_RANK != 0:
return {}
ret_dict = {}
if dist_test:
for key, val in metric[0].items():
for k in range(1, world_size):
metric[0][key] += metric[k][key]
metric = metric[0]
gt_num_cnt = metric['gt_num']
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
cur_roi_recall = metric['recall_roi_%s' % str(cur_thresh)] / max(gt_num_cnt, 1)
cur_rcnn_recall = metric['recall_rcnn_%s' % str(cur_thresh)] / max(gt_num_cnt, 1)
logger.info('recall_roi_%s: %f' % (cur_thresh, cur_roi_recall))
logger.info('recall_rcnn_%s: %f' % (cur_thresh, cur_rcnn_recall))
ret_dict['recall/roi_%s' % str(cur_thresh)] = cur_roi_recall
ret_dict['recall/rcnn_%s' % str(cur_thresh)] = cur_rcnn_recall
total_pred_objects = 0
for anno in det_annos:
total_pred_objects += anno['name'].__len__()
logger.info('Average predicted number of objects(%d samples): %.3f'
% (len(det_annos), total_pred_objects / max(1, len(det_annos))))
with open(result_dir / 'result.pkl', 'wb') as f:
pickle.dump(det_annos, f)
print(f"length of det_annos: {len(det_annos)}")
print(dataset)
result_str, result_dict = dataset.evaluation(
det_annos, class_names,
eval_metric=cfg.MODEL.POST_PROCESSING.EVAL_METRIC,
output_path=final_output_dir
)
print(f"result_dict: {result_dict.keys()}")
logger.info(result_str)
ret_dict.update(result_dict)
logger.info('Result is saved to %s' % result_dir)
logger.info('****************Evaluation done.*****************')
return ret_dict
if __name__ == '__main__':
pass
|