Spaces:
Runtime error
Runtime error
File size: 3,412 Bytes
7734d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from loguru import logger
import torch
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel as DDP
from yolox.core import launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
from yolox.evaluators import MOTEvaluator
import argparse
import os
import random
import warnings
import glob
import motmetrics as mm
from collections import OrderedDict
from pathlib import Path
def compare_dataframes(gts, ts):
accs = []
names = []
for k, tsacc in ts.items():
if k in gts:
logger.info('Comparing {}...'.format(k))
accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
names.append(k)
else:
logger.warning('No ground truth for {}, skipping.'.format(k))
return accs, names
# evaluate MOTA
results_folder = 'YOLOX_outputs/yolox_x_ablation/track_results'
mm.lap.default_solver = 'lap'
gt_type = '_val_half'
#gt_type = ''
print('gt_type', gt_type)
gtfiles = glob.glob(
os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
print('gt_files', gtfiles)
tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if not os.path.basename(f).startswith('eval')]
logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
logger.info('Loading files.')
gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=0.6)) for f in tsfiles])
mh = mm.metrics.create()
accs, names = compare_dataframes(gt, ts)
logger.info('Running metrics')
metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
# summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
# print(mm.io.render_summary(
# summary, formatters=mh.formatters,
# namemap=mm.io.motchallenge_metric_names))
div_dict = {
'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
for divisor in div_dict:
for divided in div_dict[divisor]:
summary[divided] = (summary[divided] / summary[divisor])
fmt = mh.formatters
change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
'partially_tracked', 'mostly_lost']
for k in change_fmt_list:
fmt[k] = fmt['mota']
print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
metrics = mm.metrics.motchallenge_metrics + ['num_objects']
summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
logger.info('Completed') |