"""Structural Similarity Index Measure metric.""" import datasets import numpy as np from skimage.metrics import structural_similarity from typing import Dict, Optional import evaluate _DESCRIPTION = """ Compute the mean Structural Similarity Index Measure (SSIM) between two images. Please pay attention to the `data_range` parameter with floating-point images. Notes: ----- If `data_range` is not specified, the range is automatically guessed based on the image data type. However for floating-point image data, this estimate yields a result double the value of the desired range, as the `dtype_range` in `skimage.util.dtype.py` has defined intervals from -1 to +1. This yields an estimate of 2, instead of 1, which is most oftenrequired when working with image data (as negative light intentsities are nonsensical). In case of working with YCbCr-like color data, note that these ranges are different per channel (Cb and Cr have double the range of Y), so one cannot calculate a channel-averaged SSIM with a single call to this function, as identical ranges are assumed for each channel. To match the implementation of Wang et al. [1]_, set `gaussian_weights` to True, `sigma` to 1.5, `use_sample_covariance` to False, and specify the `data_range` argument. """ _KWARGS_DESCRIPTION = """ Args: predictions (`list` of `np.array`): Predicted labels. references (`list` of `np.array`): Ground truth labels. sample_weight (`list` of `float`): Sample weights Defaults to None. Returns: ssim (`float`): Structural Similarity Index Measure. The SSIM values are in range (-1, 1], when pixels are non-negative. Examples: Example 1-A simple example >>> accuracy_metric = evaluate.load("accuracy") >>> results = accuracy_metric.compute(references=[[0, 0], [-1, -1]], predictions=[[0, 1], [0, 0]]) >>> print(results) 0.5 """ _CITATION = """ @article{boulogne2014scikit, title={Scikit-image: Image processing in Python}, author={Boulogne, Fran{\c{c}}ois and Warner, Joshua D and Neil Yager, Emmanuelle}, journal={J. PeerJ}, volume={2}, pages={453}, year={2014} } """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class StructuralSimilarityIndexMeasure(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features(self._get_feature_types()), reference_urls=["https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html"], ) def _get_feature_types(self): if self.config_name == "multilist": return { # 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width "predictions": datasets.Sequence( datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) ), "references": datasets.Sequence( datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) ), } else: return { # 1st Seq - Height, 2rd Seq - Width "predictions": datasets.Sequence( datasets.Sequence(datasets.Value("float32")) ), "references": datasets.Sequence( datasets.Sequence(datasets.Value("float32")) ), } def _compute( self, predictions, references, win_size: Optional[int] = None, gaussian_weights: Optional[bool] = False, data_range: Optional[float] = None, multichannel: Optional[bool] = False, sample_weight=None, **kwargs ) -> Dict[str, float]: if self.config_name == "multilist": def func_ssim(args): pred, target = args pred = np.array(pred) target = np.array(target) return structural_similarity( pred, target, win_size=win_size, gaussian_weights=gaussian_weights, data_range=data_range, multichannel=multichannel, **kwargs ) return np.average( list(map(func_ssim, zip(predictions, references))), weights=sample_weight ) else: return structural_similarity( np.array(predictions), np.array(references), win_size=win_size, gaussian_weights=gaussian_weights, data_range=data_range, multichannel=multichannel, **kwargs )