|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for ClusterPreserveQuantizeRegistry.""" |
|
|
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry |
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config |
|
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry |
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry |
|
|
|
|
|
QuantizeConfig = quantize_config.QuantizeConfig |
|
layers = keras.layers |
|
|
|
|
|
class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase): |
|
|
|
def setUp(self): |
|
super(ClusterPreserveQuantizeRegistryTest, self).setUp() |
|
|
|
self.cluster_preserve_quantize_registry = ( |
|
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( |
|
False) |
|
) |
|
|
|
|
|
self.layer_conv2d = layers.Conv2D(10, (2, 2)) |
|
self.layer_conv2d.build((2, 2)) |
|
|
|
self.layer_dense = layers.Dense(10) |
|
self.layer_dense.build((2, 2)) |
|
|
|
self.layer_relu = layers.ReLU() |
|
self.layer_relu.build((2, 2)) |
|
|
|
|
|
|
|
self.layer_custom = self.CustomLayer() |
|
self.layer_custom.build() |
|
|
|
class CustomLayer(layers.Layer): |
|
"""A simple custom layer with training weights.""" |
|
|
|
def build(self, input_shape=(2, 2)): |
|
self.add_weight(shape=input_shape, |
|
initializer='random_normal', |
|
trainable=True) |
|
|
|
class CustomQuantizeConfig(QuantizeConfig): |
|
"""A dummy concrete class for testing unregistered configs.""" |
|
|
|
def get_weights_and_quantizers(self, layer): |
|
return [] |
|
|
|
def get_activations_and_quantizers(self, layer): |
|
return [] |
|
|
|
def set_quantize_weights(self, layer, quantize_weights): |
|
pass |
|
|
|
def set_quantize_activations(self, layer, quantize_activations): |
|
pass |
|
|
|
def get_output_quantizers(self, layer): |
|
return [] |
|
|
|
def get_config(self): |
|
return {} |
|
|
|
def testSupportsKerasLayer(self): |
|
|
|
self.assertTrue( |
|
self.cluster_preserve_quantize_registry.supports(self.layer_dense)) |
|
self.assertTrue( |
|
self.cluster_preserve_quantize_registry.supports(self.layer_conv2d)) |
|
|
|
self.assertTrue( |
|
self.cluster_preserve_quantize_registry.supports(self.layer_relu)) |
|
|
|
def testDoesNotSupportCustomLayer(self): |
|
self.assertFalse( |
|
self.cluster_preserve_quantize_registry.supports(self.layer_custom)) |
|
|
|
def testApplyClusterPreserveWithQuantizeConfig(self): |
|
(self.cluster_preserve_quantize_registry |
|
.apply_cluster_preserve_quantize_config( |
|
self.layer_conv2d, |
|
default_8bit_quantize_registry.Default8BitConvQuantizeConfig( |
|
['kernel'], ['activation'], False))) |
|
|
|
def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self): |
|
with self.assertRaises( |
|
ValueError, msg='Unregistered QuantizeConfigs should raise error.'): |
|
(self.cluster_preserve_quantize_registry. |
|
apply_cluster_preserve_quantize_config( |
|
self.layer_conv2d, self.CustomQuantizeConfig)) |
|
|
|
with self.assertRaises(ValueError, |
|
msg='Unregistered layers should raise error.'): |
|
(self.cluster_preserve_quantize_registry. |
|
apply_cluster_preserve_quantize_config( |
|
self.layer_custom, self.CustomQuantizeConfig)) |
|
|
|
|
|
class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase): |
|
|
|
def setUp(self): |
|
super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp() |
|
self.default_8bit_quantize_registry = ( |
|
default_8bit_quantize_registry.Default8BitQuantizeRegistry()) |
|
self.cluster_registry = clustering_registry.ClusteringRegistry() |
|
|
|
self.cluster_preserve_quantize_registry = ( |
|
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry( |
|
False)) |
|
|
|
def testSupportsClusterDefault8bitQuantizeKerasLayers(self): |
|
|
|
|
|
cqat_layers_config_map = ( |
|
self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP) |
|
for cqat_support_layer in cqat_layers_config_map: |
|
if cqat_layers_config_map[cqat_support_layer].weight_attrs and ( |
|
cqat_layers_config_map[cqat_support_layer].quantize_config_attrs): |
|
self.assertIn( |
|
cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP, |
|
msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer)) |
|
self.assertIn( |
|
cqat_support_layer, |
|
self.default_8bit_quantize_registry._layer_quantize_map, |
|
msg='Default 8bit QAT doesn\'t support {}'.format( |
|
cqat_support_layer)) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|