File size: 9,087 Bytes
aee59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d9d4d0
8bfa969
329abf4
aee59a3
 
 
6d18e2a
aee59a3
6d18e2a
aee59a3
 
 
 
 
ec13ff1
aee59a3
 
 
 
 
6d18e2a
 
 
aee59a3
6d18e2a
 
 
 
 
 
 
 
 
 
 
 
 
aee59a3
6d18e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
5d7bacb
6d18e2a
 
 
 
 
 
 
 
 
 
 
aee59a3
 
db9ef5e
 
a209af8
6d18e2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee59a3
db9ef5e
c575cfd
 
db9ef5e
5d7bacb
db9ef5e
c575cfd
6d18e2a
5d7bacb
c575cfd
 
 
db9ef5e
 
5d7bacb
 
 
 
db9ef5e
aee59a3
6d18e2a
 
 
 
 
 
aee59a3
 
 
 
 
 
 
 
6d18e2a
490df31
 
 
6d18e2a
490df31
 
 
aee59a3
6d18e2a
db9ef5e
bc89e29
db9ef5e
6d18e2a
db9ef5e
6d18e2a
 
 
 
db9ef5e
6d18e2a
 
 
a3751fe
db9ef5e
 
 
dbcf938
db9ef5e
6d18e2a
db9ef5e
6d18e2a
 
db9ef5e
1033953
5d7bacb
 
 
 
 
 
 
 
 
 
 
 
 
db9ef5e
5d7bacb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import evaluate
import datasets
import numpy as np

from seametrics.horizon.utils import *

_CITATION = """\
@InProceedings{huggingface:module,
title = {Horizon Metrics},
authors={huggingface, Inc.},
year={2024}
}
"""

# TODO: Add description of the module here
_DESCRIPTION = """\
This metric is intended to calculate horizon prediction metrics."""

# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
    predictions: list of predictions for each image. Each prediction
        should be a nested array like this:
        - [[x1, y1], [x2, y2]]

    references: list of references for each image. Each reference
        should be a nested array like this:
        - [[x1, y1], [x2, y2]]
Returns:
    dict containing following metrics:
    'average_slope_error': Measures the average difference in slope between the predicted and ground truth horizon.
    'average_midpoint_error': Calculates the average difference in midpoint position between the predicted and ground truth horizon.
    'stddev_slope_error': Indicates the variability of errors in slope between the predicted and ground truth horizon.
    'stddev_midpoint_error': Quantifies the variability of errors in midpoint position between the predicted and ground truth horizon.
    'max_slope_error': Represents the maximum difference in slope between the predicted and ground truth horizon.
    'max_midpoint_error': Indicates the maximum difference in midpoint position between the predicted and ground truth horizon.
    'num_slope_error_jumps': Calculates the differences between errors in successive frames for the slope. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.
    'num_midpoint_error_jumps': Calculates the differences between errors in successive frames for the midpoint. It then counts the number of jumps in these errors by comparing the absolute differences to a specified threshold.

