File size: 6,214 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# --------------------------------------------------------
# Python Single Object Tracking Evaluation
# Licensed under The MIT License [see LICENSE for details]
# Written by Fangyi Zhang
# @author fangyi.zhang@vipl.ict.ac.cn
# @project https://github.com/StrangerZhang/pysot-toolkit.git
# Revised for SiamMask by foolwood
# --------------------------------------------------------

import warnings
import itertools
import numpy as np

from colorama import Style, Fore
from ..utils import calculate_failures, calculate_accuracy


class AccuracyRobustnessBenchmark:
    """
    Args:
        dataset:
        burnin:
    """
    def __init__(self, dataset, burnin=10):
        self.dataset = dataset
        self.burnin = burnin

    def eval(self, eval_trackers=None):
        """
        Args:
            eval_tags: list of tag
            eval_trackers: list of tracker name
        Returns:
            ret: dict of results
        """
        if eval_trackers is None:
            eval_trackers = self.dataset.tracker_names
        if isinstance(eval_trackers, str):
            eval_trackers = [eval_trackers]

        result = {}
        for tracker_name in eval_trackers:
            accuracy, failures = self._calculate_accuracy_robustness(tracker_name)
            result[tracker_name] = {'overlaps': accuracy,
                                    'failures': failures}
        return result

    def show_result(self, result, eao_result=None, show_video_level=False, helight_threshold=0.5):
        """pretty print result
        Args:
            result: returned dict from function eval
        """
        tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12)
        if eao_result is not None:
            header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|{:^7}|"
            header = header.format('Tracker Name',
                    'Accuracy', 'Robustness', 'Lost Number', 'EAO')
            formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|"
        else:
            header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|"
            header = header.format('Tracker Name',
                    'Accuracy', 'Robustness', 'Lost Number')
            formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|"
        bar = '-'*len(header)
        print(bar)
        print(header)
        print(bar)
        if eao_result is not None:
            tracker_eao = sorted(eao_result.items(),
                                 key=lambda x:x[1]['all'],
                                 reverse=True)[:20]
            tracker_names = [x[0] for x in tracker_eao]
        else:
            tracker_names = list(result.keys())
        for tracker_name in tracker_names:
            ret = result[tracker_name]
            overlaps = list(itertools.chain(*ret['overlaps'].values()))
            accuracy = np.nanmean(overlaps)
            length = sum([len(x) for x in ret['overlaps'].values()])
            failures = list(ret['failures'].values())
            lost_number = np.mean(np.sum(failures, axis=0))
            robustness = np.mean(np.sum(np.array(failures), axis=0) / length) * 100
            if eao_result is None:
                print(formatter.format(tracker_name, accuracy, robustness, lost_number))
            else:
                print(formatter.format(tracker_name, accuracy, robustness, lost_number, eao_result[tracker_name]['all']))
        print(bar)

        if show_video_level and len(result) < 10:
            print('\n\n')
            header1 = "|{:^14}|".format("Tracker name")
            header2 = "|{:^14}|".format("Video name")
            for tracker_name in result.keys():
                header1 += ("{:^17}|").format(tracker_name)
                header2 += "{:^8}|{:^8}|".format("Acc", "LN")
            print('-'*len(header1))
            print(header1)
            print('-'*len(header1))
            print(header2)
            print('-'*len(header1))
            videos = list(result[tracker_name]['overlaps'].keys())
            for video in videos:
                row = "|{:^14}|".format(video)
                for tracker_name in result.keys():
                    overlaps = result[tracker_name]['overlaps'][video]
                    accuracy = np.nanmean(overlaps)
                    failures = result[tracker_name]['failures'][video]
                    lost_number = np.mean(failures)

                    accuracy_str = "{:^8.3f}".format(accuracy)
                    if accuracy < helight_threshold:
                        row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
                    else:
                        row += accuracy_str+'|'
                    lost_num_str = "{:^8.3f}".format(lost_number)
                    if lost_number > 0:
                        row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
                    else:
                        row += lost_num_str+'|'
                print(row)
            print('-'*len(header1))

    def _calculate_accuracy_robustness(self, tracker_name):
        overlaps = {}
        failures = {}
        all_length = {}
        for i in range(len(self.dataset)):
            video = self.dataset[i]
            gt_traj = video.gt_traj
            if tracker_name not in video.pred_trajs:
                tracker_trajs = video.load_tracker(self.dataset.tracker_path, tracker_name, False)
            else:
                tracker_trajs = video.pred_trajs[tracker_name]
            overlaps_group = []
            num_failures_group = []
            for tracker_traj in tracker_trajs:
                num_failures = calculate_failures(tracker_traj)[0]
                overlaps_ = calculate_accuracy(tracker_traj, gt_traj,
                        burnin=10, bound=(video.width, video.height))[1]
                overlaps_group.append(overlaps_)
                num_failures_group.append(num_failures)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                overlaps[video.name] = np.nanmean(overlaps_group, axis=0).tolist()
                failures[video.name] = num_failures_group
        return overlaps, failures