Spaces:
Runtime error
Runtime error
File size: 6,156 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing framework for Keras extending tf.TestCase.
Keras layers have several nuances. Usually, we expect:
1. A Keras layer to be stable. If no learning happens, they should return
the same result. See assertLayerStable.
2. A Keras layer to be savable. We should be able to save, then load
a Keras layer. It should have the same variables, and return the same result.
See assertLayerSavable.
Various nuances, such as defining a correct get_config function that doesn't
forget any variables, and making sure that any sub-layers are referenced as
fields (or in fields that are dictionaries), are necessary for these tests to
succeed.
"""
import json
from typing import Any, Mapping, Sequence
import tensorflow as tf, tf_keras
# pylint: disable=invalid-name
class KerasTestCase(tf.test.TestCase):
"""Adds methods for testing Keras layers."""
def assertNestedEqual(self, data1, data2):
"""Asserts that both inputs have identical structure and contents."""
tf.nest.assert_same_structure(data1, data2)
for a, b in zip(tf.nest.flatten(data1), tf.nest.flatten(data2)):
if isinstance(a, tf.SparseTensor):
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)
else:
self.assertAllEqual(a, b)
def assertLayerStable(self, inputs, layer, **kwargs):
"""Layer called twice with the same inputs returns the same result.
Layer must return a result appropriate for assertAllEqual.
Args:
inputs: inputs to the layer.
layer: layer to test.
**kwargs: auxiliary inputs passed to the layer.
"""
output1 = layer(inputs, **kwargs)
output2 = layer(inputs, **kwargs)
self.assertNestedEqual(output1, output2)
def toKerasInputs(self, inputs): # pylint:disable=[invalid-name]
"""Generate tf_keras.Input for inputs.
Args:
inputs: a StructuredTensor, Tensor, RecordTensor, RaggedTensor, or nested
structures of these.
Returns:
tf_keras.Input representing the input.
"""
def map_to_keras_input(x):
if isinstance(x, tf.Tensor):
return tf_keras.Input(x.shape[1:], dtype=x.dtype)
ts = tf.type_spec_from_value(x)
return tf_keras.Input(type_spec=ts)
return tf.nest.map_structure(map_to_keras_input, inputs)
def assertLayerSavable(
self,
inputs,
layer,
keras_inputs=None,
custom_objects=None,
save_format="tf",
**kwargs
):
"""Layer can be saved and loaded in a model.
Args:
inputs: an input to the layer.
layer: a layer to save.
keras_inputs: if inputs._type_spec won't create a tf_keras.Input.
custom_objects: passed to load_model.
save_format: save_format ("tf" or "h5")
**kwargs: auxiliary inputs passed to the layer.
"""
def _make_model(inputs, layer, keras_inputs):
if keras_inputs is None:
# TODO(martinz): This is not a generic solution.
keras_inputs = self.toKerasInputs(inputs)
keras_outputs = layer(keras_inputs, **kwargs)
return tf_keras.Model(keras_inputs, keras_outputs)
model = _make_model(inputs, layer, keras_inputs)
self.assertModelSavable(
inputs, model, custom_objects=custom_objects, save_format=save_format
)
def assertModelSavable(
self, inputs, model, custom_objects=None, save_format="tf"
):
"""Model can be saved and loaded.
Args:
inputs: an input to the layer.
model: a model to save.
custom_objects: passed to load_model.
save_format: save_format ("tf" or "h5")
"""
if custom_objects is None:
custom_objects = {}
src_output = model(inputs)
model_path = self.get_temp_dir() + "/tmp_model"
tf_keras.models.save_model(model, model_path, save_format=save_format)
reloaded_model = tf_keras.models.load_model(
model_path, custom_objects=custom_objects
)
self.assertEqual(
len(model.trainable_variables), len(reloaded_model.trainable_variables)
)
for src_v, loaded_v in zip(
model.trainable_variables, reloaded_model.trainable_variables
):
self.assertAllEqual(src_v, loaded_v)
loaded_output = reloaded_model(inputs)
self.assertNestedEqual(src_output, loaded_output)
def assertLayerConfigurable(
self, layer: tf_keras.layers.Layer, serializable: bool = True, **kwargs
):
"""Layer can be reconstructed using get_config and from_config.
Args:
layer: layer with get_config and from_config methods.
serializable: boolean to indicate if the config should be tested for json
serializability.
**kwargs: keyword arguments to pass to layer's `__call__` method. These
are used for testing the correctness of the call output. If no keywords
are passed then this part of the test will not be executed.
"""
config = layer.get_config()
from_config_layer = layer.__class__.from_config(
config if not serializable else json.loads(json.dumps(config))
)
self.assertDictEqual(config, from_config_layer.get_config())
if kwargs:
self.assertNestedEqual(layer(**kwargs), from_config_layer(**kwargs))
def layer_dict_from_classes(classes: Sequence[Any]) -> Mapping[str, Any]:
"""Construct a dictionary for custom_objects while saving keras layers.
Args:
classes: a sequence of layer classes.
Returns:
A dictionary of layer classes, keyed by name.
"""
return {x.__name__: x for x in classes}
|