Mixtral_ether / cluster_preserve_quantize_registry_test.py
jeduardogruiz's picture
Upload 22 files
516a027 verified
raw
history blame
5.82 kB
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ClusterPreserveQuantizeRegistry."""
import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
QuantizeConfig = quantize_config.QuantizeConfig
layers = keras.layers
class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase):
def setUp(self):
super(ClusterPreserveQuantizeRegistryTest, self).setUp()
# Test CQAT by default
self.cluster_preserve_quantize_registry = (
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
False)
)
# layers which are supported
# initial and build a Conv2D layer
self.layer_conv2d = layers.Conv2D(10, (2, 2))
self.layer_conv2d.build((2, 2))
# initial and build a Dense layer
self.layer_dense = layers.Dense(10)
self.layer_dense.build((2, 2))
# initial and build a ReLU layer
self.layer_relu = layers.ReLU()
self.layer_relu.build((2, 2))
# a layer which is not supported
# initial and build a Custom layer
self.layer_custom = self.CustomLayer()
self.layer_custom.build()
class CustomLayer(layers.Layer):
"""A simple custom layer with training weights."""
def build(self, input_shape=(2, 2)):
self.add_weight(shape=input_shape,
initializer='random_normal',
trainable=True)
class CustomQuantizeConfig(QuantizeConfig):
"""A dummy concrete class for testing unregistered configs."""
def get_weights_and_quantizers(self, layer):
return []
def get_activations_and_quantizers(self, layer):
return []
def set_quantize_weights(self, layer, quantize_weights):
pass
def set_quantize_activations(self, layer, quantize_activations):
pass
def get_output_quantizers(self, layer):
return []
def get_config(self):
return {}
def testSupportsKerasLayer(self):
# test registered layer
self.assertTrue(
self.cluster_preserve_quantize_registry.supports(self.layer_dense))
self.assertTrue(
self.cluster_preserve_quantize_registry.supports(self.layer_conv2d))
# test layer without training weights
self.assertTrue(
self.cluster_preserve_quantize_registry.supports(self.layer_relu))
def testDoesNotSupportCustomLayer(self):
self.assertFalse(
self.cluster_preserve_quantize_registry.supports(self.layer_custom))
def testApplyClusterPreserveWithQuantizeConfig(self):
(self.cluster_preserve_quantize_registry
.apply_cluster_preserve_quantize_config(
self.layer_conv2d,
default_8bit_quantize_registry.Default8BitConvQuantizeConfig(
['kernel'], ['activation'], False)))
def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self):
with self.assertRaises(
ValueError, msg='Unregistered QuantizeConfigs should raise error.'):
(self.cluster_preserve_quantize_registry.
apply_cluster_preserve_quantize_config(
self.layer_conv2d, self.CustomQuantizeConfig))
with self.assertRaises(ValueError,
msg='Unregistered layers should raise error.'):
(self.cluster_preserve_quantize_registry.
apply_cluster_preserve_quantize_config(
self.layer_custom, self.CustomQuantizeConfig))
class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase):
def setUp(self):
super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp()
self.default_8bit_quantize_registry = (
default_8bit_quantize_registry.Default8BitQuantizeRegistry())
self.cluster_registry = clustering_registry.ClusteringRegistry()
# Test CQAT by default
self.cluster_preserve_quantize_registry = (
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
False))
def testSupportsClusterDefault8bitQuantizeKerasLayers(self):
# ClusterPreserveQuantize supported layer, must be suppoted
# by both Cluster and Quantize
cqat_layers_config_map = (
self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP)
for cqat_support_layer in cqat_layers_config_map:
if cqat_layers_config_map[cqat_support_layer].weight_attrs and (
cqat_layers_config_map[cqat_support_layer].quantize_config_attrs):
self.assertIn(
cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP,
msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer))
self.assertIn(
cqat_support_layer,
self.default_8bit_quantize_registry._layer_quantize_map,
msg='Default 8bit QAT doesn\'t support {}'.format(
cqat_support_layer))
if __name__ == '__main__':
tf.test.main()