File size: 4,856 Bytes
0106bed
869cc3a
 
 
 
 
 
 
 
 
 
a0c2170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869cc3a
 
 
 
 
 
 
 
 
a0c2170
 
869cc3a
 
 
df40628
869cc3a
df40628
869cc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3f2132
869cc3a
 
 
c3f2132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869cc3a
 
 
 
 
 
 
 
 
 
 
c3f2132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869cc3a
c3f2132
 
869cc3a
 
 
0106bed
 
c3f2132
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Structural Similarity Index Measure metric."""

import datasets
import numpy as np

from skimage.metrics import structural_similarity
from typing import Dict, Optional
import evaluate


_DESCRIPTION = """
Compute the mean Structural Similarity Index Measure (SSIM) between two images.

Please pay attention to the `data_range` parameter with floating-point images.

Notes:
-----

If `data_range` is not specified, the range is automatically guessed based on the image 
data type. However for floating-point image data, this estimate yields a result double
the value of the desired range, as the `dtype_range` in `skimage.util.dtype.py` has
defined intervals from -1 to +1. This yields an estimate of 2, instead of 1, which is
most oftenrequired when working with image data (as negative light intentsities are
nonsensical). In case of working with YCbCr-like color data, note that these ranges are
different per channel (Cb and Cr have double the range of Y), so one cannot calculate a
channel-averaged SSIM with a single call to this function, as identical ranges are
assumed for each channel. To match the implementation of Wang et al. [1]_, set
`gaussian_weights` to True, `sigma` to 1.5, `use_sample_covariance` to False, and
specify the `data_range` argument.
"""


_KWARGS_DESCRIPTION = """
Args:
    predictions (`list` of `np.array`): Predicted labels.
    references (`list` of `np.array`): Ground truth labels.
    sample_weight (`list` of `float`): Sample weights Defaults to None.
Returns:
    ssim (`float`): Structural Similarity Index Measure. The SSIM values are in range
    (-1, 1], when pixels are non-negative.
Examples:
    Example 1-A simple example
        >>> accuracy_metric = evaluate.load("accuracy")
        >>> results = accuracy_metric.compute(references=[[0, 0], [-1, -1]], predictions=[[0, 1], [0, 0]])
        >>> print(results)
        0.5
"""


_CITATION = """
@article{boulogne2014scikit,
  title={Scikit-image: Image processing in Python},
  author={Boulogne, Fran{\c{c}}ois and Warner, Joshua D and Neil Yager, Emmanuelle},
  journal={J. PeerJ},
  volume={2},
  pages={453},
  year={2014}
}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class StructuralSimilarityIndexMeasure(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(self._get_feature_types()),
            reference_urls=["https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html"],
        )

    def _get_feature_types(self):
        if self.config_name == "multilist":
            return {
                # 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width
                "predictions": datasets.Sequence(
                    datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
                ),
                "references": datasets.Sequence(
                    datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
                ),
            }
        else:
            return {
                # 1st Seq - Height, 2rd Seq - Width
                "predictions": datasets.Sequence(
                    datasets.Sequence(datasets.Value("float32"))
                ),
                "references": datasets.Sequence(
                    datasets.Sequence(datasets.Value("float32"))
                ),
            }

    def _compute(
        self, 
        predictions, 
        references, 
        win_size: Optional[int] = None, 
        gaussian_weights: Optional[bool] = False,
        data_range: Optional[float] = None, 
        multichannel: Optional[bool] = False,
        sample_weight=None,
        **kwargs
    ) -> Dict[str, float]:
        if self.config_name == "multilist":
            def func_ssim(args):
                pred, target = args
                pred = np.array(pred)
                target = np.array(target)
                return structural_similarity(
                    pred, 
                    target, 
                    win_size=win_size, 
                    gaussian_weights=gaussian_weights, 
                    data_range=data_range, 
                    multichannel=multichannel,
                    **kwargs
                )
            return np.average(
                list(map(func_ssim, zip(predictions, references))), 
                weights=sample_weight
            )
        else:
            return structural_similarity(
                np.array(predictions),
                np.array(references), 
                win_size=win_size, 
                gaussian_weights=gaussian_weights, 
                data_range=data_range, 
                multichannel=multichannel,
                **kwargs
            )