Examples:
    >>> ground_truth_points = [[[0.0, 0.5384765625], [1.0, 0.4931640625]],
                       [[0.0, 0.53796875], [1.0, 0.4928515625]],
                       [[0.0, 0.5374609375], [1.0, 0.4925390625]],
                       [[0.0, 0.536953125], [1.0, 0.4922265625]],
                       [[0.0, 0.5364453125], [1.0, 0.4919140625]]]

    >>> prediction_points = [[[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
                     [[0.0, 0.5428930956049597], [1.0, 0.4642497615378973]],
                     [[0.0, 0.523573113510805], [1.0, 0.47642688648919496]],
                     [[0.0, 0.5200016849393765], [1.0, 0.4728554579177664]],
                     [[0.0, 0.523573113510805], [1.0, 0.47642688648919496]]]


    >>> module = evaluate.load("SEA-AI/horizon-metrics", roll_threshold=0.5, pitch_threshold=0.1, vertical_fov_degrees=25.6, height=512)
    >>> module.add(predictions=ground_truth_points, references=prediction_points)
    >>> module.compute()
    >>> {'average_slope_error': 0.014823194839790999,
         'average_midpoint_error': 0.014285714285714301,
         'stddev_slope_error': 0.01519178791378349,
         'stddev_midpoint_error': 0.0022661781575342445,
         'max_slope_error': 0.033526146567062376,
         'max_midpoint_error': 0.018161272321428612,
         'num_slope_error_jumps': 1,
         'num_midpoint_error_jumps': 1}
    """


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION,
                                                _KWARGS_DESCRIPTION)
class HorizonMetrics(evaluate.Metric):
    """
    HorizonMetrics is a metric class that calculates horizon prediction metrics.

    Args:
        roll_threshold (float, optional): Threshold for roll angle. Defaults to 0.5.
        pitch_threshold (float, optional): Threshold for pitch angle. Defaults to 0.1.
        vertical_fov_degrees (float, optional): Vertical field of view in degrees. Defaults to 25.6.
        **kwargs: Additional keyword arguments.

    Attributes:

        slope_threshold (float): Threshold for slope calculated from roll threshold.
        midpoint_threshold (float): Threshold for midpoint calculated from pitch threshold.
        predictions (list): List of predicted horizons.
        ground_truth_det (list): List of ground truth horizons.
        slope_error_list (list): List of slope errors.
        midpoint_error_list (list): List of midpoint errors.

    Methods:

        _info(): Returns the metric information.
        add(predictions, references, **kwargs): Updates the predictions and ground truth detections.
        _compute(predictions, references, **kwargs): Computes the horizon error across the sequence.
    """

    def __init__(self,
                 roll_threshold=0.5,
                 pitch_threshold=0.1,
                 vertical_fov_degrees=25.6,
                 height=512,
                 **kwargs):

        super().__init__(**kwargs)

        self.slope_threshold = roll_to_slope(roll_threshold)
        self.midpoint_threshold = pitch_to_midpoint(pitch_threshold,
                                                    vertical_fov_degrees)
        self.predictions = None
        self.ground_truth_det = None
        self.slope_error_list = []
        self.midpoint_error_list = []
        self.height = height
        self.vertical_fov_degrees = vertical_fov_degrees

    def _info(self):
        """
        Returns the metric information.

        Returns:
            MetricInfo: The metric information.
        """
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features({
                'predictions':
                datasets.Sequence(
                    datasets.Sequence(
                        datasets.Sequence(datasets.Value("float")))),
                'references':
                datasets.Sequence(
                    datasets.Sequence(
                        datasets.Sequence(datasets.Value("float")))),
            }),
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"])

    def add(self, *, predictions, references, **kwargs):
        """
        Updates the predictions and ground truth detections.

        Parameters:
            predictions (list): List of predicted horizons.
            references (list): List of ground truth horizons.
            **kwargs: Additional keyword arguments.
        """
        super(evaluate.Metric, self).add(prediction=predictions,
                                         references=references,
                                         **kwargs)

        self.predictions = predictions
        self.ground_truth_det = references

    def _compute(self, *, predictions, references, **kwargs):
        """
        Computes the horizon error across the sequence.

        Returns:
            float: The computed horizon error.
        """

        # calculate erros and store values in slope_error_list and midpoint_error_list
        for annotated_horizon, proposed_horizon in zip(self.ground_truth_det,
                                                       self.predictions):

            if annotated_horizon is None or proposed_horizon is None:
                continue
            slope_error, midpoint_error = calculate_horizon_error(
                annotated_horizon, proposed_horizon)
            self.slope_error_list.append(slope_error)
            self.midpoint_error_list.append(midpoint_error)

        # calculate slope errors, midpoint errors and jumps
        result = calculate_horizon_error_across_sequence(
            self.slope_error_list, self.midpoint_error_list,
            self.slope_threshold, self.midpoint_threshold,
            self.vertical_fov_degrees, self.height)

        # calulcate detection rate
        detected_horizon_count = len(
            self.predictions) - self.predictions.count(None)
        detected_gt_count = len(
            self.ground_truth_det) - self.ground_truth_det.count(None)

        detection_rate = detected_horizon_count / detected_gt_count
        result['detection_rate'] = detection_rate

        return result