Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for variance metric.""" | |
from typing import Optional | |
from absl.testing import parameterized | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.recommendation.uplift import keras_test_case | |
from official.recommendation.uplift.metrics import variance | |
class VarianceTest(keras_test_case.KerasTestCase, parameterized.TestCase): | |
def _compute_variance( | |
self, values: tf.Tensor, weights: Optional[tf.Tensor] = None | |
) -> float: | |
values = values.numpy() | |
if weights is None: | |
return values.var() | |
weights = weights.numpy() | |
weights = np.broadcast_to(weights, shape=values.shape) | |
weighted_mean = np.average(values, weights=weights) | |
return np.average((values - weighted_mean) ** 2, weights=weights) | |
def test_single_batch_correctness(self, values, sample_weight): | |
metric = variance.Variance() | |
metric(values=values, sample_weight=sample_weight) | |
expected_variance = self._compute_variance(values, sample_weight) | |
self.assertAllClose(expected_variance, metric.result()) | |
def test_multi_batch_correctness( | |
self, values_batches, sample_weight_batches, all_values, all_weights | |
): | |
metric = variance.Variance() | |
for values, sample_weight in zip(values_batches, sample_weight_batches): | |
metric(values=values, sample_weight=sample_weight) | |
expected_variance = self._compute_variance(all_values, all_weights) | |
self.assertAllClose(expected_variance, metric.result()) | |
self.assertAllGreaterEqual(metric.result(), 0.0) | |
def test_float_sample_weight(self, values, sample_weight, expected_variance): | |
metric = variance.Variance() | |
metric(values, sample_weight=sample_weight) | |
self.assertEqual(expected_variance, metric.result()) | |
def test_empty_input(self): | |
metric = variance.Variance() | |
values = tf.constant([0, 1, 2, 3]) | |
metric(values) | |
self.assertEqual(1.25, metric.result()) | |
metric(tf.ones(shape=(0,)), sample_weight=None) | |
self.assertEqual(1.25, metric.result()) | |
def test_initial_state(self): | |
metric = variance.Variance() | |
self.assertEqual(0.0, metric.result()) | |
def test_dtype_correctness(self): | |
# 1 << 128 overflows for float32 but fits in float64. | |
value = tf.constant([1 << 128], dtype=tf.float64) | |
metric = variance.Variance(dtype=tf.float32) | |
metric(value) | |
self.assertAllEqual(np.nan, metric.result().numpy()) | |
metric = variance.Variance(dtype=tf.float64) | |
metric(value) | |
self.assertAllEqual(0.0, metric.result().numpy()) | |
def test_invalid_dtype(self): | |
with self.assertRaises(ValueError): | |
metric = variance.Variance(dtype=tf.string) | |
metric(tf.constant(["hello, world!"], tf.string)) | |
def test_invalid_weight_shape(self, values, weights): | |
metric = variance.Variance() | |
with self.assertRaises(tf.errors.InvalidArgumentError): | |
metric(values, weights) | |
def test_name(self): | |
metric = variance.Variance(name="test_name") | |
self.assertEqual("test_name", metric.name) | |
def test_multiple_result_calls(self): | |
metric = variance.Variance() | |
values = tf.constant([1, 2, 1, 4]) | |
metric.update_state(values) | |
self.assertEqual(values.numpy().var(), metric.result()) | |
self.assertEqual(values.numpy().var(), metric.result()) | |
metric.update_state(tf.constant([-1, -2, 0])) | |
self.assertEqual(np.array([1, 2, 1, 4, -1, -2, 0]).var(), metric.result()) | |
def test_reset_state(self): | |
metric = variance.Variance() | |
values = tf.constant([1, 2, 1, 4]) | |
metric.update_state(values) | |
self.assertEqual(1.5, metric.result()) | |
metric.reset_state() | |
metric.update_state(values, sample_weight=tf.constant([1, 0, 1, 0])) | |
self.assertEqual(0.0, metric.result()) | |
def test_numpy_correctness(self): | |
metric = variance.Variance() | |
values = np.array([-1.3, 2.4, 1, 4]) | |
weights = np.array([0.7, 0, 1.3, 1.0]) | |
metric.update_state(values, weights) | |
expected_variance = self._compute_variance( | |
tf.convert_to_tensor(values), tf.convert_to_tensor(weights) | |
) | |
self.assertAllClose(expected_variance, metric.result()) | |
def test_metric_config(self): | |
metric = variance.Variance() | |
self.assertLayerConfigurable(layer=metric) | |
if __name__ == "__main__": | |
tf.test.main() | |