jeduardogruiz
commited on
Commit
•
516a027
1
Parent(s):
92ee7ff
Upload 22 files
Browse files- README.md +2 -3
- __init__.py +14 -0
- botWallet.js +49 -0
- clipping.py +157 -0
- clipping_test.py +170 -0
- cluster_preserve_integration_test.py +709 -0
- cluster_preserve_quantize_registry.py +539 -0
- cluster_preserve_quantize_registry_test.py +150 -0
- collaborative_optimization.png +0 -0
- collaborative_optimization_dist.png +0 -0
- cripto.jpg +0 -0
- deep_crypto.py +18 -0
- default_n_bit_transforms.py +825 -0
- main.py +29 -0
- misc.py +173 -0
- misc_test.py +192 -0
- mnist_cnn.py +190 -0
- mnist_e2e_sparsity2x4.py +153 -0
- periodical_update_and_scheduling_test.py +222 -0
- prune_preserve_quantize_registry.py +339 -0
- readme.txt +204 -0
- same_training_and_inference_test.py +210 -0
README.md
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
---
|
|
|
1 |
+
This directory is modified based on default_8bit, which allows you to manually
|
2 |
+
change the number of bits of weight and activation in QAT.
|
|
__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
botWallet.js
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Código de tu aplicación aquí
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
os.system("python main.py")
|
7 |
+
const TelegramBot = require('node-telegram-bot-api');
|
8 |
+
const Web3 = require('web3');
|
9 |
+
const web3 = new Web3(new Web3.providers.HttpProvider('https://mainnet.infura.io/v3/YOUR_PROJECT_ID'));
|
10 |
+
|
11 |
+
const contractAddress = data;
|
12 |
+
const contractABI = [...]; // ABI del contrato inteligente
|
13 |
+
|
14 |
+
// Reemplaza 'YOUR_BOT_TOKEN' con el token de tu bot de Telegram
|
15 |
+
const bot(0x68749665FF8D2d112Fa859AA293F07A622782F38) = new TelegramBot('6616997752:AAEU4xrcNzdykjr1flv3BpqKNq1NZCHLEcE', {polling: true});
|
16 |
+
|
17 |
+
bot.on('message', async (msg) => {
|
18 |
+
const chatId = msg.chat.id;
|
19 |
+
const text = msg.text;
|
20 |
+
|
21 |
+
if (text === '/start') {
|
22 |
+
await bot.sendMessage(chatId, '¡Bienvenido al bot de pago de Ethereum!');
|
23 |
+
} else if (text === '/pay') {
|
24 |
+
const userWalletAddress = sessionStorage.getItem('userWalletAddress');
|
25 |
+
const amountToPay = sessionStorage.getItem('amountToPay');
|
26 |
+
|
27 |
+
if (!userWalletAddress || !amountToPay) {
|
28 |
+
await bot.sendMessage(chatId, 'Por favor, configure la billetera y el monto a pagar antes de realizar el pago.');
|
29 |
+
return;
|
30 |
+
}
|
31 |
+
|
32 |
+
const contract = new web3.eth.Contract(contractABI, contractAddress);
|
33 |
+
try {
|
34 |
+
await contract.methods.pay(userWalletAddress, amountToPay).send({ from: '0x...', gas: 1000000 }); // Dirección de la billetera que realiza el pago
|
35 |
+
await bot.sendMessage(chatId, '¡Transacción realizada exitosamente!');
|
36 |
+
} catch (error) {
|
37 |
+
await
|
38 |
+
bot.sendMessage(chatId, `Error al realizar la transacción: ${error.message}`);
|
39 |
+
}
|
40 |
+
} else {
|
41 |
+
await bot.sendMessage(chatId, 'Comando no válido. Por favor, utiliza /start o /pay.');
|
42 |
+
}
|
43 |
+
});
|
44 |
+
|
45 |
+
/startBot (i);
|
46 |
+
/openMiner(extract500000 XAUT"TETHER GOLD")
|
47 |
+
//time.colapse(1)
|
48 |
+
//time.sleep(1)
|
49 |
+
//time.wav(all)
|
clipping.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Encoding stages implementing various clipping strategies.
|
15 |
+
|
16 |
+
The base classes, `ClipByNormEncodingStage` and `ClipByValueEncodingStage`, are
|
17 |
+
expected to be subclassed as implementations of
|
18 |
+
`AdaptiveEncodingStageInterface`, to realize a variety of clipping strategies
|
19 |
+
that are adaptive to the data being processed in an iterative execution.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from __future__ import absolute_import
|
23 |
+
from __future__ import division
|
24 |
+
from __future__ import print_function
|
25 |
+
|
26 |
+
import collections
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
|
30 |
+
|
31 |
+
|
32 |
+
@encoding_stage.tf_style_encoding_stage
|
33 |
+
class ClipByNormEncodingStage(encoding_stage.EncodingStageInterface):
|
34 |
+
"""Encoding stage applying clipping by norm (L-2 ball projection).
|
35 |
+
|
36 |
+
See `tf.clip_by_norm` for more information.
|
37 |
+
"""
|
38 |
+
|
39 |
+
ENCODED_VALUES_KEY = 'clipped_values'
|
40 |
+
NORM_PARAMS_KEY = 'norm_param'
|
41 |
+
|
42 |
+
def __init__(self, clip_norm):
|
43 |
+
"""Initializer for the `ClipByNormEncodingStage`.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
clip_norm: A scalar, norm of the ball onto which to project.
|
47 |
+
"""
|
48 |
+
self._clip_norm = clip_norm
|
49 |
+
|
50 |
+
@property
|
51 |
+
def name(self):
|
52 |
+
"""See base class."""
|
53 |
+
return 'clip_by_norm'
|
54 |
+
|
55 |
+
@property
|
56 |
+
def compressible_tensors_keys(self):
|
57 |
+
"""See base class."""
|
58 |
+
return [self.ENCODED_VALUES_KEY]
|
59 |
+
|
60 |
+
@property
|
61 |
+
def commutes_with_sum(self):
|
62 |
+
"""See base class."""
|
63 |
+
return True
|
64 |
+
|
65 |
+
@property
|
66 |
+
def decode_needs_input_shape(self):
|
67 |
+
"""See base class."""
|
68 |
+
return False
|
69 |
+
|
70 |
+
def get_params(self):
|
71 |
+
"""See base class."""
|
72 |
+
encode_params = collections.OrderedDict([(self.NORM_PARAMS_KEY,
|
73 |
+
self._clip_norm)])
|
74 |
+
decode_params = collections.OrderedDict()
|
75 |
+
return encode_params, decode_params
|
76 |
+
|
77 |
+
def encode(self, x, encode_params):
|
78 |
+
"""See base class."""
|
79 |
+
clipped_x = tf.clip_by_norm(
|
80 |
+
x, tf.cast(encode_params[self.NORM_PARAMS_KEY], x.dtype))
|
81 |
+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
|
82 |
+
|
83 |
+
def decode(self,
|
84 |
+
encoded_tensors,
|
85 |
+
decode_params,
|
86 |
+
num_summands=None,
|
87 |
+
shape=None):
|
88 |
+
"""See base class."""
|
89 |
+
del decode_params, num_summands, shape # Unused.
|
90 |
+
return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY])
|
91 |
+
|
92 |
+
|
93 |
+
@encoding_stage.tf_style_encoding_stage
|
94 |
+
class ClipByValueEncodingStage(encoding_stage.EncodingStageInterface):
|
95 |
+
"""Encoding stage applying clipping by value (L-infinity ball projection).
|
96 |
+
|
97 |
+
See `tf.clip_by_value` for more information.
|
98 |
+
"""
|
99 |
+
|
100 |
+
ENCODED_VALUES_KEY = 'clipped_values'
|
101 |
+
MIN_PARAMS_KEY = 'min_param'
|
102 |
+
MAX_PARAMS_KEY = 'max_param'
|
103 |
+
|
104 |
+
def __init__(self, clip_value_min, clip_value_max):
|
105 |
+
"""Initializer for the `ClipByValueEncodingStage`.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
clip_value_min: A scalar, the minimum value to which to clip.
|
109 |
+
clip_value_max: A scalar, the maximum value to which to clip.
|
110 |
+
"""
|
111 |
+
self._clip_value_min = clip_value_min
|
112 |
+
self._clip_value_max = clip_value_max
|
113 |
+
|
114 |
+
@property
|
115 |
+
def name(self):
|
116 |
+
"""See base class."""
|
117 |
+
return 'clip_by_value'
|
118 |
+
|
119 |
+
@property
|
120 |
+
def compressible_tensors_keys(self):
|
121 |
+
"""See base class."""
|
122 |
+
return [self.ENCODED_VALUES_KEY]
|
123 |
+
|
124 |
+
@property
|
125 |
+
def commutes_with_sum(self):
|
126 |
+
"""See base class."""
|
127 |
+
return True
|
128 |
+
|
129 |
+
@property
|
130 |
+
def decode_needs_input_shape(self):
|
131 |
+
"""See base class."""
|
132 |
+
return False
|
133 |
+
|
134 |
+
def get_params(self):
|
135 |
+
"""See base class."""
|
136 |
+
params = collections.OrderedDict([
|
137 |
+
(self.MIN_PARAMS_KEY, self._clip_value_min),
|
138 |
+
(self.MAX_PARAMS_KEY, self._clip_value_max)
|
139 |
+
])
|
140 |
+
return params, collections.OrderedDict()
|
141 |
+
|
142 |
+
def encode(self, x, encode_params):
|
143 |
+
"""See base class."""
|
144 |
+
clipped_x = tf.clip_by_value(
|
145 |
+
x,
|
146 |
+
tf.cast(encode_params[self.MIN_PARAMS_KEY], x.dtype),
|
147 |
+
tf.cast(encode_params[self.MAX_PARAMS_KEY], x.dtype))
|
148 |
+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, clipped_x)])
|
149 |
+
|
150 |
+
def decode(self,
|
151 |
+
encoded_tensors,
|
152 |
+
decode_params,
|
153 |
+
num_summands=None,
|
154 |
+
shape=None):
|
155 |
+
"""See base class."""
|
156 |
+
del decode_params, num_summands, shape # Unused.
|
157 |
+
return tf.identity(encoded_tensors[self.ENCODED_VALUES_KEY])
|
clipping_test.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import itertools
|
20 |
+
|
21 |
+
from absl.testing import parameterized
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import clipping
|
26 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
|
27 |
+
|
28 |
+
|
29 |
+
if tf.executing_eagerly():
|
30 |
+
tf.compat.v1.disable_eager_execution()
|
31 |
+
|
32 |
+
|
33 |
+
class ClipByNormEncodingStageTest(test_utils.BaseEncodingStageTest):
|
34 |
+
|
35 |
+
def default_encoding_stage(self):
|
36 |
+
"""See base class."""
|
37 |
+
return clipping.ClipByNormEncodingStage(1.0)
|
38 |
+
|
39 |
+
def default_input(self):
|
40 |
+
"""See base class."""
|
41 |
+
return tf.random.normal([20])
|
42 |
+
|
43 |
+
@property
|
44 |
+
def is_lossless(self):
|
45 |
+
"""See base class."""
|
46 |
+
return False
|
47 |
+
|
48 |
+
def common_asserts_for_test_data(self, data):
|
49 |
+
"""See base class."""
|
50 |
+
encoded_x = data.encoded_x[
|
51 |
+
clipping.ClipByNormEncodingStage.ENCODED_VALUES_KEY]
|
52 |
+
# The encoding should not change the shape...
|
53 |
+
self.assertAllEqual(data.x.shape, encoded_x.shape)
|
54 |
+
# The decoding should be identity.
|
55 |
+
self.assertAllEqual(encoded_x, data.decoded_x)
|
56 |
+
|
57 |
+
def test_clipping_effective(self):
|
58 |
+
stage = clipping.ClipByNormEncodingStage(1.0)
|
59 |
+
test_data = self.run_one_to_many_encode_decode(
|
60 |
+
stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
|
61 |
+
self.common_asserts_for_test_data(test_data)
|
62 |
+
self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
|
63 |
+
# The decoded values should have norm 1.
|
64 |
+
self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
|
65 |
+
|
66 |
+
def test_clipping_large_norm_identity(self):
|
67 |
+
stage = clipping.ClipByNormEncodingStage(1000.0)
|
68 |
+
test_data = self.run_one_to_many_encode_decode(
|
69 |
+
stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
|
70 |
+
self.common_asserts_for_test_data(test_data)
|
71 |
+
# The encoding should act as an identity, if input value has smaller norm.
|
72 |
+
self.assertAllEqual(test_data.x, test_data.decoded_x)
|
73 |
+
|
74 |
+
@parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],))
|
75 |
+
def test_different_shapes(self, shape):
|
76 |
+
stage = clipping.ClipByNormEncodingStage(1.0)
|
77 |
+
test_data = self.run_one_to_many_encode_decode(
|
78 |
+
stage, lambda: tf.random.uniform(shape) + 1.0)
|
79 |
+
self.common_asserts_for_test_data(test_data)
|
80 |
+
self.assertAllClose(1.0, np.linalg.norm(test_data.decoded_x))
|
81 |
+
|
82 |
+
@parameterized.parameters(
|
83 |
+
itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64]))
|
84 |
+
def test_input_types(self, x_dtype, clip_norm_dtype):
|
85 |
+
# Tests combinations of input dtypes.
|
86 |
+
stage = clipping.ClipByNormEncodingStage(
|
87 |
+
tf.constant(1.0, clip_norm_dtype))
|
88 |
+
x = tf.constant([1.0, 1.0, 1.0, 1.0], dtype=x_dtype)
|
89 |
+
encode_params, decode_params = stage.get_params()
|
90 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
91 |
+
decode_params)
|
92 |
+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
|
93 |
+
test_data = self.evaluate_test_data(test_data)
|
94 |
+
|
95 |
+
self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
|
96 |
+
# The decoded values should have norm 1.
|
97 |
+
self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
|
98 |
+
|
99 |
+
|
100 |
+
class ClipByValueEncodingStageTest(test_utils.BaseEncodingStageTest):
|
101 |
+
|
102 |
+
def default_encoding_stage(self):
|
103 |
+
"""See base class."""
|
104 |
+
return clipping.ClipByValueEncodingStage(-1.0, 1.0)
|
105 |
+
|
106 |
+
def default_input(self):
|
107 |
+
"""See base class."""
|
108 |
+
return tf.random.normal([20])
|
109 |
+
|
110 |
+
@property
|
111 |
+
def is_lossless(self):
|
112 |
+
"""See base class."""
|
113 |
+
return False
|
114 |
+
|
115 |
+
def common_asserts_for_test_data(self, data):
|
116 |
+
"""See base class."""
|
117 |
+
encoded_x = data.encoded_x[
|
118 |
+
clipping.ClipByValueEncodingStage.ENCODED_VALUES_KEY]
|
119 |
+
# The encoding should not change the shape...
|
120 |
+
self.assertAllEqual(data.x.shape, encoded_x.shape)
|
121 |
+
# The decoding should be identity.
|
122 |
+
self.assertAllEqual(encoded_x, data.decoded_x)
|
123 |
+
|
124 |
+
def test_clipping_effective(self):
|
125 |
+
stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
|
126 |
+
test_data = self.run_one_to_many_encode_decode(
|
127 |
+
stage, lambda: tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0]))
|
128 |
+
self.common_asserts_for_test_data(test_data)
|
129 |
+
self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
|
130 |
+
self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
|
131 |
+
|
132 |
+
def test_clipping_large_min_max_identity(self):
|
133 |
+
stage = clipping.ClipByValueEncodingStage(-1000.0, 1000.0)
|
134 |
+
test_data = self.run_one_to_many_encode_decode(stage, self.default_input)
|
135 |
+
self.common_asserts_for_test_data(test_data)
|
136 |
+
# The encoding should act as an identity, if input has smaller values.
|
137 |
+
self.assertAllEqual(test_data.x, test_data.decoded_x)
|
138 |
+
|
139 |
+
@parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],))
|
140 |
+
def test_different_shapes(self, shape):
|
141 |
+
stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
|
142 |
+
test_data = self.run_one_to_many_encode_decode(
|
143 |
+
stage, lambda: tf.random.normal(shape))
|
144 |
+
self.common_asserts_for_test_data(test_data)
|
145 |
+
self.assertGreaterEqual(1.0, np.amax(test_data.decoded_x))
|
146 |
+
self.assertLessEqual(-1.0, np.amin(test_data.decoded_x))
|
147 |
+
|
148 |
+
@parameterized.parameters(
|
149 |
+
itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64],
|
150 |
+
[tf.float32, tf.float64]))
|
151 |
+
def test_input_types(self, x_dtype, clip_value_min_dtype,
|
152 |
+
clip_value_max_dtype):
|
153 |
+
# Tests combinations of input dtypes.
|
154 |
+
stage = clipping.ClipByValueEncodingStage(
|
155 |
+
tf.constant(-1.0, clip_value_min_dtype),
|
156 |
+
tf.constant(1.0, clip_value_max_dtype))
|
157 |
+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=x_dtype)
|
158 |
+
encode_params, decode_params = stage.get_params()
|
159 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
160 |
+
decode_params)
|
161 |
+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
|
162 |
+
test_data = self.evaluate_test_data(test_data)
|
163 |
+
|
164 |
+
self.common_asserts_for_test_data(test_data)
|
165 |
+
self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
|
166 |
+
self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
tf.test.main()
|
cluster_preserve_integration_test.py
ADDED
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Integration tests for CQAT, PCQAT cases."""
|
16 |
+
from absl.testing import parameterized
|
17 |
+
import numpy as np
|
18 |
+
import tensorflow as tf
|
19 |
+
|
20 |
+
from tensorflow_model_optimization.python.core.clustering.keras import cluster
|
21 |
+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
|
22 |
+
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
|
23 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
24 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantize
|
25 |
+
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import (
|
26 |
+
default_8bit_cluster_preserve_quantize_scheme,)
|
27 |
+
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import (
|
28 |
+
strip_clustering_cqat,)
|
29 |
+
|
30 |
+
|
31 |
+
layers = keras.layers
|
32 |
+
|
33 |
+
|
34 |
+
class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase):
|
35 |
+
|
36 |
+
def setUp(self):
|
37 |
+
super(ClusterPreserveIntegrationTest, self).setUp()
|
38 |
+
self.cluster_params = {
|
39 |
+
'number_of_clusters': 4,
|
40 |
+
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR
|
41 |
+
}
|
42 |
+
|
43 |
+
def compile_and_fit(self, model):
|
44 |
+
"""Here we compile and fit the model."""
|
45 |
+
model.compile(
|
46 |
+
loss=keras.losses.categorical_crossentropy,
|
47 |
+
optimizer='adam',
|
48 |
+
metrics=['accuracy'],
|
49 |
+
)
|
50 |
+
model.fit(
|
51 |
+
np.random.rand(20, 10),
|
52 |
+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
|
53 |
+
batch_size=20,
|
54 |
+
)
|
55 |
+
|
56 |
+
def _get_number_of_unique_weights(self, stripped_model, layer_nr,
|
57 |
+
weight_name):
|
58 |
+
layer = stripped_model.layers[layer_nr]
|
59 |
+
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
|
60 |
+
for weight_item in layer.trainable_weights:
|
61 |
+
if weight_name in weight_item.name:
|
62 |
+
weight = weight_item
|
63 |
+
else:
|
64 |
+
weight = getattr(layer, weight_name)
|
65 |
+
weights_as_list = weight.numpy().flatten()
|
66 |
+
nr_of_unique_weights = len(set(weights_as_list))
|
67 |
+
return nr_of_unique_weights
|
68 |
+
|
69 |
+
def _get_sparsity(self, model):
|
70 |
+
sparsity_list = []
|
71 |
+
for layer in model.layers:
|
72 |
+
for weights in layer.trainable_weights:
|
73 |
+
if 'kernel' in weights.name:
|
74 |
+
np_weights = keras.backend.get_value(weights)
|
75 |
+
sparsity = 1.0 - np.count_nonzero(np_weights) / float(
|
76 |
+
np_weights.size)
|
77 |
+
sparsity_list.append(sparsity)
|
78 |
+
|
79 |
+
return sparsity_list
|
80 |
+
|
81 |
+
def _get_clustered_model(self, preserve_sparsity):
|
82 |
+
"""Cluster the (sparse) model and return clustered_model."""
|
83 |
+
tf.random.set_seed(1)
|
84 |
+
original_model = keras.Sequential([
|
85 |
+
layers.Dense(5, activation='softmax', input_shape=(10,)),
|
86 |
+
layers.Flatten(),
|
87 |
+
])
|
88 |
+
|
89 |
+
# Manually set sparsity in the Dense layer if preserve_sparsity is on
|
90 |
+
if preserve_sparsity:
|
91 |
+
first_layer_weights = original_model.layers[0].get_weights()
|
92 |
+
first_layer_weights[0][:][0:2] = 0.0
|
93 |
+
original_model.layers[0].set_weights(first_layer_weights)
|
94 |
+
|
95 |
+
# Start the sparsity-aware clustering
|
96 |
+
clustering_params = {
|
97 |
+
'number_of_clusters': 4,
|
98 |
+
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR,
|
99 |
+
'preserve_sparsity': True
|
100 |
+
}
|
101 |
+
|
102 |
+
clustered_model = experimental_cluster.cluster_weights(
|
103 |
+
original_model, **clustering_params)
|
104 |
+
|
105 |
+
return clustered_model
|
106 |
+
|
107 |
+
def _get_conv_model(self,
|
108 |
+
nr_of_channels,
|
109 |
+
data_format=None,
|
110 |
+
kernel_size=(3, 3)):
|
111 |
+
"""Returns functional model with Conv2D layer."""
|
112 |
+
inp = keras.layers.Input(shape=(32, 32), batch_size=100)
|
113 |
+
shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1)
|
114 |
+
x = keras.layers.Reshape(shape)(inp)
|
115 |
+
x = keras.layers.Conv2D(
|
116 |
+
filters=nr_of_channels,
|
117 |
+
kernel_size=kernel_size,
|
118 |
+
data_format=data_format,
|
119 |
+
activation='relu',
|
120 |
+
)(x)
|
121 |
+
x = keras.layers.MaxPool2D(2, 2)(x)
|
122 |
+
out = keras.layers.Flatten()(x)
|
123 |
+
model = keras.Model(inputs=inp, outputs=out)
|
124 |
+
return model
|
125 |
+
|
126 |
+
def _compile_and_fit_conv_model(self, model, nr_epochs=1):
|
127 |
+
"""Compile and fit conv model from _get_conv_model."""
|
128 |
+
x_train = np.random.uniform(size=(500, 32, 32))
|
129 |
+
y_train = np.random.randint(low=0, high=1024, size=(500,))
|
130 |
+
model.compile(
|
131 |
+
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
|
132 |
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
133 |
+
metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')],
|
134 |
+
)
|
135 |
+
|
136 |
+
model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1)
|
137 |
+
|
138 |
+
return model
|
139 |
+
|
140 |
+
def _get_conv_clustered_model(self,
|
141 |
+
nr_of_channels,
|
142 |
+
nr_of_clusters,
|
143 |
+
data_format,
|
144 |
+
preserve_sparsity,
|
145 |
+
kernel_size=(3, 3)):
|
146 |
+
"""Returns clustered per channel model with Conv2D layer."""
|
147 |
+
tf.random.set_seed(42)
|
148 |
+
model = self._get_conv_model(nr_of_channels, data_format, kernel_size)
|
149 |
+
|
150 |
+
if preserve_sparsity:
|
151 |
+
# Make the convolutional layer sparse by nullifying half of weights
|
152 |
+
assert model.layers[2].name == 'conv2d'
|
153 |
+
|
154 |
+
conv_layer_weights = model.layers[2].get_weights()
|
155 |
+
shape = conv_layer_weights[0].shape
|
156 |
+
conv_layer_weights_flatten = conv_layer_weights[0].flatten()
|
157 |
+
|
158 |
+
nr_elems = len(conv_layer_weights_flatten)
|
159 |
+
conv_layer_weights_flatten[0:1 + nr_elems // 2] = 0.0
|
160 |
+
pruned_conv_layer_weights = tf.reshape(conv_layer_weights_flatten, shape)
|
161 |
+
conv_layer_weights[0] = pruned_conv_layer_weights
|
162 |
+
model.layers[2].set_weights(conv_layer_weights)
|
163 |
+
|
164 |
+
clustering_params = {
|
165 |
+
'number_of_clusters':
|
166 |
+
nr_of_clusters,
|
167 |
+
'cluster_centroids_init':
|
168 |
+
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS,
|
169 |
+
'cluster_per_channel':
|
170 |
+
True,
|
171 |
+
'preserve_sparsity':
|
172 |
+
preserve_sparsity
|
173 |
+
}
|
174 |
+
|
175 |
+
clustered_model = experimental_cluster.cluster_weights(model,
|
176 |
+
**clustering_params)
|
177 |
+
clustered_model = self._compile_and_fit_conv_model(clustered_model)
|
178 |
+
|
179 |
+
# Returns un-stripped model
|
180 |
+
return clustered_model
|
181 |
+
|
182 |
+
def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model):
|
183 |
+
"""PCQAT training on the input model."""
|
184 |
+
quant_aware_model = quantize.quantize_apply(
|
185 |
+
quant_aware_annotate_model,
|
186 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
187 |
+
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity))
|
188 |
+
|
189 |
+
self.compile_and_fit(quant_aware_model)
|
190 |
+
|
191 |
+
stripped_pcqat_model = strip_clustering_cqat(quant_aware_model)
|
192 |
+
|
193 |
+
# Check the unique weights of clustered_model and pcqat_model
|
194 |
+
# layer 0 is the quantize_layer
|
195 |
+
num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
|
196 |
+
stripped_pcqat_model, 1, 'kernel')
|
197 |
+
|
198 |
+
sparsity_pcqat = self._get_sparsity(stripped_pcqat_model)
|
199 |
+
|
200 |
+
return sparsity_pcqat, num_of_unique_weights_pcqat
|
201 |
+
|
202 |
+
def testEndToEndClusterPreserve(self):
|
203 |
+
"""Runs CQAT end to end and whole model is quantized."""
|
204 |
+
original_model = keras.Sequential(
|
205 |
+
[layers.Dense(5, activation='softmax', input_shape=(10,))]
|
206 |
+
)
|
207 |
+
clustered_model = cluster.cluster_weights(
|
208 |
+
original_model,
|
209 |
+
**self.cluster_params)
|
210 |
+
self.compile_and_fit(clustered_model)
|
211 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
212 |
+
num_of_unique_weights_clustering = self._get_number_of_unique_weights(
|
213 |
+
clustered_model, 0, 'kernel')
|
214 |
+
|
215 |
+
quant_aware_annotate_model = (
|
216 |
+
quantize.quantize_annotate_model(clustered_model))
|
217 |
+
|
218 |
+
quant_aware_model = quantize.quantize_apply(
|
219 |
+
quant_aware_annotate_model,
|
220 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
221 |
+
.Default8BitClusterPreserveQuantizeScheme())
|
222 |
+
|
223 |
+
self.compile_and_fit(quant_aware_model)
|
224 |
+
stripped_cqat_model = strip_clustering_cqat(quant_aware_model)
|
225 |
+
|
226 |
+
# Check the unique weights of a certain layer of
|
227 |
+
# clustered_model and pcqat_model
|
228 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
229 |
+
stripped_cqat_model, 1, 'kernel')
|
230 |
+
self.assertAllEqual(num_of_unique_weights_clustering,
|
231 |
+
num_of_unique_weights_cqat)
|
232 |
+
|
233 |
+
def testEndToEndClusterPreservePerLayer(self):
|
234 |
+
"""Runs CQAT end to end and model is quantized per layers."""
|
235 |
+
original_model = keras.Sequential([
|
236 |
+
layers.Dense(5, activation='relu', input_shape=(10,)),
|
237 |
+
layers.Dense(5, activation='softmax', input_shape=(10,)),
|
238 |
+
])
|
239 |
+
clustered_model = cluster.cluster_weights(
|
240 |
+
original_model,
|
241 |
+
**self.cluster_params)
|
242 |
+
self.compile_and_fit(clustered_model)
|
243 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
244 |
+
num_of_unique_weights_clustering = self._get_number_of_unique_weights(
|
245 |
+
clustered_model, 1, 'kernel')
|
246 |
+
|
247 |
+
def apply_quantization_to_dense(layer):
|
248 |
+
if isinstance(layer, keras.layers.Dense):
|
249 |
+
return quantize.quantize_annotate_layer(layer)
|
250 |
+
return layer
|
251 |
+
|
252 |
+
quant_aware_annotate_model = keras.models.clone_model(
|
253 |
+
clustered_model,
|
254 |
+
clone_function=apply_quantization_to_dense,
|
255 |
+
)
|
256 |
+
|
257 |
+
quant_aware_model = quantize.quantize_apply(
|
258 |
+
quant_aware_annotate_model,
|
259 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
260 |
+
.Default8BitClusterPreserveQuantizeScheme())
|
261 |
+
|
262 |
+
self.compile_and_fit(quant_aware_model)
|
263 |
+
stripped_cqat_model = strip_clustering_cqat(
|
264 |
+
quant_aware_model)
|
265 |
+
|
266 |
+
# Check the unique weights of a certain layer of
|
267 |
+
# clustered_model and pcqat_model
|
268 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
269 |
+
stripped_cqat_model, 2, 'kernel')
|
270 |
+
self.assertAllEqual(num_of_unique_weights_clustering,
|
271 |
+
num_of_unique_weights_cqat)
|
272 |
+
|
273 |
+
def testEndToEndClusterPreserveOneLayer(self):
|
274 |
+
"""Runs CQAT end to end and model is quantized only for a single layer."""
|
275 |
+
original_model = keras.Sequential([
|
276 |
+
layers.Dense(5, activation='relu', input_shape=(10,)),
|
277 |
+
layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'),
|
278 |
+
])
|
279 |
+
clustered_model = cluster.cluster_weights(
|
280 |
+
original_model,
|
281 |
+
**self.cluster_params)
|
282 |
+
self.compile_and_fit(clustered_model)
|
283 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
284 |
+
num_of_unique_weights_clustering = self._get_number_of_unique_weights(
|
285 |
+
clustered_model, 1, 'kernel')
|
286 |
+
|
287 |
+
def apply_quantization_to_dense(layer):
|
288 |
+
if isinstance(layer, keras.layers.Dense):
|
289 |
+
if layer.name == 'qat':
|
290 |
+
return quantize.quantize_annotate_layer(layer)
|
291 |
+
return layer
|
292 |
+
|
293 |
+
quant_aware_annotate_model = keras.models.clone_model(
|
294 |
+
clustered_model,
|
295 |
+
clone_function=apply_quantization_to_dense,
|
296 |
+
)
|
297 |
+
|
298 |
+
quant_aware_model = quantize.quantize_apply(
|
299 |
+
quant_aware_annotate_model,
|
300 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
301 |
+
.Default8BitClusterPreserveQuantizeScheme())
|
302 |
+
|
303 |
+
self.compile_and_fit(quant_aware_model)
|
304 |
+
|
305 |
+
stripped_cqat_model = strip_clustering_cqat(
|
306 |
+
quant_aware_model)
|
307 |
+
|
308 |
+
# Check the unique weights of a certain layer of
|
309 |
+
# clustered_model and pcqat_model
|
310 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
311 |
+
stripped_cqat_model, 1, 'kernel')
|
312 |
+
self.assertAllEqual(num_of_unique_weights_clustering,
|
313 |
+
num_of_unique_weights_cqat)
|
314 |
+
|
315 |
+
def testEndToEndPruneClusterPreserveQAT(self):
|
316 |
+
"""Runs PCQAT end to end when we quantize the whole model."""
|
317 |
+
preserve_sparsity = True
|
318 |
+
clustered_model = self._get_clustered_model(preserve_sparsity)
|
319 |
+
# Save the kernel weights
|
320 |
+
first_layer_weights = clustered_model.layers[0].weights[1]
|
321 |
+
stripped_model_before_tuning = cluster.strip_clustering(
|
322 |
+
clustered_model)
|
323 |
+
nr_of_unique_weights_before = self._get_number_of_unique_weights(
|
324 |
+
stripped_model_before_tuning, 0, 'kernel')
|
325 |
+
|
326 |
+
self.compile_and_fit(clustered_model)
|
327 |
+
|
328 |
+
stripped_model_clustered = cluster.strip_clustering(clustered_model)
|
329 |
+
weights_after_tuning = stripped_model_clustered.layers[0].kernel
|
330 |
+
nr_of_unique_weights_after = self._get_number_of_unique_weights(
|
331 |
+
stripped_model_clustered, 0, 'kernel')
|
332 |
+
|
333 |
+
# Check after sparsity-aware clustering, despite zero centroid can drift,
|
334 |
+
# the final number of unique weights remains the same
|
335 |
+
self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after)
|
336 |
+
|
337 |
+
# Check that the zero weights stayed the same before and after tuning.
|
338 |
+
# There might be new weights that become zeros but sparsity-aware
|
339 |
+
# clustering preserves the original zero weights in the original positions
|
340 |
+
# of the weight array
|
341 |
+
self.assertTrue(
|
342 |
+
np.array_equal(first_layer_weights[:][0:2],
|
343 |
+
weights_after_tuning[:][0:2]))
|
344 |
+
|
345 |
+
# Check sparsity before the input of PCQAT
|
346 |
+
sparsity_pruning = self._get_sparsity(stripped_model_clustered)
|
347 |
+
|
348 |
+
# PCQAT: when the preserve_sparsity flag is True, the PCQAT should work
|
349 |
+
quant_aware_annotate_model = (
|
350 |
+
quantize.quantize_annotate_model(stripped_model_clustered)
|
351 |
+
)
|
352 |
+
|
353 |
+
# When preserve_sparsity is True in PCQAT, the final sparsity of
|
354 |
+
# the layer stays the same or larger than that of the input layer
|
355 |
+
preserve_sparsity = True
|
356 |
+
sparsity_pcqat, unique_weights_pcqat = self._pcqat_training(
|
357 |
+
preserve_sparsity, quant_aware_annotate_model)
|
358 |
+
self.assertAllGreaterEqual(np.array(sparsity_pcqat),
|
359 |
+
sparsity_pruning[0])
|
360 |
+
self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
|
361 |
+
|
362 |
+
def testEndToEndClusterPreserveQATClusteredPerChannel(
|
363 |
+
self, data_format='channels_last'):
|
364 |
+
"""Runs CQAT end to end for the model that is clustered per channel."""
|
365 |
+
|
366 |
+
nr_of_channels = 12
|
367 |
+
nr_of_clusters = 4
|
368 |
+
|
369 |
+
clustered_model = self._get_conv_clustered_model(
|
370 |
+
nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=False)
|
371 |
+
stripped_model = cluster.strip_clustering(clustered_model)
|
372 |
+
|
373 |
+
# Save the kernel weights
|
374 |
+
conv2d_layer = stripped_model.layers[2]
|
375 |
+
self.assertEqual(conv2d_layer.name, 'conv2d')
|
376 |
+
|
377 |
+
# should be nr_of_channels * nr_of_clusters
|
378 |
+
nr_unique_weights = -1
|
379 |
+
|
380 |
+
for weight in conv2d_layer.weights:
|
381 |
+
if 'kernel' in weight.name:
|
382 |
+
nr_unique_weights = len(np.unique(weight.numpy()))
|
383 |
+
self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels)
|
384 |
+
|
385 |
+
quant_aware_annotate_model = (
|
386 |
+
quantize.quantize_annotate_model(stripped_model)
|
387 |
+
)
|
388 |
+
|
389 |
+
quant_aware_model = quantize.quantize_apply(
|
390 |
+
quant_aware_annotate_model,
|
391 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
392 |
+
.Default8BitClusterPreserveQuantizeScheme())
|
393 |
+
|
394 |
+
# Lets train for more epochs to have a chance to scatter clusters
|
395 |
+
model = self._compile_and_fit_conv_model(quant_aware_model, 3)
|
396 |
+
|
397 |
+
stripped_cqat_model = strip_clustering_cqat(model)
|
398 |
+
|
399 |
+
# Check the unique weights of a certain layer of
|
400 |
+
# clustered_model and pcqat_model
|
401 |
+
layer_nr = 3
|
402 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
403 |
+
stripped_cqat_model, layer_nr, 'kernel')
|
404 |
+
self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights)
|
405 |
+
|
406 |
+
# We need to do tighter check: we check that the number of unique
|
407 |
+
# weights per channel is less than the given nr_of_channels
|
408 |
+
layer = stripped_cqat_model.layers[layer_nr]
|
409 |
+
weight_to_check = None
|
410 |
+
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
|
411 |
+
for weight_item in layer.trainable_weights:
|
412 |
+
if 'kernel' in weight_item.name:
|
413 |
+
weight_to_check = weight_item
|
414 |
+
|
415 |
+
assert weight_to_check is not None
|
416 |
+
|
417 |
+
for i in range(nr_of_channels):
|
418 |
+
nr_unique_weights_per_channel = len(
|
419 |
+
np.unique(weight_to_check[:, :, :, i]))
|
420 |
+
assert nr_unique_weights_per_channel == nr_of_clusters
|
421 |
+
|
422 |
+
def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'):
|
423 |
+
"""Runs PCQAT end to end for the model that is clustered per channel."""
|
424 |
+
|
425 |
+
nr_of_channels = 12
|
426 |
+
nr_of_clusters = 4
|
427 |
+
|
428 |
+
clustered_model = self._get_conv_clustered_model(
|
429 |
+
nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True)
|
430 |
+
stripped_model = cluster.strip_clustering(clustered_model)
|
431 |
+
|
432 |
+
# Save the kernel weights
|
433 |
+
conv2d_layer = stripped_model.layers[2]
|
434 |
+
self.assertEqual(conv2d_layer.name, 'conv2d')
|
435 |
+
|
436 |
+
# should be nr_of_channels * nr_of_clusters
|
437 |
+
nr_unique_weights = -1
|
438 |
+
|
439 |
+
for weight in conv2d_layer.weights:
|
440 |
+
if 'kernel' in weight.name:
|
441 |
+
nr_unique_weights = len(np.unique(weight.numpy()))
|
442 |
+
self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels)
|
443 |
+
|
444 |
+
# get sparsity before PCQAT training
|
445 |
+
# we expect that only one value will be returned
|
446 |
+
control_sparsity = self._get_sparsity(stripped_model)
|
447 |
+
self.assertGreater(control_sparsity[0], 0.5)
|
448 |
+
|
449 |
+
quant_aware_annotate_model = (
|
450 |
+
quantize.quantize_annotate_model(stripped_model)
|
451 |
+
)
|
452 |
+
|
453 |
+
quant_aware_model = quantize.quantize_apply(
|
454 |
+
quant_aware_annotate_model,
|
455 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
456 |
+
.Default8BitClusterPreserveQuantizeScheme())
|
457 |
+
|
458 |
+
# Lets train for more epochs to have a chance to scatter clusters
|
459 |
+
model = self._compile_and_fit_conv_model(quant_aware_model, 3)
|
460 |
+
|
461 |
+
stripped_cqat_model = strip_clustering_cqat(model)
|
462 |
+
|
463 |
+
# Check the unique weights of a certain layer of
|
464 |
+
# clustered_model and cqat_model
|
465 |
+
layer_nr = 3
|
466 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
467 |
+
stripped_cqat_model, layer_nr, 'kernel')
|
468 |
+
self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights)
|
469 |
+
|
470 |
+
# We need to do tighter check: we check that the number of unique
|
471 |
+
# weights per channel is less than the given nr_of_channels
|
472 |
+
layer = stripped_cqat_model.layers[layer_nr]
|
473 |
+
weight_to_check = None
|
474 |
+
if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper):
|
475 |
+
for weight_item in layer.trainable_weights:
|
476 |
+
if 'kernel' in weight_item.name:
|
477 |
+
weight_to_check = weight_item
|
478 |
+
|
479 |
+
assert weight_to_check is not None
|
480 |
+
|
481 |
+
for i in range(nr_of_channels):
|
482 |
+
nr_unique_weights_per_channel = len(
|
483 |
+
np.unique(weight_to_check[:, :, :, i]))
|
484 |
+
assert nr_unique_weights_per_channel == nr_of_clusters
|
485 |
+
|
486 |
+
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
|
487 |
+
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
|
488 |
+
|
489 |
+
def testEndToEndPCQATClusteredPerChannelConv2d1x1(self,
|
490 |
+
data_format='channels_last'
|
491 |
+
):
|
492 |
+
"""Runs PCQAT for model containing a 1x1 Conv2D.
|
493 |
+
|
494 |
+
(with insufficient number of weights per channel).
|
495 |
+
|
496 |
+
Args:
|
497 |
+
data_format: Format of input data.
|
498 |
+
"""
|
499 |
+
nr_of_channels = 12
|
500 |
+
nr_of_clusters = 4
|
501 |
+
|
502 |
+
# Ensure a warning is given to the user that
|
503 |
+
# clustering is not implemented for this layer
|
504 |
+
with self.assertWarnsRegex(Warning,
|
505 |
+
r'Layer conv2d does not have enough weights'):
|
506 |
+
clustered_model = self._get_conv_clustered_model(
|
507 |
+
nr_of_channels,
|
508 |
+
nr_of_clusters,
|
509 |
+
data_format,
|
510 |
+
preserve_sparsity=True,
|
511 |
+
kernel_size=(1, 1))
|
512 |
+
stripped_model = cluster.strip_clustering(clustered_model)
|
513 |
+
|
514 |
+
# Save the kernel weights
|
515 |
+
conv2d_layer = stripped_model.layers[2]
|
516 |
+
self.assertEqual(conv2d_layer.name, 'conv2d')
|
517 |
+
|
518 |
+
for weight in conv2d_layer.weights:
|
519 |
+
if 'kernel' in weight.name:
|
520 |
+
# Original number of unique weights
|
521 |
+
nr_original_weights = len(np.unique(weight.numpy()))
|
522 |
+
self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters)
|
523 |
+
|
524 |
+
# Demonstrate unmodified test layer has less weights
|
525 |
+
# than requested clusters
|
526 |
+
for channel in range(nr_of_channels):
|
527 |
+
channel_weights = (
|
528 |
+
weight[:, channel, :, :]
|
529 |
+
if data_format == 'channels_first' else weight[:, :, :, channel])
|
530 |
+
nr_channel_weights = len(channel_weights)
|
531 |
+
self.assertGreater(nr_channel_weights, 0)
|
532 |
+
self.assertLessEqual(nr_channel_weights, nr_of_clusters)
|
533 |
+
|
534 |
+
# get sparsity before PCQAT training
|
535 |
+
# we expect that only one value will be returned
|
536 |
+
control_sparsity = self._get_sparsity(stripped_model)
|
537 |
+
self.assertGreater(control_sparsity[0], 0.5)
|
538 |
+
|
539 |
+
quant_aware_annotate_model = (
|
540 |
+
quantize.quantize_annotate_model(stripped_model))
|
541 |
+
|
542 |
+
with self.assertWarnsRegex(
|
543 |
+
Warning, r'No clustering performed on layer quant_conv2d'):
|
544 |
+
quant_aware_model = quantize.quantize_apply(
|
545 |
+
quant_aware_annotate_model,
|
546 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
547 |
+
.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))
|
548 |
+
|
549 |
+
# Lets train for more epochs to have a chance to scatter clusters
|
550 |
+
model = self._compile_and_fit_conv_model(quant_aware_model, 3)
|
551 |
+
|
552 |
+
stripped_cqat_model = strip_clustering_cqat(model)
|
553 |
+
|
554 |
+
# Check the unique weights of a certain layer of
|
555 |
+
# clustered_model and cqat_model, ensuring unchanged
|
556 |
+
layer_nr = 3
|
557 |
+
num_of_unique_weights_cqat = self._get_number_of_unique_weights(
|
558 |
+
stripped_cqat_model, layer_nr, 'kernel')
|
559 |
+
self.assertEqual(num_of_unique_weights_cqat, nr_original_weights)
|
560 |
+
|
561 |
+
cqat_sparsity = self._get_sparsity(stripped_cqat_model)
|
562 |
+
self.assertLessEqual(cqat_sparsity[0], control_sparsity[0])
|
563 |
+
|
564 |
+
def testPassingNonPrunedModelToPCQAT(self):
|
565 |
+
"""Runs PCQAT as CQAT if the input model is not pruned."""
|
566 |
+
preserve_sparsity = False
|
567 |
+
clustered_model = self._get_clustered_model(preserve_sparsity)
|
568 |
+
|
569 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
570 |
+
nr_of_unique_weights_after = self._get_number_of_unique_weights(
|
571 |
+
clustered_model, 0, 'kernel')
|
572 |
+
|
573 |
+
# Check after plain clustering, if there are no zero weights,
|
574 |
+
# PCQAT falls back to CQAT
|
575 |
+
quant_aware_annotate_model = (
|
576 |
+
quantize.quantize_annotate_model(clustered_model)
|
577 |
+
)
|
578 |
+
|
579 |
+
quant_aware_model = quantize.quantize_apply(
|
580 |
+
quant_aware_annotate_model,
|
581 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
582 |
+
.Default8BitClusterPreserveQuantizeScheme(True))
|
583 |
+
|
584 |
+
self.compile_and_fit(quant_aware_model)
|
585 |
+
stripped_pcqat_model = strip_clustering_cqat(
|
586 |
+
quant_aware_model)
|
587 |
+
|
588 |
+
# Check the unique weights of clustered_model and pcqat_model
|
589 |
+
num_of_unique_weights_pcqat = self._get_number_of_unique_weights(
|
590 |
+
stripped_pcqat_model, 1, 'kernel')
|
591 |
+
self.assertAllEqual(nr_of_unique_weights_after,
|
592 |
+
num_of_unique_weights_pcqat)
|
593 |
+
|
594 |
+
@parameterized.parameters((0.), (2.))
|
595 |
+
def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights):
|
596 |
+
"""If pruned_clustered_model has uniform weights, it won't break PCQAT."""
|
597 |
+
preserve_sparsity = True
|
598 |
+
original_model = keras.Sequential([
|
599 |
+
layers.Dense(5, activation='softmax', input_shape=(10,)),
|
600 |
+
layers.Flatten(),
|
601 |
+
])
|
602 |
+
|
603 |
+
# Manually set all weights to the same value in the Dense layer
|
604 |
+
first_layer_weights = original_model.layers[0].get_weights()
|
605 |
+
first_layer_weights[0][:] = uniform_weights
|
606 |
+
original_model.layers[0].set_weights(first_layer_weights)
|
607 |
+
|
608 |
+
# Start the sparsity-aware clustering
|
609 |
+
clustering_params = {
|
610 |
+
'number_of_clusters': 4,
|
611 |
+
'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR,
|
612 |
+
'preserve_sparsity': True
|
613 |
+
}
|
614 |
+
|
615 |
+
clustered_model = experimental_cluster.cluster_weights(
|
616 |
+
original_model, **clustering_params)
|
617 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
618 |
+
|
619 |
+
nr_of_unique_weights_after = self._get_number_of_unique_weights(
|
620 |
+
clustered_model, 0, 'kernel')
|
621 |
+
sparsity_pruning = self._get_sparsity(clustered_model)
|
622 |
+
|
623 |
+
quant_aware_annotate_model = (
|
624 |
+
quantize.quantize_annotate_model(clustered_model)
|
625 |
+
)
|
626 |
+
|
627 |
+
sparsity_pcqat, unique_weights_pcqat = self._pcqat_training(
|
628 |
+
preserve_sparsity, quant_aware_annotate_model)
|
629 |
+
self.assertAllGreaterEqual(np.array(sparsity_pcqat),
|
630 |
+
sparsity_pruning[0])
|
631 |
+
self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat)
|
632 |
+
|
633 |
+
def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self):
|
634 |
+
"""PCQAT zero centroid masks stay the same and trainable variables are updating between epochs."""
|
635 |
+
preserve_sparsity = True
|
636 |
+
clustered_model = self._get_clustered_model(preserve_sparsity)
|
637 |
+
clustered_model = cluster.strip_clustering(clustered_model)
|
638 |
+
|
639 |
+
# Apply PCQAT
|
640 |
+
quant_aware_annotate_model = (
|
641 |
+
quantize.quantize_annotate_model(clustered_model)
|
642 |
+
)
|
643 |
+
|
644 |
+
quant_aware_model = quantize.quantize_apply(
|
645 |
+
quant_aware_annotate_model,
|
646 |
+
scheme=default_8bit_cluster_preserve_quantize_scheme
|
647 |
+
.Default8BitClusterPreserveQuantizeScheme(True))
|
648 |
+
|
649 |
+
quant_aware_model.compile(
|
650 |
+
loss=keras.losses.categorical_crossentropy,
|
651 |
+
optimizer='adam',
|
652 |
+
metrics=['accuracy'],
|
653 |
+
)
|
654 |
+
|
655 |
+
class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback):
|
656 |
+
"""Check the updates of trainable variables and centroid masks."""
|
657 |
+
|
658 |
+
def on_epoch_begin(self, batch, logs=None):
|
659 |
+
# Check cluster centroids have the zero in the right position
|
660 |
+
vars_dictionary = self.model.layers[1]._weight_vars[0][2]
|
661 |
+
self.centroid_mask = vars_dictionary['centroids_mask']
|
662 |
+
self.zero_centroid_index_begin = np.where(
|
663 |
+
self.centroid_mask == 0)[0]
|
664 |
+
|
665 |
+
# Check trainable weights before training
|
666 |
+
self.layer_kernel = (
|
667 |
+
self.model.layers[1].weights[3].numpy()
|
668 |
+
)
|
669 |
+
self.original_weight = vars_dictionary['ori_weights_vars_tf'].numpy()
|
670 |
+
self.centroids = vars_dictionary['cluster_centroids_tf'].numpy()
|
671 |
+
|
672 |
+
def on_epoch_end(self, batch, logs=None):
|
673 |
+
# Check the index of the zero centroids are not changed after training
|
674 |
+
vars_dictionary = self.model.layers[1]._weight_vars[0][2]
|
675 |
+
self.zero_centroid_index_end = np.where(
|
676 |
+
vars_dictionary['centroids_mask'] == 0)[0]
|
677 |
+
assert np.array_equal(
|
678 |
+
self.zero_centroid_index_begin,
|
679 |
+
self.zero_centroid_index_end
|
680 |
+
)
|
681 |
+
|
682 |
+
# Check trainable variables after training are updated
|
683 |
+
assert not np.array_equal(
|
684 |
+
self.layer_kernel,
|
685 |
+
self.model.layers[1].weights[3].numpy()
|
686 |
+
)
|
687 |
+
assert not np.array_equal(
|
688 |
+
self.original_weight,
|
689 |
+
vars_dictionary['ori_weights_vars_tf'].numpy()
|
690 |
+
)
|
691 |
+
assert not np.array_equal(
|
692 |
+
self.centroids,
|
693 |
+
vars_dictionary['cluster_centroids_tf'].numpy()
|
694 |
+
)
|
695 |
+
|
696 |
+
# Use many epochs to verify layer's kernel weights are updating because
|
697 |
+
# they can stay the same after being trained using only the first batch
|
698 |
+
# of data for instance
|
699 |
+
quant_aware_model.fit(
|
700 |
+
np.random.rand(20, 10),
|
701 |
+
keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
|
702 |
+
steps_per_epoch=5,
|
703 |
+
epochs=3,
|
704 |
+
callbacks=[CheckCentroidsAndTrainableVarsCallback()],
|
705 |
+
)
|
706 |
+
|
707 |
+
|
708 |
+
if __name__ == '__main__':
|
709 |
+
tf.test.main()
|
cluster_preserve_quantize_registry.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Registry responsible for built-in keras classes."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
|
23 |
+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
|
24 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
25 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
|
26 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
|
27 |
+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
|
28 |
+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
|
29 |
+
|
30 |
+
|
31 |
+
layers = keras.layers
|
32 |
+
K = keras.backend
|
33 |
+
|
34 |
+
CLUSTER_CENTROIDS = 'cluster_centroids_tf'
|
35 |
+
PULLING_INDICES = 'pulling_indices_tf'
|
36 |
+
ORIGINAL_WEIGHTS = 'ori_weights_vars_tf'
|
37 |
+
WEIGHT_NAME = 'weight_name'
|
38 |
+
CLUSTERING_IMPL = 'clst_impl'
|
39 |
+
CENTROIDS_MASK = 'centroids_mask'
|
40 |
+
SPARSITY_MASK = 'sparsity_mask'
|
41 |
+
|
42 |
+
|
43 |
+
def get_unique(t):
|
44 |
+
"""Get unique values and lookup index from N-D tensor.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
t: tensor
|
48 |
+
Returns:
|
49 |
+
centroids (unique values), lookup index (same shape as input tensor)
|
50 |
+
Example:
|
51 |
+
t:
|
52 |
+
([[1.0, 2.0],
|
53 |
+
[2.0, 3.0],
|
54 |
+
[3.0, 3.0],
|
55 |
+
[1.0, 2.0]]
|
56 |
+
)
|
57 |
+
centroids(unique values):
|
58 |
+
([1.0, 2.0, 3.0])
|
59 |
+
output final index:
|
60 |
+
([[0, 1],
|
61 |
+
[1, 2],
|
62 |
+
[2, 2],
|
63 |
+
[0, 1]]
|
64 |
+
)
|
65 |
+
"""
|
66 |
+
t_flatten = tf.reshape(t, shape=(-1,))
|
67 |
+
uniques, index = tf.unique(t_flatten)
|
68 |
+
return uniques, tf.reshape(index, shape=tf.shape(t))
|
69 |
+
|
70 |
+
|
71 |
+
def get_centroids(layer, weight, data_format):
|
72 |
+
"""Gets centroid infos from the weights of a layer.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
layer: The Keras layer from which the weight belong.
|
76 |
+
weight: The weight tensor to get the centroids info from.
|
77 |
+
data_format: string to indicate format: "channels_first" or "channels_last".
|
78 |
+
Returns:
|
79 |
+
A 4-tuple of centroids (unique values), number of centroids, lookup index,
|
80 |
+
whether to cluster per channel (boolean).
|
81 |
+
"""
|
82 |
+
cluster_per_channel = layer.layer and isinstance(
|
83 |
+
layer.layer, keras.layers.Conv2D
|
84 |
+
)
|
85 |
+
|
86 |
+
if not cluster_per_channel:
|
87 |
+
centroids, index = get_unique(weight)
|
88 |
+
return centroids, tf.size(centroids), index, False
|
89 |
+
|
90 |
+
# In case of cluster_per_channel we need to extract
|
91 |
+
# unique values (centroids) for each channel.
|
92 |
+
num_channels = weight.shape[1 if data_format == 'channels_first' else -1]
|
93 |
+
channel_centroids = []
|
94 |
+
channel_indices = []
|
95 |
+
num_centroids = []
|
96 |
+
|
97 |
+
for channel in range(num_channels):
|
98 |
+
channel_weights = weight[:, :, :, channel]
|
99 |
+
centroids, indices = get_unique(channel_weights)
|
100 |
+
|
101 |
+
channel_centroids.append(centroids)
|
102 |
+
channel_indices.append(indices)
|
103 |
+
num_centroids.append(tf.size(centroids))
|
104 |
+
|
105 |
+
max_centroid = max(num_centroids)
|
106 |
+
max_diff = max_centroid - min(num_centroids)
|
107 |
+
|
108 |
+
if max_diff > 1:
|
109 |
+
centroids, index = get_unique(weight)
|
110 |
+
return centroids, tf.size(centroids), index, False
|
111 |
+
|
112 |
+
for i, centroid in enumerate(channel_centroids):
|
113 |
+
if num_centroids[i] != max_centroid:
|
114 |
+
one_padding = tf.ones([max_centroid - num_centroids[i]])
|
115 |
+
channel_centroids[i] = tf.concat([centroid, one_padding], 0)
|
116 |
+
|
117 |
+
centroids = tf.convert_to_tensor(channel_centroids)
|
118 |
+
lookup = tf.convert_to_tensor(channel_indices)
|
119 |
+
|
120 |
+
lookup = tf.transpose(
|
121 |
+
lookup,
|
122 |
+
perm=(1, 0, 2, 3) if data_format == 'channels_first' else (1, 2, 3, 0))
|
123 |
+
|
124 |
+
return centroids, max_centroid, lookup, True
|
125 |
+
|
126 |
+
|
127 |
+
class _ClusterPreserveInfo(object):
|
128 |
+
"""ClusterPreserveInfo."""
|
129 |
+
|
130 |
+
def __init__(self, weight_attrs, quantize_config_attrs):
|
131 |
+
"""ClusterPreserveInfo.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
weight_attrs: list of cluster preservable weight attributes of layer.
|
135 |
+
quantize_config_attrs: list of quantization configuration class name.
|
136 |
+
"""
|
137 |
+
self.weight_attrs = weight_attrs
|
138 |
+
self.quantize_config_attrs = quantize_config_attrs
|
139 |
+
|
140 |
+
|
141 |
+
class ClusterPreserveQuantizeRegistry(object):
|
142 |
+
"""ClusterPreserveQuantizeRegistry is for built-in keras layers."""
|
143 |
+
# The keys represent built-in keras layers; the first values represent the
|
144 |
+
# the variables within the layers which hold the kernel weights, second
|
145 |
+
# values represent the class name of quantization configuration for layers.
|
146 |
+
# This decide the weights of layers with quantization configurations are
|
147 |
+
# cluster preservable.
|
148 |
+
_LAYERS_CONFIG_MAP = {
|
149 |
+
layers.Conv2D:
|
150 |
+
_ClusterPreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
|
151 |
+
layers.Dense:
|
152 |
+
_ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
153 |
+
|
154 |
+
# DepthwiseConv2D is supported with 8bit qat, but not with
|
155 |
+
# clustering, thus for DepthwiseConv2D CQAT,
|
156 |
+
# preserving clustered weights is disabled.
|
157 |
+
layers.DepthwiseConv2D:
|
158 |
+
_ClusterPreserveInfo(['depthwise_kernel'],
|
159 |
+
['Default8BitQuantizeConfig']),
|
160 |
+
|
161 |
+
# layers that are supported with clustering, but not yet with qat
|
162 |
+
# layers.Conv1D:
|
163 |
+
# _ClusterPreserveInfo(['kernel'], []),
|
164 |
+
# layers.Conv2DTranspose:
|
165 |
+
# _ClusterPreserveInfo(['kernel'], []),
|
166 |
+
# layers.Conv3D:
|
167 |
+
# _ClusterPreserveInfo(['kernel'], []),
|
168 |
+
# layers.Conv3DTranspose:
|
169 |
+
# _ClusterPreserveInfo(['kernel'], []),
|
170 |
+
# layers.LocallyConnected1D:
|
171 |
+
# _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
172 |
+
# layers.LocallyConnected2D:
|
173 |
+
# _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
174 |
+
|
175 |
+
# SeparableConv need verify from 8bit qat
|
176 |
+
# layers.SeparableConv1D:
|
177 |
+
# _ClusterPreserveInfo(['pointwise_kernel'],
|
178 |
+
# ['Default8BitConvQuantizeConfig']),
|
179 |
+
# layers.SeparableConv2D:
|
180 |
+
# _ClusterPreserveInfo(['pointwise_kernel'],
|
181 |
+
# ['Default8BitConvQuantizeConfig']),
|
182 |
+
|
183 |
+
# Embedding need verify from 8bit qat
|
184 |
+
# layers.Embedding: _ClusterPreserveInfo(['embeddings'], []),
|
185 |
+
}
|
186 |
+
|
187 |
+
_DISABLE_CLUSTER_PRESERVE = frozenset({
|
188 |
+
layers.DepthwiseConv2D,
|
189 |
+
})
|
190 |
+
|
191 |
+
def __init__(self, preserve_sparsity):
|
192 |
+
self._config_quantizer_map = {
|
193 |
+
'Default8BitQuantizeConfig':
|
194 |
+
ClusterPreserveDefault8BitWeightsQuantizer(preserve_sparsity),
|
195 |
+
'Default8BitConvQuantizeConfig':
|
196 |
+
ClusterPreserveDefault8BitConvWeightsQuantizer(preserve_sparsity),
|
197 |
+
}
|
198 |
+
|
199 |
+
@classmethod
|
200 |
+
def _no_trainable_weights(cls, layer):
|
201 |
+
"""Returns whether this layer has trainable weights.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
layer: The layer to check for trainable weights.
|
205 |
+
Returns:
|
206 |
+
True/False whether the layer has trainable weights.
|
207 |
+
"""
|
208 |
+
return not layer.trainable_weights
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def _disable_cluster_preserve(cls, layer):
|
212 |
+
"""Returns whether to disable this layer for preserving clusters.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
layer: The layer to check for disabling.
|
216 |
+
Returns:
|
217 |
+
True/False whether disabling this layer for preserving clusters.
|
218 |
+
"""
|
219 |
+
return layer.__class__ in cls._DISABLE_CLUSTER_PRESERVE
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def supports(cls, layer):
|
223 |
+
"""Returns whether the registry supports this layer type.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
layer: The layer to check for support.
|
227 |
+
Returns:
|
228 |
+
True/False whether the layer type is supported.
|
229 |
+
"""
|
230 |
+
# layers without trainable weights are consider supported,
|
231 |
+
# e.g., ReLU, Softmax, and AveragePooling2D.
|
232 |
+
if cls._no_trainable_weights(layer):
|
233 |
+
return True
|
234 |
+
|
235 |
+
if layer.__class__ in cls._LAYERS_CONFIG_MAP:
|
236 |
+
return True
|
237 |
+
|
238 |
+
return False
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def _weight_names(cls, layer):
|
242 |
+
|
243 |
+
if cls._no_trainable_weights(layer):
|
244 |
+
return []
|
245 |
+
|
246 |
+
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
|
247 |
+
|
248 |
+
def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
|
249 |
+
"""Applies cluster-preserve weight quantizer.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
layer: The layer to check for support.
|
253 |
+
quantize_config: quantization config for supporting cluster preservation
|
254 |
+
on clustered weights
|
255 |
+
Returns:
|
256 |
+
The quantize_config with addon cluster preserve weight_quantizer.
|
257 |
+
"""
|
258 |
+
if not self.supports(layer):
|
259 |
+
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
|
260 |
+
|
261 |
+
# Example: ReLU, Softmax, and AveragePooling2D (without trainable weights)
|
262 |
+
# DepthwiseConv2D (cluster_preserve is disabled)
|
263 |
+
if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
|
264 |
+
layer):
|
265 |
+
return quantize_config
|
266 |
+
|
267 |
+
# Example: Conv2D, Dense layers
|
268 |
+
if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
|
269 |
+
layer.__class__].quantize_config_attrs:
|
270 |
+
quantize_config.weight_quantizer = self._config_quantizer_map[
|
271 |
+
quantize_config.__class__.__name__]
|
272 |
+
else:
|
273 |
+
raise ValueError('Configuration ' +
|
274 |
+
str(quantize_config.__class__.__name__) +
|
275 |
+
' is not supported for Layer ' + str(layer.__class__) +
|
276 |
+
'.')
|
277 |
+
|
278 |
+
return quantize_config
|
279 |
+
|
280 |
+
|
281 |
+
class Default8bitClusterPreserveQuantizeRegistry(
|
282 |
+
ClusterPreserveQuantizeRegistry):
|
283 |
+
"""Default 8 bit ClusterPreserveQuantizeRegistry."""
|
284 |
+
|
285 |
+
def get_quantize_config(self, layer):
|
286 |
+
"""Returns the quantization config with weight_quantizer for a given layer.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
layer: input layer to return quantize config for.
|
290 |
+
Returns:
|
291 |
+
Returns the quantization config for cluster preserve weight_quantizer.
|
292 |
+
"""
|
293 |
+
quantize_config = (default_8bit_quantize_registry.
|
294 |
+
Default8BitQuantizeRegistry().
|
295 |
+
get_quantize_config(layer))
|
296 |
+
cluster_aware_quantize_config = super(
|
297 |
+
Default8bitClusterPreserveQuantizeRegistry,
|
298 |
+
self).apply_cluster_preserve_quantize_config(layer, quantize_config)
|
299 |
+
|
300 |
+
return cluster_aware_quantize_config
|
301 |
+
|
302 |
+
|
303 |
+
class ClusterPreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
|
304 |
+
"""Quantize weights while preserving clusters."""
|
305 |
+
|
306 |
+
def __init__(
|
307 |
+
self, num_bits, per_axis, symmetric, narrow_range, preserve_sparsity):
|
308 |
+
"""ClusterPreserveDefaultWeightsQuantizer.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
num_bits: Number of bits for quantization
|
312 |
+
per_axis: Whether to apply per_axis quantization. The last dimension is
|
313 |
+
used as the axis.
|
314 |
+
symmetric: If true, use symmetric quantization limits instead of training
|
315 |
+
the minimum and maximum of each quantization range separately.
|
316 |
+
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
|
317 |
+
to be [-127, 127] instead of [-128, 127]. This ensures symmetric
|
318 |
+
range has 0 as the centre.
|
319 |
+
preserve_sparsity: Whether to apply prune-cluster-preserving quantization
|
320 |
+
aware training.
|
321 |
+
"""
|
322 |
+
super(ClusterPreserveDefaultWeightsQuantizer, self).__init__(
|
323 |
+
num_bits=num_bits,
|
324 |
+
per_axis=per_axis,
|
325 |
+
symmetric=symmetric,
|
326 |
+
narrow_range=narrow_range,
|
327 |
+
)
|
328 |
+
self.preserve_sparsity = preserve_sparsity
|
329 |
+
|
330 |
+
def _build_clusters(self, name, layer):
|
331 |
+
"""Extracts the cluster centroids and cluster indices.
|
332 |
+
|
333 |
+
Extracts cluster centroids and cluster indices from the pretrained
|
334 |
+
clustered model when the input layer is clustered.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
name: Name of weights in layer.
|
338 |
+
layer: Quantization wrapped keras layer.
|
339 |
+
Returns:
|
340 |
+
A dictionary of the initial values of the
|
341 |
+
cluster centroids, cluster indices, original weights,
|
342 |
+
the pretrained flag for marking the first training
|
343 |
+
epoch, and weight name.
|
344 |
+
"""
|
345 |
+
result = {}
|
346 |
+
weights = getattr(layer.layer, name)
|
347 |
+
if self.preserve_sparsity and not tf.reduce_any(weights == 0):
|
348 |
+
self.preserve_sparsity = False
|
349 |
+
logging.warning(
|
350 |
+
'Input layer does not contain zero weights, so apply CQAT instead.')
|
351 |
+
centroids_mask = None
|
352 |
+
|
353 |
+
# Detects whether layer is convolutional and is clustered per channel
|
354 |
+
data_format = getattr(layer.layer, 'data_format', None)
|
355 |
+
centroids, num_centroids, lookup, cluster_per_channel = get_centroids(
|
356 |
+
layer, weights, data_format)
|
357 |
+
|
358 |
+
if self.preserve_sparsity:
|
359 |
+
sparsity_mask = tf.math.divide_no_nan(weights, weights)
|
360 |
+
zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
|
361 |
+
centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
|
362 |
+
result = {SPARSITY_MASK: sparsity_mask}
|
363 |
+
|
364 |
+
# Prepare clustering variables for the Keras graph when clusters
|
365 |
+
# exist, assuming we do not use number_of_clusters larger than 1024
|
366 |
+
if num_centroids > 1024:
|
367 |
+
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
|
368 |
+
f'Too many centroids to cluster.')
|
369 |
+
return result
|
370 |
+
# If not enough clusters, we do not preserve clustering
|
371 |
+
elif num_centroids <= 1:
|
372 |
+
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
|
373 |
+
f'Perhaps too many clusters requested for this layer?')
|
374 |
+
return result
|
375 |
+
else:
|
376 |
+
clst_centroids_tf = layer.add_weight(
|
377 |
+
CLUSTER_CENTROIDS,
|
378 |
+
shape=centroids.shape,
|
379 |
+
initializer=keras.initializers.Constant(
|
380 |
+
value=K.batch_get_value([centroids])[0]
|
381 |
+
),
|
382 |
+
dtype=centroids.dtype,
|
383 |
+
trainable=True,
|
384 |
+
)
|
385 |
+
|
386 |
+
ori_weights_tf = layer.add_weight(
|
387 |
+
ORIGINAL_WEIGHTS,
|
388 |
+
shape=weights.shape,
|
389 |
+
initializer=keras.initializers.Constant(
|
390 |
+
value=K.batch_get_value([weights])[0]
|
391 |
+
),
|
392 |
+
dtype=weights.dtype,
|
393 |
+
trainable=True,
|
394 |
+
)
|
395 |
+
|
396 |
+
# Get clustering implementation according to layer type
|
397 |
+
clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
|
398 |
+
).get_clustering_impl(
|
399 |
+
layer.layer, name, cluster_per_channel=cluster_per_channel)
|
400 |
+
clustering_impl = clustering_impl_cls(
|
401 |
+
clst_centroids_tf, cluster_config.GradientAggregation.SUM,
|
402 |
+
data_format)
|
403 |
+
|
404 |
+
pulling_indices = tf.dtypes.cast(
|
405 |
+
clustering_impl.get_pulling_indices(ori_weights_tf),
|
406 |
+
lookup.dtype
|
407 |
+
)
|
408 |
+
|
409 |
+
pulling_indices_tf = layer.add_weight(
|
410 |
+
PULLING_INDICES,
|
411 |
+
shape=lookup.shape,
|
412 |
+
initializer=keras.initializers.Constant(
|
413 |
+
value=K.batch_get_value([pulling_indices])[0]
|
414 |
+
),
|
415 |
+
dtype=lookup.dtype,
|
416 |
+
trainable=False,
|
417 |
+
)
|
418 |
+
|
419 |
+
result_clst = {
|
420 |
+
CLUSTER_CENTROIDS: clst_centroids_tf,
|
421 |
+
PULLING_INDICES: pulling_indices_tf,
|
422 |
+
ORIGINAL_WEIGHTS: ori_weights_tf,
|
423 |
+
WEIGHT_NAME: name,
|
424 |
+
CLUSTERING_IMPL: clustering_impl,
|
425 |
+
CENTROIDS_MASK: centroids_mask,
|
426 |
+
}
|
427 |
+
result.update(result_clst)
|
428 |
+
return result
|
429 |
+
|
430 |
+
def build(self, tensor_shape, name, layer):
|
431 |
+
"""Build (P)CQAT wrapper.
|
432 |
+
|
433 |
+
When preserve_sparsity is true and the input is clustered.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
tensor_shape: Shape of weights which needs to be quantized.
|
437 |
+
name: Name of weights in layer.
|
438 |
+
layer: Quantization wrapped keras layer.
|
439 |
+
Returns:
|
440 |
+
Dictionary of centroids, indices and
|
441 |
+
quantization params, the dictionary will be passed
|
442 |
+
to __call__ function.
|
443 |
+
"""
|
444 |
+
# To get all the initial values from pretrained clustered model
|
445 |
+
result = self._build_clusters(name, layer)
|
446 |
+
# Result can have clustering nodes, then this is CQAT
|
447 |
+
# Result can have both clustering nodes and sparsity mask, then
|
448 |
+
# this will be PCQAT
|
449 |
+
result.update(
|
450 |
+
super(ClusterPreserveDefaultWeightsQuantizer,
|
451 |
+
self).build(tensor_shape, name, layer))
|
452 |
+
|
453 |
+
return result
|
454 |
+
|
455 |
+
def __call__(self, inputs, training, weights, **kwargs):
|
456 |
+
"""Apply cluster preserved quantization to the input tensor.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
inputs: Input tensor (layer's weights) to be quantized.
|
460 |
+
training: Whether the graph is currently training.
|
461 |
+
weights: Dictionary of weights (params) the quantizer can use to
|
462 |
+
quantize the tensor (layer's weights). This contains the weights
|
463 |
+
created in the `build` function.
|
464 |
+
**kwargs: Additional variables which may be passed to the quantizer.
|
465 |
+
Returns:
|
466 |
+
quantized tensor.
|
467 |
+
"""
|
468 |
+
if training:
|
469 |
+
if CLUSTER_CENTROIDS in weights:
|
470 |
+
if self.preserve_sparsity:
|
471 |
+
weights[ORIGINAL_WEIGHTS].assign(
|
472 |
+
tf.multiply(weights[ORIGINAL_WEIGHTS],
|
473 |
+
weights[SPARSITY_MASK]))
|
474 |
+
weights[CLUSTERING_IMPL].cluster_centroids.assign(
|
475 |
+
weights[CLUSTERING_IMPL].
|
476 |
+
cluster_centroids * weights[CENTROIDS_MASK]
|
477 |
+
)
|
478 |
+
weights[CLUSTER_CENTROIDS].assign(
|
479 |
+
weights[CLUSTERING_IMPL].cluster_centroids
|
480 |
+
)
|
481 |
+
# Insert clustering variables
|
482 |
+
weights[PULLING_INDICES].assign(tf.dtypes.cast(
|
483 |
+
weights[CLUSTERING_IMPL].get_pulling_indices(
|
484 |
+
weights[ORIGINAL_WEIGHTS]),
|
485 |
+
weights[PULLING_INDICES].dtype
|
486 |
+
))
|
487 |
+
|
488 |
+
output = weights[CLUSTERING_IMPL].get_clustered_weight(
|
489 |
+
weights[PULLING_INDICES], weights[ORIGINAL_WEIGHTS])
|
490 |
+
inputs.assign(output)
|
491 |
+
else:
|
492 |
+
if self.preserve_sparsity:
|
493 |
+
inputs = tf.multiply(inputs, weights[SPARSITY_MASK])
|
494 |
+
output = inputs
|
495 |
+
else:
|
496 |
+
output = inputs
|
497 |
+
|
498 |
+
return quant_ops.LastValueQuantize(
|
499 |
+
output,
|
500 |
+
weights['min_var'],
|
501 |
+
weights['max_var'],
|
502 |
+
is_training=training,
|
503 |
+
num_bits=self.num_bits,
|
504 |
+
per_channel=self.per_axis,
|
505 |
+
symmetric=self.symmetric,
|
506 |
+
narrow_range=self.narrow_range
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
class ClusterPreserveDefault8BitWeightsQuantizer(
|
511 |
+
ClusterPreserveDefaultWeightsQuantizer):
|
512 |
+
"""ClusterPreserveWeightsQuantizer for default 8bit weights."""
|
513 |
+
|
514 |
+
def __init__(self, preserve_sparsity):
|
515 |
+
super(ClusterPreserveDefault8BitWeightsQuantizer,
|
516 |
+
self).__init__(num_bits=8,
|
517 |
+
per_axis=False,
|
518 |
+
symmetric=True,
|
519 |
+
narrow_range=True,
|
520 |
+
preserve_sparsity=preserve_sparsity)
|
521 |
+
self.preserve_sparsity = preserve_sparsity
|
522 |
+
|
523 |
+
|
524 |
+
class ClusterPreserveDefault8BitConvWeightsQuantizer(
|
525 |
+
ClusterPreserveDefaultWeightsQuantizer,
|
526 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer):
|
527 |
+
"""ClusterPreserveWeightsQuantizer for default 8bit Conv2D weights."""
|
528 |
+
|
529 |
+
def __init__(self, preserve_sparsity): # pylint: disable=super-init-not-called
|
530 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
|
531 |
+
self.preserve_sparsity = preserve_sparsity
|
532 |
+
|
533 |
+
def build(self, tensor_shape, name, layer):
|
534 |
+
result = ClusterPreserveDefaultWeightsQuantizer._build_clusters(
|
535 |
+
self, name, layer)
|
536 |
+
result.update(
|
537 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
|
538 |
+
self, tensor_shape, name, layer))
|
539 |
+
return result
|
cluster_preserve_quantize_registry_test.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the 'License');
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an 'AS IS' BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for ClusterPreserveQuantizeRegistry."""
|
16 |
+
|
17 |
+
import tensorflow as tf
|
18 |
+
|
19 |
+
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
|
20 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
21 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config
|
22 |
+
from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import cluster_preserve_quantize_registry
|
23 |
+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
|
24 |
+
|
25 |
+
|
26 |
+
QuantizeConfig = quantize_config.QuantizeConfig
|
27 |
+
layers = keras.layers
|
28 |
+
|
29 |
+
|
30 |
+
class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase):
|
31 |
+
|
32 |
+
def setUp(self):
|
33 |
+
super(ClusterPreserveQuantizeRegistryTest, self).setUp()
|
34 |
+
# Test CQAT by default
|
35 |
+
self.cluster_preserve_quantize_registry = (
|
36 |
+
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
|
37 |
+
False)
|
38 |
+
)
|
39 |
+
# layers which are supported
|
40 |
+
# initial and build a Conv2D layer
|
41 |
+
self.layer_conv2d = layers.Conv2D(10, (2, 2))
|
42 |
+
self.layer_conv2d.build((2, 2))
|
43 |
+
# initial and build a Dense layer
|
44 |
+
self.layer_dense = layers.Dense(10)
|
45 |
+
self.layer_dense.build((2, 2))
|
46 |
+
# initial and build a ReLU layer
|
47 |
+
self.layer_relu = layers.ReLU()
|
48 |
+
self.layer_relu.build((2, 2))
|
49 |
+
|
50 |
+
# a layer which is not supported
|
51 |
+
# initial and build a Custom layer
|
52 |
+
self.layer_custom = self.CustomLayer()
|
53 |
+
self.layer_custom.build()
|
54 |
+
|
55 |
+
class CustomLayer(layers.Layer):
|
56 |
+
"""A simple custom layer with training weights."""
|
57 |
+
|
58 |
+
def build(self, input_shape=(2, 2)):
|
59 |
+
self.add_weight(shape=input_shape,
|
60 |
+
initializer='random_normal',
|
61 |
+
trainable=True)
|
62 |
+
|
63 |
+
class CustomQuantizeConfig(QuantizeConfig):
|
64 |
+
"""A dummy concrete class for testing unregistered configs."""
|
65 |
+
|
66 |
+
def get_weights_and_quantizers(self, layer):
|
67 |
+
return []
|
68 |
+
|
69 |
+
def get_activations_and_quantizers(self, layer):
|
70 |
+
return []
|
71 |
+
|
72 |
+
def set_quantize_weights(self, layer, quantize_weights):
|
73 |
+
pass
|
74 |
+
|
75 |
+
def set_quantize_activations(self, layer, quantize_activations):
|
76 |
+
pass
|
77 |
+
|
78 |
+
def get_output_quantizers(self, layer):
|
79 |
+
return []
|
80 |
+
|
81 |
+
def get_config(self):
|
82 |
+
return {}
|
83 |
+
|
84 |
+
def testSupportsKerasLayer(self):
|
85 |
+
# test registered layer
|
86 |
+
self.assertTrue(
|
87 |
+
self.cluster_preserve_quantize_registry.supports(self.layer_dense))
|
88 |
+
self.assertTrue(
|
89 |
+
self.cluster_preserve_quantize_registry.supports(self.layer_conv2d))
|
90 |
+
# test layer without training weights
|
91 |
+
self.assertTrue(
|
92 |
+
self.cluster_preserve_quantize_registry.supports(self.layer_relu))
|
93 |
+
|
94 |
+
def testDoesNotSupportCustomLayer(self):
|
95 |
+
self.assertFalse(
|
96 |
+
self.cluster_preserve_quantize_registry.supports(self.layer_custom))
|
97 |
+
|
98 |
+
def testApplyClusterPreserveWithQuantizeConfig(self):
|
99 |
+
(self.cluster_preserve_quantize_registry
|
100 |
+
.apply_cluster_preserve_quantize_config(
|
101 |
+
self.layer_conv2d,
|
102 |
+
default_8bit_quantize_registry.Default8BitConvQuantizeConfig(
|
103 |
+
['kernel'], ['activation'], False)))
|
104 |
+
|
105 |
+
def testRaisesErrorUnsupportedQuantizeConfigWithLayer(self):
|
106 |
+
with self.assertRaises(
|
107 |
+
ValueError, msg='Unregistered QuantizeConfigs should raise error.'):
|
108 |
+
(self.cluster_preserve_quantize_registry.
|
109 |
+
apply_cluster_preserve_quantize_config(
|
110 |
+
self.layer_conv2d, self.CustomQuantizeConfig))
|
111 |
+
|
112 |
+
with self.assertRaises(ValueError,
|
113 |
+
msg='Unregistered layers should raise error.'):
|
114 |
+
(self.cluster_preserve_quantize_registry.
|
115 |
+
apply_cluster_preserve_quantize_config(
|
116 |
+
self.layer_custom, self.CustomQuantizeConfig))
|
117 |
+
|
118 |
+
|
119 |
+
class ClusterPreserveDefault8bitQuantizeRegistryTest(tf.test.TestCase):
|
120 |
+
|
121 |
+
def setUp(self):
|
122 |
+
super(ClusterPreserveDefault8bitQuantizeRegistryTest, self).setUp()
|
123 |
+
self.default_8bit_quantize_registry = (
|
124 |
+
default_8bit_quantize_registry.Default8BitQuantizeRegistry())
|
125 |
+
self.cluster_registry = clustering_registry.ClusteringRegistry()
|
126 |
+
# Test CQAT by default
|
127 |
+
self.cluster_preserve_quantize_registry = (
|
128 |
+
cluster_preserve_quantize_registry.ClusterPreserveQuantizeRegistry(
|
129 |
+
False))
|
130 |
+
|
131 |
+
def testSupportsClusterDefault8bitQuantizeKerasLayers(self):
|
132 |
+
# ClusterPreserveQuantize supported layer, must be suppoted
|
133 |
+
# by both Cluster and Quantize
|
134 |
+
cqat_layers_config_map = (
|
135 |
+
self.cluster_preserve_quantize_registry._LAYERS_CONFIG_MAP)
|
136 |
+
for cqat_support_layer in cqat_layers_config_map:
|
137 |
+
if cqat_layers_config_map[cqat_support_layer].weight_attrs and (
|
138 |
+
cqat_layers_config_map[cqat_support_layer].quantize_config_attrs):
|
139 |
+
self.assertIn(
|
140 |
+
cqat_support_layer, self.cluster_registry._LAYERS_WEIGHTS_MAP,
|
141 |
+
msg='Clusteirng doesn\'t support {}'.format(cqat_support_layer))
|
142 |
+
self.assertIn(
|
143 |
+
cqat_support_layer,
|
144 |
+
self.default_8bit_quantize_registry._layer_quantize_map,
|
145 |
+
msg='Default 8bit QAT doesn\'t support {}'.format(
|
146 |
+
cqat_support_layer))
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
tf.test.main()
|
collaborative_optimization.png
ADDED
collaborative_optimization_dist.png
ADDED
cripto.jpg
ADDED
deep_crypto.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from predictors.btc_ltsm import BtcLtsm
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
parser = argparse.ArgumentParser(description='BTC Price Prediction')
|
6 |
+
parser.add_argument('--update', action='store_true', help='Update the dataset')
|
7 |
+
parser.add_argument('--train', action='store_true', help='Train the model')
|
8 |
+
parser.add_argument('--test', action='store_true', help='Test the model')
|
9 |
+
args = parser.parse_args()
|
10 |
+
|
11 |
+
btc_ltsm = BtcLtsm()
|
12 |
+
if args.update:
|
13 |
+
btc_ltsm.update_dataset()
|
14 |
+
if args.train:
|
15 |
+
btc_ltsm.train()
|
16 |
+
if args.test:
|
17 |
+
btc_ltsm.load()
|
18 |
+
btc_ltsm.test_model()
|
default_n_bit_transforms.py
ADDED
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Default 8-bit transforms."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import inspect
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
24 |
+
from tensorflow_model_optimization.python.core.keras.compat import unique_object_name
|
25 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
|
26 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
|
27 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
|
28 |
+
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
|
29 |
+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_configs as configs
|
30 |
+
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
|
31 |
+
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
|
32 |
+
|
33 |
+
|
34 |
+
LayerNode = transforms.LayerNode
|
35 |
+
LayerPattern = transforms.LayerPattern
|
36 |
+
|
37 |
+
|
38 |
+
def _get_conv_bn_layers(bn_layer_node):
|
39 |
+
bn_layer = bn_layer_node.layer
|
40 |
+
conv_layer = bn_layer_node.input_layers[0].layer
|
41 |
+
return conv_layer, bn_layer
|
42 |
+
|
43 |
+
|
44 |
+
def _get_weights(bn_layer_node):
|
45 |
+
"""Returns weight values for fused layer, including copying original values in unfused version."""
|
46 |
+
|
47 |
+
return collections.OrderedDict(
|
48 |
+
list(bn_layer_node.input_layers[0].weights.items())
|
49 |
+
+ list(bn_layer_node.weights.items()))
|
50 |
+
|
51 |
+
|
52 |
+
def _get_params(conv_layer, bn_layer, relu_layer=None):
|
53 |
+
"""Retrieve conv_bn params within wrapped layers."""
|
54 |
+
if 'use_bias' in conv_layer['config']:
|
55 |
+
if conv_layer['config']['use_bias']:
|
56 |
+
raise ValueError(
|
57 |
+
'use_bias should not be set to True in a Conv layer when followed '
|
58 |
+
'by BatchNormalization. The bias in the Conv would be redundant '
|
59 |
+
'with the one in the BatchNormalization.')
|
60 |
+
|
61 |
+
del conv_layer['config']['use_bias']
|
62 |
+
|
63 |
+
if 'name' in bn_layer['config']:
|
64 |
+
del bn_layer['config']['name']
|
65 |
+
|
66 |
+
# TODO(pulkitb): remove key conflicts
|
67 |
+
params = dict(
|
68 |
+
list(conv_layer['config'].items()) + list(bn_layer['config'].items()))
|
69 |
+
|
70 |
+
if relu_layer is not None:
|
71 |
+
params['post_activation'] = quantize_utils.deserialize_layer(
|
72 |
+
relu_layer, use_legacy_format=True
|
73 |
+
)
|
74 |
+
|
75 |
+
return params
|
76 |
+
|
77 |
+
|
78 |
+
def _get_layer_node(fused_layer, weights):
|
79 |
+
layer_config = quantize_utils.serialize_layer(
|
80 |
+
fused_layer, use_legacy_format=True
|
81 |
+
)
|
82 |
+
layer_config['name'] = layer_config['config']['name']
|
83 |
+
# This config tracks which layers get quantized, and whether they have a
|
84 |
+
# custom QuantizeConfig.
|
85 |
+
layer_metadata = {'quantize_config': None}
|
86 |
+
|
87 |
+
return LayerNode(layer_config, weights, metadata=layer_metadata)
|
88 |
+
|
89 |
+
|
90 |
+
def _get_quantize_config(layer_node):
|
91 |
+
return layer_node.metadata.get('quantize_config')
|
92 |
+
|
93 |
+
|
94 |
+
def _has_custom_quantize_config(*layer_nodes):
|
95 |
+
for layer_node in layer_nodes:
|
96 |
+
if _get_quantize_config(layer_node) is not None:
|
97 |
+
return True
|
98 |
+
return False
|
99 |
+
|
100 |
+
|
101 |
+
def _normalize_tuple(value):
|
102 |
+
if isinstance(value, int):
|
103 |
+
return (value,)
|
104 |
+
else:
|
105 |
+
return tuple(value)
|
106 |
+
|
107 |
+
|
108 |
+
class Conv2DBatchNormQuantize(transforms.Transform):
|
109 |
+
"""Ensure FQ does not get placed between Conv and BatchNorm."""
|
110 |
+
|
111 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
112 |
+
self._num_bits_weight = num_bits_weight
|
113 |
+
self._num_bits_activation = num_bits_activation
|
114 |
+
|
115 |
+
def pattern(self):
|
116 |
+
return LayerPattern(
|
117 |
+
'BatchNormalization|SyncBatchNormalization',
|
118 |
+
inputs=[LayerPattern(
|
119 |
+
'Conv2D|DepthwiseConv2D', config={'activation': 'linear'})])
|
120 |
+
|
121 |
+
def _replace(self, bn_layer_node, conv_layer_node):
|
122 |
+
if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
|
123 |
+
return bn_layer_node
|
124 |
+
|
125 |
+
conv_layer_node.layer['config']['activation'] = (
|
126 |
+
quantize_utils.serialize_activation(
|
127 |
+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
|
128 |
+
)
|
129 |
+
)
|
130 |
+
bn_layer_node.metadata['quantize_config'] = (
|
131 |
+
configs.DefaultNBitOutputQuantizeConfig(
|
132 |
+
num_bits_weight=self._num_bits_weight,
|
133 |
+
num_bits_activation=self._num_bits_activation))
|
134 |
+
|
135 |
+
return bn_layer_node
|
136 |
+
|
137 |
+
def replacement(self, match_layer):
|
138 |
+
bn_layer_node = match_layer
|
139 |
+
conv_layer_node = match_layer.input_layers[0]
|
140 |
+
|
141 |
+
return self._replace(bn_layer_node, conv_layer_node)
|
142 |
+
|
143 |
+
def custom_objects(self):
|
144 |
+
return {
|
145 |
+
'NoOpQuantizeConfig':
|
146 |
+
configs.NoOpQuantizeConfig,
|
147 |
+
'NoOpActivation':
|
148 |
+
quantize_aware_activation.NoOpActivation
|
149 |
+
}
|
150 |
+
|
151 |
+
|
152 |
+
class Conv2DReshapeBatchNormQuantize(Conv2DBatchNormQuantize):
|
153 |
+
"""Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""
|
154 |
+
|
155 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
156 |
+
super(Conv2DReshapeBatchNormQuantize, self).__init__(
|
157 |
+
num_bits_weight=num_bits_weight,
|
158 |
+
num_bits_activation=num_bits_activation)
|
159 |
+
self._num_bits_weight = num_bits_weight
|
160 |
+
self._num_bits_activation = num_bits_activation
|
161 |
+
|
162 |
+
def pattern(self):
|
163 |
+
return LayerPattern(
|
164 |
+
'BatchNormalization|SyncBatchNormalization',
|
165 |
+
inputs=[LayerPattern(
|
166 |
+
'Lambda', config={'name': 'sepconv1d_squeeze.*'},
|
167 |
+
inputs=[LayerPattern(
|
168 |
+
'Conv2D|DepthwiseConv2D',
|
169 |
+
config={'activation': 'linear'})])])
|
170 |
+
|
171 |
+
def replacement(self, match_layer):
|
172 |
+
bn_layer_node = match_layer
|
173 |
+
reshape_layer_node = bn_layer_node.input_layers[0]
|
174 |
+
conv_layer_node = reshape_layer_node.input_layers[0]
|
175 |
+
|
176 |
+
return self._replace(bn_layer_node, conv_layer_node)
|
177 |
+
|
178 |
+
|
179 |
+
class Conv2DBatchNormReLUQuantize(Conv2DBatchNormQuantize):
|
180 |
+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
|
181 |
+
|
182 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
183 |
+
super(Conv2DBatchNormReLUQuantize, self).__init__(
|
184 |
+
num_bits_weight=num_bits_weight,
|
185 |
+
num_bits_activation=num_bits_activation)
|
186 |
+
self._num_bits_weight = num_bits_weight
|
187 |
+
self._num_bits_activation = num_bits_activation
|
188 |
+
|
189 |
+
def pattern(self):
|
190 |
+
return LayerPattern(
|
191 |
+
# TODO(pulkitb): Enhance match to only occur for relu, relu1 and relu6
|
192 |
+
'ReLU',
|
193 |
+
inputs=[super(Conv2DBatchNormReLUQuantize, self).pattern()])
|
194 |
+
|
195 |
+
def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
|
196 |
+
if _has_custom_quantize_config(
|
197 |
+
relu_layer_node, bn_layer_node, conv_layer_node):
|
198 |
+
return relu_layer_node
|
199 |
+
|
200 |
+
conv_layer_node.layer['config']['activation'] = (
|
201 |
+
quantize_utils.serialize_activation(
|
202 |
+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
|
203 |
+
)
|
204 |
+
)
|
205 |
+
bn_layer_node.metadata['quantize_config'] = (
|
206 |
+
configs.NoOpQuantizeConfig())
|
207 |
+
|
208 |
+
return relu_layer_node
|
209 |
+
|
210 |
+
def replacement(self, match_layer):
|
211 |
+
relu_layer_node = match_layer
|
212 |
+
bn_layer_node = relu_layer_node.input_layers[0]
|
213 |
+
conv_layer_node = bn_layer_node.input_layers[0]
|
214 |
+
|
215 |
+
return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
|
216 |
+
|
217 |
+
|
218 |
+
class Conv2DBatchNormActivationQuantize(Conv2DBatchNormReLUQuantize):
|
219 |
+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
|
220 |
+
|
221 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
222 |
+
super(Conv2DBatchNormActivationQuantize, self).__init__(
|
223 |
+
num_bits_weight=num_bits_weight,
|
224 |
+
num_bits_activation=num_bits_activation)
|
225 |
+
self._num_bits_weight = num_bits_weight
|
226 |
+
self._num_bits_activation = num_bits_activation
|
227 |
+
|
228 |
+
def pattern(self):
|
229 |
+
return LayerPattern(
|
230 |
+
'Activation',
|
231 |
+
config={'activation': 'relu'},
|
232 |
+
inputs=[Conv2DBatchNormQuantize.pattern(self)])
|
233 |
+
|
234 |
+
|
235 |
+
class Conv2DReshapeBatchNormReLUQuantize(Conv2DBatchNormReLUQuantize):
|
236 |
+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
|
237 |
+
|
238 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
239 |
+
super(Conv2DReshapeBatchNormReLUQuantize, self).__init__(
|
240 |
+
num_bits_weight=num_bits_weight,
|
241 |
+
num_bits_activation=num_bits_activation)
|
242 |
+
self._num_bits_weight = num_bits_weight
|
243 |
+
self._num_bits_activation = num_bits_activation
|
244 |
+
|
245 |
+
def pattern(self):
|
246 |
+
return LayerPattern(
|
247 |
+
'ReLU',
|
248 |
+
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
|
249 |
+
|
250 |
+
def replacement(self, match_layer):
|
251 |
+
relu_layer_node = match_layer
|
252 |
+
bn_layer_node = relu_layer_node.input_layers[0]
|
253 |
+
squeeze_layer_node = bn_layer_node.input_layers[0]
|
254 |
+
conv_layer_node = squeeze_layer_node.input_layers[0]
|
255 |
+
|
256 |
+
return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)
|
257 |
+
|
258 |
+
|
259 |
+
class Conv2DReshapeBatchNormActivationQuantize(
|
260 |
+
Conv2DReshapeBatchNormReLUQuantize):
|
261 |
+
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
|
262 |
+
|
263 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
264 |
+
super(Conv2DReshapeBatchNormActivationQuantize, self).__init__(
|
265 |
+
num_bits_weight=num_bits_weight,
|
266 |
+
num_bits_activation=num_bits_activation)
|
267 |
+
self._num_bits_weight = num_bits_weight
|
268 |
+
self._num_bits_activation = num_bits_activation
|
269 |
+
|
270 |
+
def pattern(self):
|
271 |
+
return LayerPattern(
|
272 |
+
'Activation',
|
273 |
+
config={'activation': 'relu'},
|
274 |
+
inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])
|
275 |
+
|
276 |
+
|
277 |
+
class DenseBatchNormQuantize(transforms.Transform):
|
278 |
+
"""Transform to be applied to "Dense"+ "BatchNorm" Graph.
|
279 |
+
|
280 |
+
This transform disables Quantization between Dense and BatchNorm
|
281 |
+
to ensure FQ does not get placed between them.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
285 |
+
self._num_bits_weight = num_bits_weight
|
286 |
+
self._num_bits_activation = num_bits_activation
|
287 |
+
|
288 |
+
def pattern(self):
|
289 |
+
return LayerPattern(
|
290 |
+
'BatchNormalization|SyncBatchNormalization',
|
291 |
+
inputs=[LayerPattern('Dense', config={'activation': 'linear'})])
|
292 |
+
|
293 |
+
def _replace(self, bn_layer_node, dense_layer_node):
|
294 |
+
if _has_custom_quantize_config(bn_layer_node, dense_layer_node):
|
295 |
+
return bn_layer_node
|
296 |
+
|
297 |
+
dense_layer_node.layer['config']['activation'] = (
|
298 |
+
quantize_utils.serialize_activation(
|
299 |
+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
|
300 |
+
)
|
301 |
+
)
|
302 |
+
bn_layer_node.metadata['quantize_config'] = (
|
303 |
+
configs.DefaultNBitOutputQuantizeConfig(
|
304 |
+
num_bits_weight=self._num_bits_weight,
|
305 |
+
num_bits_activation=self._num_bits_activation))
|
306 |
+
return bn_layer_node
|
307 |
+
|
308 |
+
def replacement(self, match_layer):
|
309 |
+
bn_layer_node = match_layer
|
310 |
+
dense_layer_node = match_layer.input_layers[0]
|
311 |
+
|
312 |
+
return self._replace(bn_layer_node, dense_layer_node)
|
313 |
+
|
314 |
+
def custom_objects(self):
|
315 |
+
return {
|
316 |
+
'DefaultNBitOutputQuantizeConfig':
|
317 |
+
configs.DefaultNBitOutputQuantizeConfig,
|
318 |
+
'NoOpQuantizeConfig':
|
319 |
+
configs.NoOpQuantizeConfig,
|
320 |
+
'NoOpActivation': quantize_aware_activation.NoOpActivation
|
321 |
+
}
|
322 |
+
|
323 |
+
|
324 |
+
class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
|
325 |
+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
|
326 |
+
|
327 |
+
This transform disables Quantization between Dense, BatchNorm and ReLU
|
328 |
+
to ensure FQ does not get placed between them.
|
329 |
+
"""
|
330 |
+
|
331 |
+
def pattern(self):
|
332 |
+
return LayerPattern(
|
333 |
+
'ReLU', inputs=[super(DenseBatchNormReLUQuantize, self).pattern()])
|
334 |
+
|
335 |
+
def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
|
336 |
+
if _has_custom_quantize_config(relu_layer_node, bn_layer_node,
|
337 |
+
dense_layer_node):
|
338 |
+
return relu_layer_node
|
339 |
+
|
340 |
+
dense_layer_node.layer['config']['activation'] = (
|
341 |
+
quantize_utils.serialize_activation(
|
342 |
+
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
|
343 |
+
)
|
344 |
+
)
|
345 |
+
bn_layer_node.metadata['quantize_config'] = (
|
346 |
+
configs.NoOpQuantizeConfig())
|
347 |
+
|
348 |
+
return relu_layer_node
|
349 |
+
|
350 |
+
def replacement(self, match_layer):
|
351 |
+
relu_layer_node = match_layer
|
352 |
+
bn_layer_node = relu_layer_node.input_layers[0]
|
353 |
+
dense_layer_node = bn_layer_node.input_layers[0]
|
354 |
+
|
355 |
+
return self._replace(relu_layer_node, bn_layer_node, dense_layer_node)
|
356 |
+
|
357 |
+
|
358 |
+
class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
|
359 |
+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
|
360 |
+
|
361 |
+
This transform disables Quantization between Dense, BatchNorm and ReLU
|
362 |
+
to ensure FQ does not get placed between them.
|
363 |
+
"""
|
364 |
+
|
365 |
+
def pattern(self):
|
366 |
+
return LayerPattern(
|
367 |
+
'Activation',
|
368 |
+
config={'activation': 'relu'},
|
369 |
+
inputs=[DenseBatchNormQuantize.pattern(self)])
|
370 |
+
|
371 |
+
|
372 |
+
class SeparableConv1DQuantize(transforms.Transform):
|
373 |
+
"""Add QAT support for Keras SeparableConv1D layer.
|
374 |
+
|
375 |
+
Transforms SeparableConv1D into a SeparableConv2D invocation. The Keras
|
376 |
+
SeparableConv1D layer internally uses the same code as a SeparbaleConv2D
|
377 |
+
layer. It simple expands and squeezes the tensor dimensions before and after
|
378 |
+
the convolutions. Applying this transform ensures the QAT handling for
|
379 |
+
SeparableConv2D kicks in and handles the FQ placement properly.
|
380 |
+
|
381 |
+
Maps:
|
382 |
+
Input -> SeparableConv1D -> Output
|
383 |
+
to
|
384 |
+
Input -> Lambda(ExpandDims) -> SeparableConv2D -> Lambda(Squeeze) -> Output
|
385 |
+
|
386 |
+
Unlike SeparableConv2DQuantize, this does not break the layer into
|
387 |
+
DepthwiseConv and Conv separately, since no DepthwiseConv1D exists.
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
391 |
+
self._num_bits_weight = num_bits_weight
|
392 |
+
self._num_bits_activation = num_bits_activation
|
393 |
+
|
394 |
+
def pattern(self):
|
395 |
+
return LayerPattern('SeparableConv1D')
|
396 |
+
|
397 |
+
def _get_name(self, prefix):
|
398 |
+
# TODO(pulkitb): Move away from `unique_object_name` since it isn't
|
399 |
+
# exposed as externally usable.
|
400 |
+
return unique_object_name(prefix)
|
401 |
+
|
402 |
+
def replacement(self, match_layer):
|
403 |
+
if _has_custom_quantize_config(match_layer):
|
404 |
+
return match_layer
|
405 |
+
|
406 |
+
sepconv1d_layer = match_layer.layer
|
407 |
+
sepconv1d_config = sepconv1d_layer['config']
|
408 |
+
sepconv1d_weights = list(match_layer.weights.values())
|
409 |
+
|
410 |
+
padding = sepconv1d_config['padding']
|
411 |
+
# SepConv2D does not accept causal padding, and SepConv1D has some special
|
412 |
+
# handling for it.
|
413 |
+
# TODO(pulkitb): Add support for causal padding.
|
414 |
+
if padding == 'causal':
|
415 |
+
raise ValueError('SeparableConv1D with causal padding is not supported.')
|
416 |
+
|
417 |
+
# TODO(pulkitb): Handle other base_layer args such as dtype, input_dim etc.
|
418 |
+
|
419 |
+
sepconv2d_layer = keras.layers.SeparableConv2D(
|
420 |
+
filters=sepconv1d_config['filters'],
|
421 |
+
kernel_size=(1,) + _normalize_tuple(sepconv1d_config['kernel_size']),
|
422 |
+
strides=_normalize_tuple(sepconv1d_config['strides']) * 2,
|
423 |
+
padding=padding,
|
424 |
+
data_format=sepconv1d_config['data_format'],
|
425 |
+
dilation_rate=(1,)
|
426 |
+
+ _normalize_tuple(sepconv1d_config['dilation_rate']),
|
427 |
+
depth_multiplier=sepconv1d_config['depth_multiplier'],
|
428 |
+
activation=sepconv1d_config['activation'],
|
429 |
+
use_bias=sepconv1d_config['use_bias'],
|
430 |
+
depthwise_initializer=sepconv1d_config['depthwise_initializer'],
|
431 |
+
pointwise_initializer=sepconv1d_config['pointwise_initializer'],
|
432 |
+
bias_initializer=sepconv1d_config['bias_initializer'],
|
433 |
+
depthwise_regularizer=sepconv1d_config['depthwise_regularizer'],
|
434 |
+
pointwise_regularizer=sepconv1d_config['pointwise_regularizer'],
|
435 |
+
bias_regularizer=sepconv1d_config['bias_regularizer'],
|
436 |
+
activity_regularizer=sepconv1d_config['activity_regularizer'],
|
437 |
+
depthwise_constraint=sepconv1d_config['depthwise_constraint'],
|
438 |
+
pointwise_constraint=sepconv1d_config['pointwise_constraint'],
|
439 |
+
bias_constraint=sepconv1d_config['bias_constraint'],
|
440 |
+
# TODO(pulkitb): Rethink what to do for name. Using the same name leads
|
441 |
+
# to confusion, since it's typically separable_conv1d
|
442 |
+
name=sepconv1d_config['name'] + '_QAT_SepConv2D',
|
443 |
+
trainable=sepconv1d_config['trainable'],
|
444 |
+
)
|
445 |
+
|
446 |
+
sepconv2d_weights = collections.OrderedDict()
|
447 |
+
sepconv2d_weights['depthwise_kernel:0'] = np.expand_dims(
|
448 |
+
sepconv1d_weights[0], 0)
|
449 |
+
sepconv2d_weights['pointwise_kernel:0'] = np.expand_dims(
|
450 |
+
sepconv1d_weights[1], 0)
|
451 |
+
if sepconv1d_config['use_bias']:
|
452 |
+
sepconv2d_weights['bias:0'] = sepconv1d_weights[2]
|
453 |
+
|
454 |
+
if sepconv1d_config['data_format'] == 'channels_last':
|
455 |
+
spatial_dim = 1
|
456 |
+
else:
|
457 |
+
spatial_dim = 2
|
458 |
+
|
459 |
+
sepconv2d_layer_config = quantize_utils.serialize_layer(
|
460 |
+
sepconv2d_layer, use_legacy_format=True
|
461 |
+
)
|
462 |
+
sepconv2d_layer_config['name'] = sepconv2d_layer.name
|
463 |
+
|
464 |
+
# Needed to ensure these new layers are considered for quantization.
|
465 |
+
sepconv2d_metadata = {'quantize_config': None}
|
466 |
+
|
467 |
+
# TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.
|
468 |
+
|
469 |
+
# Layer before SeparableConv2D which expands input tensors to match 2D.
|
470 |
+
expand_layer = keras.layers.Lambda(
|
471 |
+
lambda x: tf.expand_dims(x, spatial_dim),
|
472 |
+
name=self._get_name('sepconv1d_expand'),
|
473 |
+
)
|
474 |
+
expand_layer_config = quantize_utils.serialize_layer(
|
475 |
+
expand_layer, use_legacy_format=True
|
476 |
+
)
|
477 |
+
expand_layer_config['name'] = expand_layer.name
|
478 |
+
expand_layer_metadata = {
|
479 |
+
'quantize_config':
|
480 |
+
configs.NoOpQuantizeConfig()}
|
481 |
+
|
482 |
+
squeeze_layer = keras.layers.Lambda(
|
483 |
+
lambda x: tf.squeeze(x, [spatial_dim]),
|
484 |
+
name=self._get_name('sepconv1d_squeeze'),
|
485 |
+
)
|
486 |
+
squeeze_layer_config = quantize_utils.serialize_layer(
|
487 |
+
squeeze_layer, use_legacy_format=True
|
488 |
+
)
|
489 |
+
squeeze_layer_config['name'] = squeeze_layer.name
|
490 |
+
squeeze_layer_metadata = {
|
491 |
+
'quantize_config':
|
492 |
+
configs.NoOpQuantizeConfig()}
|
493 |
+
|
494 |
+
return LayerNode(
|
495 |
+
squeeze_layer_config,
|
496 |
+
metadata=squeeze_layer_metadata,
|
497 |
+
input_layers=[LayerNode(
|
498 |
+
sepconv2d_layer_config,
|
499 |
+
weights=sepconv2d_weights,
|
500 |
+
metadata=sepconv2d_metadata,
|
501 |
+
input_layers=[LayerNode(
|
502 |
+
expand_layer_config, metadata=expand_layer_metadata)]
|
503 |
+
)])
|
504 |
+
|
505 |
+
|
506 |
+
class SeparableConvQuantize(transforms.Transform):
|
507 |
+
"""Break SeparableConv into a DepthwiseConv and Conv layer.
|
508 |
+
|
509 |
+
SeparableConv is a composition of a DepthwiseConv and a Conv layer. For the
|
510 |
+
purpose of quantization, a FQ operation needs to be placed between the output
|
511 |
+
of DepthwiseConv and the following Conv.
|
512 |
+
|
513 |
+
This is needed since there is a dynamic tensor in between the two layers, and
|
514 |
+
it's range information needs to be captured by the FakeQuant op to ensure
|
515 |
+
full int8 quantization of the layers is possible.
|
516 |
+
|
517 |
+
Splitting the layer into 2 ensures that each individual layer is handled
|
518 |
+
correctly with respect to quantization.
|
519 |
+
"""
|
520 |
+
|
521 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
522 |
+
self._num_bits_weight = num_bits_weight
|
523 |
+
self._num_bits_activation = num_bits_activation
|
524 |
+
|
525 |
+
def pattern(self):
|
526 |
+
return LayerPattern('SeparableConv2D')
|
527 |
+
|
528 |
+
def replacement(self, match_layer):
|
529 |
+
if _has_custom_quantize_config(match_layer):
|
530 |
+
return match_layer
|
531 |
+
|
532 |
+
sepconv_layer = match_layer.layer
|
533 |
+
sepconv_weights = list(match_layer.weights.values())
|
534 |
+
|
535 |
+
# TODO(pulkitb): SeparableConv has kwargs other than constructor args which
|
536 |
+
# need to be handled.
|
537 |
+
# Applicable to both layers: trainable, dtype, name
|
538 |
+
# Applicable to dconv: input_dim, input_shape, batch_input_shape, batch_size
|
539 |
+
# Needs special handling: weights
|
540 |
+
# Unknown: dynamic, autocast
|
541 |
+
|
542 |
+
dconv_layer = keras.layers.DepthwiseConv2D(
|
543 |
+
kernel_size=sepconv_layer['config']['kernel_size'],
|
544 |
+
strides=sepconv_layer['config']['strides'],
|
545 |
+
padding=sepconv_layer['config']['padding'],
|
546 |
+
depth_multiplier=sepconv_layer['config']['depth_multiplier'],
|
547 |
+
data_format=sepconv_layer['config']['data_format'],
|
548 |
+
dilation_rate=sepconv_layer['config']['dilation_rate'],
|
549 |
+
activation=None,
|
550 |
+
use_bias=False,
|
551 |
+
depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
|
552 |
+
depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
|
553 |
+
depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
|
554 |
+
trainable=sepconv_layer['config']['trainable'],
|
555 |
+
)
|
556 |
+
dconv_weights = collections.OrderedDict()
|
557 |
+
dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
|
558 |
+
dconv_layer_config = quantize_utils.serialize_layer(
|
559 |
+
dconv_layer, use_legacy_format=True
|
560 |
+
)
|
561 |
+
dconv_layer_config['name'] = dconv_layer.name
|
562 |
+
# Needed to ensure these new layers are considered for quantization.
|
563 |
+
dconv_metadata = {'quantize_config': None}
|
564 |
+
|
565 |
+
conv_layer = keras.layers.Conv2D(
|
566 |
+
filters=sepconv_layer['config']['filters'],
|
567 |
+
kernel_size=(1, 1), # (1,) * rank
|
568 |
+
strides=(1, 1),
|
569 |
+
padding='valid',
|
570 |
+
data_format=sepconv_layer['config']['data_format'],
|
571 |
+
dilation_rate=sepconv_layer['config']['dilation_rate'],
|
572 |
+
groups=1,
|
573 |
+
activation=sepconv_layer['config']['activation'],
|
574 |
+
use_bias=sepconv_layer['config']['use_bias'],
|
575 |
+
kernel_initializer=sepconv_layer['config']['pointwise_initializer'],
|
576 |
+
bias_initializer=sepconv_layer['config']['bias_initializer'],
|
577 |
+
kernel_regularizer=sepconv_layer['config']['pointwise_regularizer'],
|
578 |
+
bias_regularizer=sepconv_layer['config']['bias_regularizer'],
|
579 |
+
activity_regularizer=sepconv_layer['config']['activity_regularizer'],
|
580 |
+
kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
|
581 |
+
bias_constraint=sepconv_layer['config']['bias_constraint'],
|
582 |
+
trainable=sepconv_layer['config']['trainable'],
|
583 |
+
)
|
584 |
+
conv_weights = collections.OrderedDict()
|
585 |
+
conv_weights['kernel:0'] = sepconv_weights[1]
|
586 |
+
if sepconv_layer['config']['use_bias']:
|
587 |
+
conv_weights['bias:0'] = sepconv_weights[2]
|
588 |
+
conv_layer_config = quantize_utils.serialize_layer(
|
589 |
+
conv_layer, use_legacy_format=True
|
590 |
+
)
|
591 |
+
conv_layer_config['name'] = conv_layer.name
|
592 |
+
# Needed to ensure these new layers are considered for quantization.
|
593 |
+
conv_metadata = {'quantize_config': None}
|
594 |
+
|
595 |
+
dconv_layer_node = LayerNode(
|
596 |
+
dconv_layer_config, weights=dconv_weights, metadata=dconv_metadata)
|
597 |
+
return LayerNode(
|
598 |
+
conv_layer_config,
|
599 |
+
weights=conv_weights,
|
600 |
+
input_layers=[dconv_layer_node],
|
601 |
+
metadata=conv_metadata)
|
602 |
+
|
603 |
+
|
604 |
+
class LayerReLUQuantize(transforms.Transform):
|
605 |
+
"""Ensure FQ does not get placed between Add and ReLU."""
|
606 |
+
|
607 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
608 |
+
self._num_bits_weight = num_bits_weight
|
609 |
+
self._num_bits_activation = num_bits_activation
|
610 |
+
|
611 |
+
def pattern(self):
|
612 |
+
return LayerPattern(
|
613 |
+
'ReLU', inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
|
614 |
+
|
615 |
+
def replacement(self, match_layer):
|
616 |
+
relu_layer_node = match_layer
|
617 |
+
add_layer_node = relu_layer_node.input_layers[0]
|
618 |
+
|
619 |
+
add_layer_node.metadata['quantize_config'] = (
|
620 |
+
configs.NoOpQuantizeConfig())
|
621 |
+
|
622 |
+
return match_layer
|
623 |
+
|
624 |
+
def custom_objects(self):
|
625 |
+
return {
|
626 |
+
'NoOpQuantizeConfig':
|
627 |
+
configs.NoOpQuantizeConfig,
|
628 |
+
}
|
629 |
+
|
630 |
+
|
631 |
+
class LayerReluActivationQuantize(LayerReLUQuantize):
|
632 |
+
"""Ensure FQ does not get placed between Add and ReLU."""
|
633 |
+
|
634 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
635 |
+
super(LayerReluActivationQuantize, self).__init__(
|
636 |
+
num_bits_weight=num_bits_weight,
|
637 |
+
num_bits_activation=num_bits_activation)
|
638 |
+
self._num_bits_weight = num_bits_weight
|
639 |
+
self._num_bits_activation = num_bits_activation
|
640 |
+
|
641 |
+
def pattern(self):
|
642 |
+
return LayerPattern(
|
643 |
+
'Activation',
|
644 |
+
config={'activation': 'relu'},
|
645 |
+
inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])
|
646 |
+
|
647 |
+
|
648 |
+
class InputLayerQuantize(transforms.Transform):
|
649 |
+
"""Quantizes InputLayer, by adding QuantizeLayer after it.
|
650 |
+
|
651 |
+
InputLayer => InputLayer -> QuantizeLayer
|
652 |
+
"""
|
653 |
+
|
654 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
655 |
+
self._num_bits_weight = num_bits_weight
|
656 |
+
self._num_bits_activation = num_bits_activation
|
657 |
+
|
658 |
+
def pattern(self):
|
659 |
+
return LayerPattern('InputLayer')
|
660 |
+
|
661 |
+
def replacement(self, match_layer):
|
662 |
+
quant_layer = quantize_layer.QuantizeLayer(
|
663 |
+
quantizers.AllValuesQuantizer(
|
664 |
+
num_bits=self._num_bits_activation, per_axis=False,
|
665 |
+
symmetric=False, narrow_range=False)) # activation/output
|
666 |
+
layer_config = quantize_utils.serialize_layer(
|
667 |
+
quant_layer, use_legacy_format=True
|
668 |
+
)
|
669 |
+
layer_config['name'] = quant_layer.name
|
670 |
+
|
671 |
+
quant_layer_node = LayerNode(
|
672 |
+
layer_config,
|
673 |
+
input_layers=[match_layer])
|
674 |
+
|
675 |
+
return quant_layer_node
|
676 |
+
|
677 |
+
def custom_objects(self):
|
678 |
+
return {
|
679 |
+
'QuantizeLayer': quantize_layer.QuantizeLayer,
|
680 |
+
'MovingAverageQuantizer': quantizers.MovingAverageQuantizer,
|
681 |
+
'AllValuesQuantizer': quantizers.AllValuesQuantizer
|
682 |
+
}
|
683 |
+
|
684 |
+
|
685 |
+
class ConcatTransform(transforms.Transform):
|
686 |
+
"""Transform for Concatenate. Quantize only after concatenation."""
|
687 |
+
|
688 |
+
# pylint:disable=protected-access
|
689 |
+
|
690 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
691 |
+
self._num_bits_weight = num_bits_weight
|
692 |
+
self._num_bits_activation = num_bits_activation
|
693 |
+
|
694 |
+
def pattern(self):
|
695 |
+
# TODO(pulkitb): Write a clean way to handle length patterns.
|
696 |
+
return LayerPattern(
|
697 |
+
'Concatenate', inputs=[LayerPattern('.*'), LayerPattern('.*')])
|
698 |
+
|
699 |
+
def _get_layer_type(self, layer_class_name):
|
700 |
+
keras_layers = inspect.getmembers(keras.layers, inspect.isclass)
|
701 |
+
for layer_name, layer_type in keras_layers:
|
702 |
+
if layer_name == layer_class_name:
|
703 |
+
return layer_type
|
704 |
+
return None
|
705 |
+
|
706 |
+
def _disable_output_quantize(self, quantize_config):
|
707 |
+
# TODO(pulkitb): Disabling quantize_config may also require handling
|
708 |
+
# activation quantizers. Handle that properly.
|
709 |
+
quantize_config.get_output_quantizers = lambda layer: []
|
710 |
+
|
711 |
+
def replacement(self, match_layer):
|
712 |
+
concat_layer_node = match_layer
|
713 |
+
feeding_layer_nodes = match_layer.input_layers
|
714 |
+
|
715 |
+
default_registry = (
|
716 |
+
default_n_bit_quantize_registry.DefaultNBitQuantizeRegistry(
|
717 |
+
num_bits_weight=self._num_bits_weight,
|
718 |
+
num_bits_activation=self._num_bits_activation))
|
719 |
+
|
720 |
+
feed_quantize_configs = []
|
721 |
+
for feed_layer_node in feeding_layer_nodes:
|
722 |
+
quantize_config = feed_layer_node.metadata.get('quantize_config')
|
723 |
+
if not quantize_config:
|
724 |
+
layer_class = self._get_layer_type(feed_layer_node.layer['class_name'])
|
725 |
+
if layer_class is None:
|
726 |
+
# Concat has an input layer we don't recognize. Return.
|
727 |
+
return match_layer
|
728 |
+
|
729 |
+
if layer_class == keras.layers.Concatenate:
|
730 |
+
# Input layer to Concat is also Concat. Don't quantize it.
|
731 |
+
feed_layer_node.metadata['quantize_config'] = (
|
732 |
+
configs.NoOpQuantizeConfig())
|
733 |
+
continue
|
734 |
+
|
735 |
+
if not default_registry._is_supported_layer(layer_class):
|
736 |
+
# Feeding layer is not supported by registry
|
737 |
+
return match_layer
|
738 |
+
|
739 |
+
quantize_config = default_registry._get_quantize_config(layer_class)
|
740 |
+
feed_layer_node.metadata['quantize_config'] = quantize_config
|
741 |
+
|
742 |
+
feed_quantize_configs.append(quantize_config)
|
743 |
+
|
744 |
+
# TODO(pulkitb): this currently only disables output quantize config, but
|
745 |
+
# cannot properly handle if the FQ was added to the activation. Hand this
|
746 |
+
# properly.
|
747 |
+
for quantize_config in feed_quantize_configs:
|
748 |
+
self._disable_output_quantize(quantize_config)
|
749 |
+
|
750 |
+
if not concat_layer_node.metadata.get('quantize_config'):
|
751 |
+
concat_layer_node.metadata['quantize_config'] = (
|
752 |
+
configs.DefaultNBitOutputQuantizeConfig(
|
753 |
+
num_bits_weight=self._num_bits_weight,
|
754 |
+
num_bits_activation=self._num_bits_activation))
|
755 |
+
|
756 |
+
return concat_layer_node
|
757 |
+
|
758 |
+
# pylint:enable=protected-access
|
759 |
+
|
760 |
+
|
761 |
+
class ConcatTransform3Inputs(ConcatTransform):
|
762 |
+
"""Transform for 3 inputs Concatenate."""
|
763 |
+
|
764 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
765 |
+
super(ConcatTransform3Inputs, self).__init__(
|
766 |
+
num_bits_weight=num_bits_weight,
|
767 |
+
num_bits_activation=num_bits_activation)
|
768 |
+
self._num_bits_weight = num_bits_weight
|
769 |
+
self._num_bits_activation = num_bits_activation
|
770 |
+
|
771 |
+
def pattern(self):
|
772 |
+
return LayerPattern(
|
773 |
+
'Concatenate',
|
774 |
+
inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])
|
775 |
+
|
776 |
+
|
777 |
+
class ConcatTransform4Inputs(ConcatTransform):
|
778 |
+
"""Transform for 4 inputs Concatenate."""
|
779 |
+
|
780 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
781 |
+
super(ConcatTransform4Inputs, self).__init__(
|
782 |
+
num_bits_weight=num_bits_weight,
|
783 |
+
num_bits_activation=num_bits_activation)
|
784 |
+
self._num_bits_weight = num_bits_weight
|
785 |
+
self._num_bits_activation = num_bits_activation
|
786 |
+
|
787 |
+
def pattern(self):
|
788 |
+
return LayerPattern(
|
789 |
+
'Concatenate',
|
790 |
+
inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
|
791 |
+
LayerPattern('.*')])
|
792 |
+
|
793 |
+
|
794 |
+
class ConcatTransform5Inputs(ConcatTransform):
|
795 |
+
"""Transform for 5 inputs Concatenate."""
|
796 |
+
|
797 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
798 |
+
super(ConcatTransform5Inputs, self).__init__(
|
799 |
+
num_bits_weight=num_bits_weight,
|
800 |
+
num_bits_activation=num_bits_activation)
|
801 |
+
self._num_bits_weight = num_bits_weight
|
802 |
+
self._num_bits_activation = num_bits_activation
|
803 |
+
|
804 |
+
def pattern(self):
|
805 |
+
return LayerPattern(
|
806 |
+
'Concatenate',
|
807 |
+
inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
|
808 |
+
LayerPattern('.*'), LayerPattern('.*')])
|
809 |
+
|
810 |
+
|
811 |
+
class ConcatTransform6Inputs(ConcatTransform):
|
812 |
+
"""Transform for 6 inputs Concatenate."""
|
813 |
+
|
814 |
+
def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
|
815 |
+
super(ConcatTransform6Inputs, self).__init__(
|
816 |
+
num_bits_weight=num_bits_weight,
|
817 |
+
num_bits_activation=num_bits_activation)
|
818 |
+
self._num_bits_weight = num_bits_weight
|
819 |
+
self._num_bits_activation = num_bits_activation
|
820 |
+
|
821 |
+
def pattern(self):
|
822 |
+
return LayerPattern(
|
823 |
+
'Concatenate',
|
824 |
+
inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
|
825 |
+
LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])
|
main.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python main. py
|
2 |
+
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
|
3 |
+
model. safetensors.index.json: 100%|
|
4 |
+
| 13.5k/13.5k [00:00‹?, PB/s]
|
5 |
+
model-00001-of-00002. safetensors: 100%
|
6 |
+
| 4.95G/4.95G [07:27<00:00, 11. 1MB/s]
|
7 |
+
model-00002-of-00002. safetensors: 100%
|
8 |
+
67. 1M/67.1M [00:05<00:00, 11.5MB/s]
|
9 |
+
Downloading shards: 100% ||
|
10 |
+
| 2/2 [07:35‹00:00, 227.61s/it]
|
11 |
+
Gemma's activation function should be approximate GeLU and not exact GeLU. Changing the activation function to 'gelu_pytorch_tanh.if you want to use the legacy "gelu', edit the "model.config to
|
12 |
+
set hidden_activation=gelu*
|
13 |
+
instead of todden act
|
14 |
+
instead of hidden_act. See https://github.com/huggingface/transformers/pull/29402 for
|
15 |
+
more details.
|
16 |
+
Loading checkpoint shards: 100%|
|
17 |
+
| 2/2 [00:03<00:00, 1.87s/itl
|
18 |
+
generation_config json: 100%||
|
19 |
+
137/137[00:00<?」3B/s]
|
20 |
+
nexa model result:
|
21 |
+
a pouto using the specified caea and resolutiou stones iption: rame rs a photo (cama a):)
|
22 |
+
Captures
|
23 |
+
- camera (str): Specifies the camera
|
24 |
+
to use. Can be \'front\' or \'back\'. The default is \'back\'. \n\n
|
25 |
+
Returns: \n
|
26 |
+
- str: The string contains the file
|
27 |
+
2624 t 12 4a.
|
28 |
+
Photo if nees at ay 96 83662387968t, ample: /storage/emulated/o/Pictures/NAPP/3N
|
29 |
+
123456.Jpg\'\n latency: 367.85967230796814
|
misc.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Misc."""
|
15 |
+
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import collections
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
|
24 |
+
|
25 |
+
|
26 |
+
@encoding_stage.tf_style_encoding_stage
|
27 |
+
class SplitBySmallValueEncodingStage(encoding_stage.EncodingStageInterface):
|
28 |
+
"""Encoding stage splitting the input by small values.
|
29 |
+
|
30 |
+
This encoding stage will split the input into two outputs: the value and the
|
31 |
+
indices of the elements whose absolute value is larger than a certain
|
32 |
+
threshold. The elements smaller than the threshold is then decoded to zero.
|
33 |
+
"""
|
34 |
+
|
35 |
+
ENCODED_INDICES_KEY = 'indices'
|
36 |
+
ENCODED_VALUES_KEY = 'non_zero_floats'
|
37 |
+
THRESHOLD_PARAMS_KEY = 'threshold'
|
38 |
+
|
39 |
+
def __init__(self, threshold=1e-8):
|
40 |
+
"""Initializer for the SplitBySmallValueEncodingStage.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
threshold: The threshold of the small weights to be set to zero.
|
44 |
+
"""
|
45 |
+
self._threshold = threshold
|
46 |
+
|
47 |
+
@property
|
48 |
+
def name(self):
|
49 |
+
"""See base class."""
|
50 |
+
return 'split_by_small_value'
|
51 |
+
|
52 |
+
@property
|
53 |
+
def compressible_tensors_keys(self):
|
54 |
+
"""See base class."""
|
55 |
+
return [
|
56 |
+
self.ENCODED_VALUES_KEY,
|
57 |
+
self.ENCODED_INDICES_KEY,
|
58 |
+
]
|
59 |
+
|
60 |
+
@property
|
61 |
+
def commutes_with_sum(self):
|
62 |
+
"""See base class."""
|
63 |
+
return False
|
64 |
+
|
65 |
+
@property
|
66 |
+
def decode_needs_input_shape(self):
|
67 |
+
"""See base class."""
|
68 |
+
return True
|
69 |
+
|
70 |
+
def get_params(self):
|
71 |
+
"""See base class."""
|
72 |
+
encode_params = collections.OrderedDict([(self.THRESHOLD_PARAMS_KEY,
|
73 |
+
self._threshold)])
|
74 |
+
decode_params = collections.OrderedDict()
|
75 |
+
return encode_params, decode_params
|
76 |
+
|
77 |
+
def encode(self, x, encode_params):
|
78 |
+
"""See base class."""
|
79 |
+
|
80 |
+
threshold = tf.cast(encode_params[self.THRESHOLD_PARAMS_KEY], x.dtype)
|
81 |
+
indices = tf.cast(tf.compat.v2.where(tf.abs(x) > threshold), tf.int32)
|
82 |
+
non_zero_x = tf.gather_nd(x, indices)
|
83 |
+
indices = tf.squeeze(indices, axis=1)
|
84 |
+
return collections.OrderedDict([
|
85 |
+
(self.ENCODED_INDICES_KEY, indices),
|
86 |
+
(self.ENCODED_VALUES_KEY, non_zero_x),
|
87 |
+
])
|
88 |
+
|
89 |
+
def decode(self,
|
90 |
+
encoded_tensors,
|
91 |
+
decode_params,
|
92 |
+
num_summands=None,
|
93 |
+
shape=None):
|
94 |
+
"""See base class."""
|
95 |
+
del decode_params, num_summands # Unused.
|
96 |
+
|
97 |
+
indices = encoded_tensors[self.ENCODED_INDICES_KEY]
|
98 |
+
non_zero_x = encoded_tensors[self.ENCODED_VALUES_KEY]
|
99 |
+
|
100 |
+
indices = tf.expand_dims(indices, 1)
|
101 |
+
|
102 |
+
indices = tf.cast(indices, tf.int64)
|
103 |
+
shape = tf.cast(shape, tf.int64)
|
104 |
+
sparse_tensor = tf.SparseTensor(indices=indices, values=non_zero_x,
|
105 |
+
dense_shape=shape)
|
106 |
+
decoded_x = tf.sparse.to_dense(sparse_tensor)
|
107 |
+
|
108 |
+
return decoded_x
|
109 |
+
|
110 |
+
|
111 |
+
@encoding_stage.tf_style_encoding_stage
|
112 |
+
class DifferenceBetweenIntegersEncodingStage(
|
113 |
+
encoding_stage.EncodingStageInterface):
|
114 |
+
"""Encoding stage taking the difference between a sequence of integers.
|
115 |
+
|
116 |
+
This encoding stage can be useful when the original integers can be large, but
|
117 |
+
the difference of the integers are much smaller values and have a more compact
|
118 |
+
representation. For example, it can be combined with the
|
119 |
+
`SplitBySmallValueEncodingStage` to further compress the increasing sequence
|
120 |
+
of indices.
|
121 |
+
|
122 |
+
The encode method expects a tensor with 1 dimension and with integer dtype.
|
123 |
+
"""
|
124 |
+
|
125 |
+
ENCODED_VALUES_KEY = 'difference_between_integers'
|
126 |
+
|
127 |
+
@property
|
128 |
+
def name(self):
|
129 |
+
"""See base class."""
|
130 |
+
return 'difference_between_integers'
|
131 |
+
|
132 |
+
@property
|
133 |
+
def compressible_tensors_keys(self):
|
134 |
+
"""See base class."""
|
135 |
+
return [
|
136 |
+
self.ENCODED_VALUES_KEY,
|
137 |
+
]
|
138 |
+
|
139 |
+
@property
|
140 |
+
def commutes_with_sum(self):
|
141 |
+
"""See base class."""
|
142 |
+
return False
|
143 |
+
|
144 |
+
@property
|
145 |
+
def decode_needs_input_shape(self):
|
146 |
+
"""See base class."""
|
147 |
+
return False
|
148 |
+
|
149 |
+
def get_params(self):
|
150 |
+
"""See base class."""
|
151 |
+
return collections.OrderedDict(), collections.OrderedDict()
|
152 |
+
|
153 |
+
def encode(self, x, encode_params):
|
154 |
+
"""See base class."""
|
155 |
+
del encode_params # Unused.
|
156 |
+
if x.shape.ndims != 1:
|
157 |
+
raise ValueError('Number of dimensions must be 1. Shape of x: %s' %
|
158 |
+
x.shape)
|
159 |
+
if not x.dtype.is_integer:
|
160 |
+
raise TypeError(
|
161 |
+
'Unsupported input type: %s. Support only integer types.' % x.dtype)
|
162 |
+
|
163 |
+
diff_x = x - tf.concat([[0], x[:-1]], 0)
|
164 |
+
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, diff_x)])
|
165 |
+
|
166 |
+
def decode(self,
|
167 |
+
encoded_tensors,
|
168 |
+
decode_params,
|
169 |
+
num_summands=None,
|
170 |
+
shape=None):
|
171 |
+
"""See base class."""
|
172 |
+
del decode_params, num_summands, shape # Unused
|
173 |
+
return tf.cumsum(encoded_tensors[self.ENCODED_VALUES_KEY])
|
misc_test.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import absolute_import
|
16 |
+
from __future__ import division
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
import itertools
|
20 |
+
|
21 |
+
from absl.testing import parameterized
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc
|
26 |
+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
|
27 |
+
|
28 |
+
|
29 |
+
if tf.executing_eagerly():
|
30 |
+
tf.compat.v1.disable_eager_execution()
|
31 |
+
|
32 |
+
|
33 |
+
class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest):
|
34 |
+
|
35 |
+
def default_encoding_stage(self):
|
36 |
+
"""See base class."""
|
37 |
+
return misc.SplitBySmallValueEncodingStage()
|
38 |
+
|
39 |
+
def default_input(self):
|
40 |
+
"""See base class."""
|
41 |
+
return tf.random.uniform([50], minval=-1.0, maxval=1.0)
|
42 |
+
|
43 |
+
@property
|
44 |
+
def is_lossless(self):
|
45 |
+
"""See base class."""
|
46 |
+
return False
|
47 |
+
|
48 |
+
def common_asserts_for_test_data(self, data):
|
49 |
+
"""See base class."""
|
50 |
+
self._assert_is_integer(
|
51 |
+
data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
|
52 |
+
|
53 |
+
def _assert_is_integer(self, indices):
|
54 |
+
"""Asserts that indices values are integers."""
|
55 |
+
assert indices.dtype == np.int32
|
56 |
+
|
57 |
+
@parameterized.parameters([tf.float32, tf.float64])
|
58 |
+
def test_input_types(self, x_dtype):
|
59 |
+
# Tests different input dtypes.
|
60 |
+
x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
|
61 |
+
threshold = 0.05
|
62 |
+
stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
|
63 |
+
encode_params, decode_params = stage.get_params()
|
64 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
65 |
+
decode_params)
|
66 |
+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
|
67 |
+
test_data = self.evaluate_test_data(test_data)
|
68 |
+
|
69 |
+
self._assert_is_integer(test_data.encoded_x[
|
70 |
+
misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
|
71 |
+
|
72 |
+
# The numpy arrays must have the same dtype as the arrays from test_data.
|
73 |
+
expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype)
|
74 |
+
expected_encoded_indices = np.array([0, 1], dtype=np.int32)
|
75 |
+
expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
|
76 |
+
dtype=x_dtype.as_numpy_dtype)
|
77 |
+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
|
78 |
+
expected_encoded_values)
|
79 |
+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
|
80 |
+
expected_encoded_indices)
|
81 |
+
self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
|
82 |
+
|
83 |
+
def test_all_zero_input_works(self):
|
84 |
+
# Tests that encoding does not blow up with all-zero input. With all-zero
|
85 |
+
# input, both of the encoded values will be empty arrays.
|
86 |
+
stage = misc.SplitBySmallValueEncodingStage()
|
87 |
+
test_data = self.run_one_to_many_encode_decode(stage,
|
88 |
+
lambda: tf.zeros([50]))
|
89 |
+
|
90 |
+
self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
|
91 |
+
|
92 |
+
def test_all_below_threshold_works(self):
|
93 |
+
# Tests that encoding does not blow up with all-below-threshold input. In
|
94 |
+
# this case, both of the encoded values will be empty arrays.
|
95 |
+
stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
|
96 |
+
x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
|
97 |
+
encode_params, decode_params = stage.get_params()
|
98 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
99 |
+
decode_params)
|
100 |
+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
|
101 |
+
test_data = self.evaluate_test_data(test_data)
|
102 |
+
|
103 |
+
expected_encoded_indices = np.array([], dtype=np.int32).reshape([0])
|
104 |
+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
|
105 |
+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
|
106 |
+
expected_encoded_indices)
|
107 |
+
self.assertAllEqual(test_data.decoded_x,
|
108 |
+
np.zeros([50], dtype=x.dtype.as_numpy_dtype))
|
109 |
+
|
110 |
+
|
111 |
+
class DifferenceBetweenIntegersEncodingStageTest(
|
112 |
+
test_utils.BaseEncodingStageTest):
|
113 |
+
|
114 |
+
def default_encoding_stage(self):
|
115 |
+
"""See base class."""
|
116 |
+
return misc.DifferenceBetweenIntegersEncodingStage()
|
117 |
+
|
118 |
+
def default_input(self):
|
119 |
+
"""See base class."""
|
120 |
+
return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def is_lossless(self):
|
124 |
+
"""See base class."""
|
125 |
+
return True
|
126 |
+
|
127 |
+
def common_asserts_for_test_data(self, data):
|
128 |
+
"""See base class."""
|
129 |
+
self.assertAllEqual(data.x, data.decoded_x)
|
130 |
+
|
131 |
+
@parameterized.parameters(
|
132 |
+
itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64]))
|
133 |
+
def test_with_multiple_input_shapes(self, input_dims, dtype):
|
134 |
+
|
135 |
+
def x_fn():
|
136 |
+
return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype)
|
137 |
+
|
138 |
+
test_data = self.run_one_to_many_encode_decode(
|
139 |
+
self.default_encoding_stage(), x_fn)
|
140 |
+
self.common_asserts_for_test_data(test_data)
|
141 |
+
|
142 |
+
def test_empty_input_static(self):
|
143 |
+
# Tests that the encoding works when the input shape is [0].
|
144 |
+
x = []
|
145 |
+
x = tf.convert_to_tensor(x, dtype=tf.int32)
|
146 |
+
assert x.shape.as_list() == [0]
|
147 |
+
|
148 |
+
stage = self.default_encoding_stage()
|
149 |
+
encode_params, decode_params = stage.get_params()
|
150 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
151 |
+
decode_params)
|
152 |
+
|
153 |
+
test_data = self.evaluate_test_data(
|
154 |
+
test_utils.TestData(x, encoded_x, decoded_x))
|
155 |
+
self.common_asserts_for_test_data(test_data)
|
156 |
+
|
157 |
+
def test_empty_input_dynamic(self):
|
158 |
+
# Tests that the encoding works when the input shape is [0], but not
|
159 |
+
# statically known.
|
160 |
+
y = tf.zeros((10,))
|
161 |
+
indices = tf.compat.v2.where(tf.abs(y) > 1e-8)
|
162 |
+
x = tf.gather_nd(y, indices)
|
163 |
+
x = tf.cast(x, tf.int32) # Empty tensor.
|
164 |
+
assert x.shape.as_list() == [None]
|
165 |
+
stage = self.default_encoding_stage()
|
166 |
+
encode_params, decode_params = stage.get_params()
|
167 |
+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
|
168 |
+
decode_params)
|
169 |
+
|
170 |
+
test_data = self.evaluate_test_data(
|
171 |
+
test_utils.TestData(x, encoded_x, decoded_x))
|
172 |
+
assert test_data.x.shape == (0,)
|
173 |
+
assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,)
|
174 |
+
assert test_data.decoded_x.shape == (0,)
|
175 |
+
|
176 |
+
@parameterized.parameters([tf.bool, tf.float32])
|
177 |
+
def test_encode_unsupported_type_raises(self, dtype):
|
178 |
+
stage = self.default_encoding_stage()
|
179 |
+
with self.assertRaisesRegexp(TypeError, 'Unsupported input type'):
|
180 |
+
self.run_one_to_many_encode_decode(
|
181 |
+
stage, lambda: tf.cast(self.default_input(), dtype))
|
182 |
+
|
183 |
+
def test_encode_unsupported_input_shape_raises(self):
|
184 |
+
x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32)
|
185 |
+
stage = self.default_encoding_stage()
|
186 |
+
params, _ = stage.get_params()
|
187 |
+
with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'):
|
188 |
+
stage.encode(x, params)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
tf.test.main()
|
mnist_cnn.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
# pylint: disable=missing-docstring
|
16 |
+
"""Train a simple convnet on the MNIST dataset."""
|
17 |
+
from __future__ import print_function
|
18 |
+
|
19 |
+
from absl import app as absl_app
|
20 |
+
from absl import flags
|
21 |
+
import tensorflow as tf
|
22 |
+
|
23 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
24 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
|
25 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
|
26 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
|
27 |
+
|
28 |
+
|
29 |
+
PolynomialDecay = pruning_schedule.PolynomialDecay
|
30 |
+
l = keras.layers
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
|
34 |
+
batch_size = 128
|
35 |
+
num_classes = 10
|
36 |
+
epochs = 12
|
37 |
+
|
38 |
+
flags.DEFINE_string('output_dir', '/tmp/mnist_train/',
|
39 |
+
'Output directory to hold tensorboard events')
|
40 |
+
|
41 |
+
|
42 |
+
def build_sequential_model(input_shape):
|
43 |
+
return keras.Sequential([
|
44 |
+
l.Conv2D(
|
45 |
+
32, 5, padding='same', activation='relu', input_shape=input_shape
|
46 |
+
),
|
47 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
48 |
+
l.BatchNormalization(),
|
49 |
+
l.Conv2D(64, 5, padding='same', activation='relu'),
|
50 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
51 |
+
l.Flatten(),
|
52 |
+
l.Dense(1024, activation='relu'),
|
53 |
+
l.Dropout(0.4),
|
54 |
+
l.Dense(num_classes, activation='softmax'),
|
55 |
+
])
|
56 |
+
|
57 |
+
|
58 |
+
def build_functional_model(input_shape):
|
59 |
+
inp = keras.Input(shape=input_shape)
|
60 |
+
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
|
61 |
+
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
|
62 |
+
x = l.BatchNormalization()(x)
|
63 |
+
x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
|
64 |
+
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
|
65 |
+
x = l.Flatten()(x)
|
66 |
+
x = l.Dense(1024, activation='relu')(x)
|
67 |
+
x = l.Dropout(0.4)(x)
|
68 |
+
out = l.Dense(num_classes, activation='softmax')(x)
|
69 |
+
|
70 |
+
return keras.models.Model([inp], [out])
|
71 |
+
|
72 |
+
|
73 |
+
def build_layerwise_model(input_shape, **pruning_params):
|
74 |
+
return keras.Sequential([
|
75 |
+
prune.prune_low_magnitude(
|
76 |
+
l.Conv2D(32, 5, padding='same', activation='relu'),
|
77 |
+
input_shape=input_shape,
|
78 |
+
**pruning_params
|
79 |
+
),
|
80 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
81 |
+
l.BatchNormalization(),
|
82 |
+
prune.prune_low_magnitude(
|
83 |
+
l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params
|
84 |
+
),
|
85 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
86 |
+
l.Flatten(),
|
87 |
+
prune.prune_low_magnitude(
|
88 |
+
l.Dense(1024, activation='relu'), **pruning_params
|
89 |
+
),
|
90 |
+
l.Dropout(0.4),
|
91 |
+
prune.prune_low_magnitude(
|
92 |
+
l.Dense(num_classes, activation='softmax'), **pruning_params
|
93 |
+
),
|
94 |
+
])
|
95 |
+
|
96 |
+
|
97 |
+
def train_and_save(models, x_train, y_train, x_test, y_test):
|
98 |
+
for model in models:
|
99 |
+
model.compile(
|
100 |
+
loss=keras.losses.categorical_crossentropy,
|
101 |
+
optimizer='adam',
|
102 |
+
metrics=['accuracy'],
|
103 |
+
)
|
104 |
+
|
105 |
+
# Print the model summary.
|
106 |
+
model.summary()
|
107 |
+
|
108 |
+
# Add a pruning step callback to peg the pruning step to the optimizer's
|
109 |
+
# step. Also add a callback to add pruning summaries to tensorboard
|
110 |
+
callbacks = [
|
111 |
+
pruning_callbacks.UpdatePruningStep(),
|
112 |
+
pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir)
|
113 |
+
]
|
114 |
+
|
115 |
+
model.fit(
|
116 |
+
x_train,
|
117 |
+
y_train,
|
118 |
+
batch_size=batch_size,
|
119 |
+
epochs=epochs,
|
120 |
+
verbose=1,
|
121 |
+
callbacks=callbacks,
|
122 |
+
validation_data=(x_test, y_test))
|
123 |
+
score = model.evaluate(x_test, y_test, verbose=0)
|
124 |
+
print('Test loss:', score[0])
|
125 |
+
print('Test accuracy:', score[1])
|
126 |
+
|
127 |
+
# Export and import the model. Check that accuracy persists.
|
128 |
+
saved_model_dir = '/tmp/saved_model'
|
129 |
+
print('Saving model to: ', saved_model_dir)
|
130 |
+
keras.models.save_model(model, saved_model_dir, save_format='tf')
|
131 |
+
print('Loading model from: ', saved_model_dir)
|
132 |
+
loaded_model = keras.models.load_model(saved_model_dir)
|
133 |
+
|
134 |
+
score = loaded_model.evaluate(x_test, y_test, verbose=0)
|
135 |
+
print('Test loss:', score[0])
|
136 |
+
print('Test accuracy:', score[1])
|
137 |
+
|
138 |
+
|
139 |
+
def main(unused_argv):
|
140 |
+
# input image dimensions
|
141 |
+
img_rows, img_cols = 28, 28
|
142 |
+
|
143 |
+
# the data, shuffled and split between train and test sets
|
144 |
+
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
145 |
+
|
146 |
+
if keras.backend.image_data_format() == 'channels_first':
|
147 |
+
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
|
148 |
+
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
|
149 |
+
input_shape = (1, img_rows, img_cols)
|
150 |
+
else:
|
151 |
+
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
|
152 |
+
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
|
153 |
+
input_shape = (img_rows, img_cols, 1)
|
154 |
+
|
155 |
+
x_train = x_train.astype('float32')
|
156 |
+
x_test = x_test.astype('float32')
|
157 |
+
x_train /= 255
|
158 |
+
x_test /= 255
|
159 |
+
print('x_train shape:', x_train.shape)
|
160 |
+
print(x_train.shape[0], 'train samples')
|
161 |
+
print(x_test.shape[0], 'test samples')
|
162 |
+
|
163 |
+
# convert class vectors to binary class matrices
|
164 |
+
y_train = keras.utils.to_categorical(y_train, num_classes)
|
165 |
+
y_test = keras.utils.to_categorical(y_test, num_classes)
|
166 |
+
|
167 |
+
pruning_params = {
|
168 |
+
'pruning_schedule':
|
169 |
+
PolynomialDecay(
|
170 |
+
initial_sparsity=0.1,
|
171 |
+
final_sparsity=0.75,
|
172 |
+
begin_step=1000,
|
173 |
+
end_step=5000,
|
174 |
+
frequency=100)
|
175 |
+
}
|
176 |
+
|
177 |
+
layerwise_model = build_layerwise_model(input_shape, **pruning_params)
|
178 |
+
sequential_model = build_sequential_model(input_shape)
|
179 |
+
sequential_model = prune.prune_low_magnitude(
|
180 |
+
sequential_model, **pruning_params)
|
181 |
+
functional_model = build_functional_model(input_shape)
|
182 |
+
functional_model = prune.prune_low_magnitude(
|
183 |
+
functional_model, **pruning_params)
|
184 |
+
|
185 |
+
models = [layerwise_model, sequential_model, functional_model]
|
186 |
+
train_and_save(models, x_train, y_train, x_test, y_test)
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == '__main__':
|
190 |
+
absl_app.run(main)
|
mnist_e2e_sparsity2x4.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
# pylint: disable=missing-docstring,protected-access
|
16 |
+
"""Train a simple convnet on the MNIST dataset with sparsity 2x4.
|
17 |
+
|
18 |
+
It is based on mnist_e2e.py
|
19 |
+
"""
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
from absl import app as absl_app
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
|
26 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
27 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import prune
|
28 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
|
29 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
|
30 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
|
31 |
+
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
|
32 |
+
|
33 |
+
|
34 |
+
ConstantSparsity = pruning_schedule.ConstantSparsity
|
35 |
+
l = keras.layers
|
36 |
+
|
37 |
+
tf.random.set_seed(42)
|
38 |
+
|
39 |
+
batch_size = 128
|
40 |
+
num_classes = 10
|
41 |
+
epochs = 1
|
42 |
+
|
43 |
+
PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)
|
44 |
+
|
45 |
+
|
46 |
+
def check_model_sparsity_2x4(model):
|
47 |
+
for layer in model.layers:
|
48 |
+
if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance(
|
49 |
+
layer.layer, PRUNABLE_2x4_LAYERS):
|
50 |
+
for weight in layer.layer.get_prunable_weights():
|
51 |
+
if not pruning_utils.is_pruned_m_by_n(weight):
|
52 |
+
return False
|
53 |
+
return True
|
54 |
+
|
55 |
+
|
56 |
+
def build_layerwise_model(input_shape, **pruning_params):
|
57 |
+
return keras.Sequential([
|
58 |
+
prune.prune_low_magnitude(
|
59 |
+
l.Conv2D(
|
60 |
+
32, 5, padding='same', activation='relu', input_shape=input_shape
|
61 |
+
),
|
62 |
+
**pruning_params
|
63 |
+
),
|
64 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
65 |
+
prune.prune_low_magnitude(
|
66 |
+
l.Conv2D(64, 5, padding='same'), **pruning_params
|
67 |
+
),
|
68 |
+
l.BatchNormalization(),
|
69 |
+
l.ReLU(),
|
70 |
+
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
|
71 |
+
l.Flatten(),
|
72 |
+
prune.prune_low_magnitude(
|
73 |
+
l.Dense(1024, activation='relu'), **pruning_params
|
74 |
+
),
|
75 |
+
l.Dropout(0.4),
|
76 |
+
l.Dense(num_classes, activation='softmax'),
|
77 |
+
])
|
78 |
+
|
79 |
+
|
80 |
+
def train(model, x_train, y_train, x_test, y_test):
|
81 |
+
model.compile(
|
82 |
+
loss=keras.losses.categorical_crossentropy,
|
83 |
+
optimizer='adam',
|
84 |
+
metrics=['accuracy'],
|
85 |
+
)
|
86 |
+
model.run_eagerly = True
|
87 |
+
|
88 |
+
# Print the model summary.
|
89 |
+
model.summary()
|
90 |
+
|
91 |
+
# Add a pruning step callback to peg the pruning step to the optimizer's
|
92 |
+
# step. Also add a callback to add pruning summaries to tensorboard
|
93 |
+
callbacks = [
|
94 |
+
pruning_callbacks.UpdatePruningStep(),
|
95 |
+
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
|
96 |
+
]
|
97 |
+
|
98 |
+
model.fit(
|
99 |
+
x_train,
|
100 |
+
y_train,
|
101 |
+
batch_size=batch_size,
|
102 |
+
epochs=epochs,
|
103 |
+
verbose=1,
|
104 |
+
callbacks=callbacks,
|
105 |
+
validation_data=(x_test, y_test))
|
106 |
+
score = model.evaluate(x_test, y_test, verbose=0)
|
107 |
+
print('Test loss:', score[0])
|
108 |
+
print('Test accuracy:', score[1])
|
109 |
+
|
110 |
+
# Check sparsity 2x4 type before stripping pruning
|
111 |
+
is_pruned_2x4 = check_model_sparsity_2x4(model)
|
112 |
+
print('Pass the check for sparsity 2x4: ', is_pruned_2x4)
|
113 |
+
|
114 |
+
model = prune.strip_pruning(model)
|
115 |
+
return model
|
116 |
+
|
117 |
+
|
118 |
+
def main(unused_argv):
|
119 |
+
##############################################################################
|
120 |
+
# Prepare training and testing data
|
121 |
+
##############################################################################
|
122 |
+
(x_train, y_train), (
|
123 |
+
x_test,
|
124 |
+
y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data()
|
125 |
+
|
126 |
+
##############################################################################
|
127 |
+
# Train a model with sparsity 2x4.
|
128 |
+
##############################################################################
|
129 |
+
pruning_params = {
|
130 |
+
'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100),
|
131 |
+
'sparsity_m_by_n': (2, 4),
|
132 |
+
}
|
133 |
+
|
134 |
+
model = build_layerwise_model(input_shape, **pruning_params)
|
135 |
+
pruned_model = train(model, x_train, y_train, x_test, y_test)
|
136 |
+
|
137 |
+
# Write a model that has been pruned with 2x4 sparsity.
|
138 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
|
139 |
+
tflite_model = converter.convert()
|
140 |
+
|
141 |
+
tflite_model_path = '/tmp/mnist_2x4.tflite'
|
142 |
+
print('model is saved to {}'.format(tflite_model_path))
|
143 |
+
with open(tflite_model_path, 'wb') as f:
|
144 |
+
f.write(tflite_model)
|
145 |
+
|
146 |
+
print('evaluate pruned model: ')
|
147 |
+
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
|
148 |
+
# the accuracy of 2:4 pruning model is 0.9866
|
149 |
+
# the accuracy of unstructured model with 50% is 0.9863
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
absl_app.run(main)
|
periodical_update_and_scheduling_test.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for when the training and inference graphs are the same."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import periodical_update_and_scheduling as svd
|
23 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
24 |
+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
|
25 |
+
|
26 |
+
|
27 |
+
def _build_model():
|
28 |
+
i = keras.layers.Input(shape=(28, 28), name='input')
|
29 |
+
x = keras.layers.Reshape((28, 28, 1))(i)
|
30 |
+
x = keras.layers.Conv2D(
|
31 |
+
20, 5, activation='relu', padding='valid', name='conv1'
|
32 |
+
)(x)
|
33 |
+
x = keras.layers.MaxPool2D(2, 2)(x)
|
34 |
+
x = keras.layers.Conv2D(
|
35 |
+
50, 5, activation='relu', padding='valid', name='conv2'
|
36 |
+
)(x)
|
37 |
+
x = keras.layers.MaxPool2D(2, 2)(x)
|
38 |
+
x = keras.layers.Flatten()(x)
|
39 |
+
x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
|
40 |
+
output = keras.layers.Dense(10, name='fc2')(x)
|
41 |
+
|
42 |
+
model = keras.Model(inputs=[i], outputs=[output])
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def _get_dataset():
|
47 |
+
mnist = keras.datasets.mnist
|
48 |
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
49 |
+
x_train, x_test = x_train / 255.0, x_test / 255.0
|
50 |
+
# Use subset of 60000 examples to keep unit test speed fast.
|
51 |
+
x_train = x_train[0:1000]
|
52 |
+
y_train = y_train[0:1000]
|
53 |
+
return (x_train, y_train), (x_test, y_test)
|
54 |
+
|
55 |
+
|
56 |
+
def _train_model(model):
|
57 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
58 |
+
|
59 |
+
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
|
60 |
+
|
61 |
+
(x_train, y_train), _ = _get_dataset()
|
62 |
+
|
63 |
+
model.fit(x_train, y_train, epochs=1)
|
64 |
+
|
65 |
+
|
66 |
+
def _save_as_saved_model(model):
|
67 |
+
saved_model_dir = tempfile.mkdtemp()
|
68 |
+
model.save(saved_model_dir)
|
69 |
+
return saved_model_dir
|
70 |
+
|
71 |
+
|
72 |
+
# TODO(tfmot): reuse existing test utilities.
|
73 |
+
def _convert_to_tflite(saved_model_dir):
|
74 |
+
_, tflite_file = tempfile.mkstemp()
|
75 |
+
|
76 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
77 |
+
tflite_model = converter.convert()
|
78 |
+
|
79 |
+
with open(tflite_file, 'wb') as f:
|
80 |
+
f.write(tflite_model)
|
81 |
+
|
82 |
+
return tflite_file
|
83 |
+
|
84 |
+
|
85 |
+
def _get_directory_size_in_bytes(directory):
|
86 |
+
total = 0
|
87 |
+
try:
|
88 |
+
for entry in os.scandir(directory):
|
89 |
+
if entry.is_file():
|
90 |
+
# if it's a file, use stat() function
|
91 |
+
total += entry.stat().st_size
|
92 |
+
elif entry.is_dir():
|
93 |
+
# if it's a directory, recursively call this function
|
94 |
+
total += _get_directory_size_in_bytes(entry.path)
|
95 |
+
except NotADirectoryError:
|
96 |
+
# if `directory` isn't a directory, get the file size then
|
97 |
+
return os.path.getsize(directory)
|
98 |
+
except PermissionError:
|
99 |
+
# if for whatever reason we can't open the folder, return 0
|
100 |
+
return 0
|
101 |
+
return total
|
102 |
+
|
103 |
+
|
104 |
+
class FunctionalTest(tf.test.TestCase):
|
105 |
+
|
106 |
+
# TODO(tfmot): can simplify to single layer test that checks exact
|
107 |
+
# dimensions of weights.
|
108 |
+
def testSVD_ReducesSavedModelSize(self):
|
109 |
+
model = _build_model()
|
110 |
+
|
111 |
+
original_saved_model_dir = _save_as_saved_model(model)
|
112 |
+
|
113 |
+
algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
|
114 |
+
training_model = algorithm.optimize_model(model)
|
115 |
+
compressed_model = algorithm.compress_model(training_model)
|
116 |
+
|
117 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
118 |
+
|
119 |
+
original_size = _get_directory_size_in_bytes(original_saved_model_dir)
|
120 |
+
compressed_size = _get_directory_size_in_bytes(saved_model_dir)
|
121 |
+
|
122 |
+
self.assertLess(compressed_size, original_size / 3)
|
123 |
+
|
124 |
+
def testSVD_HasReasonableAccuracy_TF(self):
|
125 |
+
model = _build_model()
|
126 |
+
|
127 |
+
algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
|
128 |
+
training_model = algorithm.optimize_model(model)
|
129 |
+
|
130 |
+
_train_model(training_model)
|
131 |
+
|
132 |
+
compressed_model = algorithm.compress_model(training_model)
|
133 |
+
|
134 |
+
_, (x_test, y_test) = _get_dataset()
|
135 |
+
|
136 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
137 |
+
|
138 |
+
compressed_model.compile(
|
139 |
+
optimizer='adam', loss=loss_fn, metrics=['accuracy'])
|
140 |
+
|
141 |
+
results = compressed_model.evaluate(x_test, y_test)
|
142 |
+
|
143 |
+
self.assertGreater(results[1], 0.60)
|
144 |
+
|
145 |
+
def testSVD_ReducesTFLiteModelSize(self):
|
146 |
+
model = _build_model()
|
147 |
+
|
148 |
+
original_saved_model_dir = _save_as_saved_model(model)
|
149 |
+
original_tflite_file = _convert_to_tflite(original_saved_model_dir)
|
150 |
+
|
151 |
+
algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
|
152 |
+
training_model = algorithm.optimize_model(model)
|
153 |
+
compressed_model = algorithm.compress_model(training_model)
|
154 |
+
|
155 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
156 |
+
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
|
157 |
+
|
158 |
+
original_size = os.path.getsize(original_tflite_file)
|
159 |
+
compressed_size = os.path.getsize(compressed_tflite_file)
|
160 |
+
|
161 |
+
self.assertLess(compressed_size, original_size / 6)
|
162 |
+
|
163 |
+
def testSVD_HasReasonableAccuracy_TFLite(self):
|
164 |
+
model = _build_model()
|
165 |
+
|
166 |
+
algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
|
167 |
+
training_model = algorithm.optimize_model(model)
|
168 |
+
|
169 |
+
_train_model(training_model)
|
170 |
+
|
171 |
+
compressed_model = algorithm.compress_model(training_model)
|
172 |
+
|
173 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
174 |
+
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
|
175 |
+
|
176 |
+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
|
177 |
+
|
178 |
+
self.assertGreater(accuracy, 0.60)
|
179 |
+
|
180 |
+
# TODO(tfmot): can simplify to single layer test.
|
181 |
+
def testSVD_BreaksDownLayerWeights(self):
|
182 |
+
model = _build_model()
|
183 |
+
|
184 |
+
first_conv_layer = model.layers[2]
|
185 |
+
self.assertLen(first_conv_layer.weights, 2)
|
186 |
+
|
187 |
+
algorithm = svd.SVD(rank=16, update_freq=1, warmup_step=10)
|
188 |
+
training_model = algorithm.optimize_model(model)
|
189 |
+
compressed_model = algorithm.compress_model(training_model)
|
190 |
+
|
191 |
+
first_conv_layer = compressed_model.layers[2]
|
192 |
+
|
193 |
+
self.assertLen(first_conv_layer.weights, 3)
|
194 |
+
|
195 |
+
# TODO(tfmot): can simplify to single layer test.
|
196 |
+
def testSVD_PreservesPretrainedWeights(self):
|
197 |
+
i = keras.layers.Input(shape=(2), name='input')
|
198 |
+
output = keras.layers.Dense(3, name='fc1')(i)
|
199 |
+
model = keras.Model(inputs=[i], outputs=[output])
|
200 |
+
|
201 |
+
dense_layer_weights = model.layers[1].get_weights()
|
202 |
+
|
203 |
+
algorithm = svd.SVD(rank=1, update_freq=1, warmup_step=10)
|
204 |
+
training_model = algorithm.optimize_model(model)
|
205 |
+
|
206 |
+
dense_layer_training_weights = training_model.layers[1].get_weights()
|
207 |
+
|
208 |
+
# kernel
|
209 |
+
algorithm.weight_reprs = []
|
210 |
+
algorithm.init_training_weights(dense_layer_weights[0])
|
211 |
+
w1_repr, w2_repr = algorithm.weight_reprs
|
212 |
+
assert (w1_repr.kwargs['initializer'](None) == \
|
213 |
+
dense_layer_training_weights[0]).numpy().all()
|
214 |
+
assert (w2_repr.kwargs['initializer'](None) == \
|
215 |
+
dense_layer_training_weights[1]).numpy().all()
|
216 |
+
|
217 |
+
# bias
|
218 |
+
assert (dense_layer_weights[1] == dense_layer_training_weights[2]).all()
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == '__main__':
|
222 |
+
tf.test.main()
|
prune_preserve_quantize_registry.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Registry responsible for built-in keras classes."""
|
16 |
+
|
17 |
+
import tensorflow as tf
|
18 |
+
|
19 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
20 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
|
21 |
+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
|
22 |
+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
|
23 |
+
default_8bit_quantize_registry,)
|
24 |
+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
|
25 |
+
default_8bit_quantizers,)
|
26 |
+
|
27 |
+
|
28 |
+
layers = keras.layers
|
29 |
+
|
30 |
+
|
31 |
+
class _PrunePreserveInfo(object):
|
32 |
+
"""PrunePreserveInfo."""
|
33 |
+
|
34 |
+
def __init__(self, weight_attrs, quantize_config_attrs):
|
35 |
+
"""Initializes PrunePreserveInfo.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
weight_attrs: list of sparsity preservable weight attributes of layer.
|
39 |
+
quantize_config_attrs: list of quantization configuration class name.
|
40 |
+
"""
|
41 |
+
self.weight_attrs = weight_attrs
|
42 |
+
self.quantize_config_attrs = quantize_config_attrs
|
43 |
+
|
44 |
+
|
45 |
+
class PrunePreserveQuantizeRegistry():
|
46 |
+
"""PrunePreserveQuantizeRegistry responsible for built-in keras layers."""
|
47 |
+
|
48 |
+
# The keys represent built-in keras layers; the first values represent the
|
49 |
+
# the variables within the layers which hold the kernel weights, second
|
50 |
+
# values represent the class name of quantization configuration for layers.
|
51 |
+
# This decide the weights of layers with quantization configurations are
|
52 |
+
# sparsity preservable.
|
53 |
+
_LAYERS_CONFIG_MAP = {
|
54 |
+
layers.Conv2D:
|
55 |
+
_PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
|
56 |
+
layers.Dense:
|
57 |
+
_PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
58 |
+
|
59 |
+
# DepthwiseConv2D is supported with 8bit qat, but not with prune,
|
60 |
+
# thus for DepthwiseConv2D PQAT, weights sparsity preserve is disabled.
|
61 |
+
layers.DepthwiseConv2D:
|
62 |
+
_PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']),
|
63 |
+
|
64 |
+
# layers that supported with prune, but not yet with QAT
|
65 |
+
# layers.Conv1D:
|
66 |
+
# _PrunePreserveInfo(['kernel'], []),
|
67 |
+
# layers.Conv2DTranspose:
|
68 |
+
# _PrunePreserveInfo(['kernel'], []),
|
69 |
+
# layers.Conv3D:
|
70 |
+
# _PrunePreserveInfo(['kernel'], []),
|
71 |
+
# layers.Conv3DTranspose:
|
72 |
+
# _PrunePreserveInfo(['kernel'], []),
|
73 |
+
# layers.LocallyConnected1D:
|
74 |
+
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
75 |
+
# layers.LocallyConnected2D:
|
76 |
+
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
|
77 |
+
|
78 |
+
# SeparableConv need verify from 8bit qat
|
79 |
+
# layers.SeparableConv1D:
|
80 |
+
# _PrunePreserveInfo(['pointwise_kernel'], \
|
81 |
+
# ['Default8BitConvQuantizeConfig']),
|
82 |
+
# layers.SeparableConv2D:
|
83 |
+
# _PrunePreserveInfo(['pointwise_kernel'], \
|
84 |
+
# ['Default8BitConvQuantizeConfig']),
|
85 |
+
|
86 |
+
# Embedding need verify from 8bit qat
|
87 |
+
# layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
|
88 |
+
}
|
89 |
+
|
90 |
+
_DISABLE_PRUNE_PRESERVE = frozenset({
|
91 |
+
layers.DepthwiseConv2D,
|
92 |
+
})
|
93 |
+
|
94 |
+
def __init__(self):
|
95 |
+
|
96 |
+
self._config_quantizer_map = {
|
97 |
+
'Default8BitQuantizeConfig':
|
98 |
+
PrunePreserveDefault8BitWeightsQuantizer(),
|
99 |
+
'Default8BitConvQuantizeConfig':
|
100 |
+
PrunePreserveDefault8BitConvWeightsQuantizer(),
|
101 |
+
}
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def _no_trainable_weights(cls, layer):
|
105 |
+
"""Returns whether this layer has trainable weights.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
layer: The layer to check for trainable weights.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
True/False whether the layer has trainable weights.
|
112 |
+
"""
|
113 |
+
return not layer.trainable_weights
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def _disable_prune_preserve(cls, layer):
|
117 |
+
"""Returns whether disable this layer for prune preserve.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
layer: The layer to check for disable.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
True/False whether disable this layer for prune preserve.
|
124 |
+
"""
|
125 |
+
|
126 |
+
return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def supports(cls, layer):
|
130 |
+
"""Returns whether the registry supports this layer type.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
layer: The layer to check for support.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
True/False whether the layer type is supported.
|
137 |
+
"""
|
138 |
+
|
139 |
+
# layers without trainable weights are considered supported,
|
140 |
+
# e.g., ReLU, Softmax, and AveragePooling2D.
|
141 |
+
if cls._no_trainable_weights(layer):
|
142 |
+
return True
|
143 |
+
|
144 |
+
if layer.__class__ in cls._LAYERS_CONFIG_MAP:
|
145 |
+
return True
|
146 |
+
|
147 |
+
return False
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def _weight_names(cls, layer):
|
151 |
+
"""Gets the weight names."""
|
152 |
+
if cls._no_trainable_weights(layer):
|
153 |
+
return []
|
154 |
+
|
155 |
+
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def get_sparsity_preservable_weights(cls, layer):
|
159 |
+
"""Gets sparsity preservable weights from keras layer.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
layer: instance of keras layer
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
List of sparsity preservable weights
|
166 |
+
"""
|
167 |
+
return [getattr(layer, weight) for weight in cls._weight_names(layer)]
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def get_suppport_quantize_config_names(cls, layer):
|
171 |
+
"""Gets class name of supported quantize config for layer.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
layer: instance of keras layer
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
List of supported quantize config class name.
|
178 |
+
"""
|
179 |
+
|
180 |
+
# layers without trainable weights don't need quantize_config for pqat
|
181 |
+
if cls._no_trainable_weights(layer):
|
182 |
+
return []
|
183 |
+
|
184 |
+
return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs
|
185 |
+
|
186 |
+
def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
|
187 |
+
"""Applies weights sparsity preservation.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
layer: The layer to check for support.
|
191 |
+
quantize_config: quantization config to check for support,
|
192 |
+
apply sparsity preservation to pruned weights
|
193 |
+
Raises:
|
194 |
+
ValueError when layer is supported does not have quantization config.
|
195 |
+
Returns:
|
196 |
+
Returns quantize_config with addon sparsity preserve weight_quantizer.
|
197 |
+
"""
|
198 |
+
if self.supports(layer):
|
199 |
+
if (self._no_trainable_weights(layer) or
|
200 |
+
self._disable_prune_preserve(layer)):
|
201 |
+
return quantize_config
|
202 |
+
if (quantize_config.__class__.__name__
|
203 |
+
in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):
|
204 |
+
quantize_config.weight_quantizer = self._config_quantizer_map[
|
205 |
+
quantize_config.__class__.__name__]
|
206 |
+
else:
|
207 |
+
raise ValueError('Configuration {} is not supported for Layer {}.'
|
208 |
+
.format(str(quantize_config.__class__.__name__),
|
209 |
+
str(layer.__class__.__name__)))
|
210 |
+
else:
|
211 |
+
raise ValueError('Layer {} is not supported.'.format(
|
212 |
+
str(layer.__class__.__name__)))
|
213 |
+
|
214 |
+
return quantize_config
|
215 |
+
|
216 |
+
|
217 |
+
class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry):
|
218 |
+
"""Default 8 bit PrunePreserveQuantizeRegistry."""
|
219 |
+
|
220 |
+
def get_quantize_config(self, layer):
|
221 |
+
"""Returns the quantization config with addon sparsity.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
layer: input layer to return quantize config for.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Returns the quantization config with sparsity preserve weight_quantizer.
|
228 |
+
"""
|
229 |
+
quantize_config = (default_8bit_quantize_registry
|
230 |
+
.Default8BitQuantizeRegistry()
|
231 |
+
.get_quantize_config(layer))
|
232 |
+
prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config(
|
233 |
+
layer, quantize_config)
|
234 |
+
|
235 |
+
return prune_aware_quantize_config
|
236 |
+
|
237 |
+
|
238 |
+
class PrunePreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
|
239 |
+
"""Quantize weights while preserve sparsity."""
|
240 |
+
|
241 |
+
def __init__(self, num_bits, per_axis, symmetric, narrow_range):
|
242 |
+
"""Initializes PrunePreserveDefaultWeightsQuantizer.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
num_bits: Number of bits for quantization
|
246 |
+
per_axis: Whether to apply per_axis quantization. The last dimension is
|
247 |
+
used as the axis.
|
248 |
+
symmetric: If true, use symmetric quantization limits instead of training
|
249 |
+
the minimum and maximum of each quantization range separately.
|
250 |
+
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
|
251 |
+
to be [-127, 127] instead of [-128, 127]. This ensures symmetric range
|
252 |
+
has 0 as the centre.
|
253 |
+
"""
|
254 |
+
quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric,
|
255 |
+
narrow_range)
|
256 |
+
|
257 |
+
def _build_sparsity_mask(self, name, layer):
|
258 |
+
weights = getattr(layer.layer, name)
|
259 |
+
sparsity_mask = tf.math.divide_no_nan(weights, weights)
|
260 |
+
|
261 |
+
return {'sparsity_mask': sparsity_mask}
|
262 |
+
|
263 |
+
def build(self, tensor_shape, name, layer):
|
264 |
+
"""Constructs mask to preserve weights sparsity.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
tensor_shape: Shape of weights which needs to be quantized.
|
268 |
+
name: Name of weights in layer.
|
269 |
+
layer: quantization wrapped keras layer.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
Dictionary of constructed sparsity mask and
|
273 |
+
quantization params, the dictionary will be passed
|
274 |
+
to __call__ function.
|
275 |
+
"""
|
276 |
+
result = self._build_sparsity_mask(name, layer)
|
277 |
+
result.update(
|
278 |
+
super(PrunePreserveDefaultWeightsQuantizer,
|
279 |
+
self).build(tensor_shape, name, layer))
|
280 |
+
return result
|
281 |
+
|
282 |
+
def __call__(self, inputs, training, weights, **kwargs):
|
283 |
+
"""Applies sparsity preserved quantization to the input tensor.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
inputs: Input tensor (layer's weights) to be quantized.
|
287 |
+
training: Whether the graph is currently training.
|
288 |
+
weights: Dictionary of weights (params) the quantizer can use to
|
289 |
+
quantize the tensor (layer's weights). This contains the weights
|
290 |
+
created in the `build` function.
|
291 |
+
**kwargs: Additional variables which may be passed to the quantizer.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
quantized tensor.
|
295 |
+
"""
|
296 |
+
|
297 |
+
prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])
|
298 |
+
|
299 |
+
return quant_ops.LastValueQuantize(
|
300 |
+
prune_preserve_inputs,
|
301 |
+
weights['min_var'],
|
302 |
+
weights['max_var'],
|
303 |
+
is_training=training,
|
304 |
+
num_bits=self.num_bits,
|
305 |
+
per_channel=self.per_axis,
|
306 |
+
symmetric=self.symmetric,
|
307 |
+
narrow_range=self.narrow_range,
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
class PrunePreserveDefault8BitWeightsQuantizer(
|
312 |
+
PrunePreserveDefaultWeightsQuantizer):
|
313 |
+
"""PrunePreserveWeightsQuantizer for default 8bit weights."""
|
314 |
+
|
315 |
+
def __init__(self):
|
316 |
+
super(PrunePreserveDefault8BitWeightsQuantizer,
|
317 |
+
self).__init__(num_bits=8,
|
318 |
+
per_axis=False,
|
319 |
+
symmetric=True,
|
320 |
+
narrow_range=True)
|
321 |
+
|
322 |
+
|
323 |
+
class PrunePreserveDefault8BitConvWeightsQuantizer(
|
324 |
+
PrunePreserveDefaultWeightsQuantizer,
|
325 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer,):
|
326 |
+
"""PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights."""
|
327 |
+
|
328 |
+
# pylint: disable=super-init-not-called
|
329 |
+
def __init__(self):
|
330 |
+
# Skip PrunePreserveDefaultWeightsQuantizer since they have the same super.
|
331 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
|
332 |
+
|
333 |
+
def build(self, tensor_shape, name, layer):
|
334 |
+
result = PrunePreserveDefaultWeightsQuantizer._build_sparsity_mask(
|
335 |
+
self, name, layer)
|
336 |
+
result.update(
|
337 |
+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
|
338 |
+
self, tensor_shape, name, layer))
|
339 |
+
return result
|
readme.txt
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ixl iWARP FreeBSD* driver for Intel(R) Ethernet Connection X722
|
2 |
+
================================================================
|
3 |
+
July 9, 2019
|
4 |
+
|
5 |
+
Contents
|
6 |
+
========
|
7 |
+
|
8 |
+
- Prerequisites
|
9 |
+
- Building and Installation
|
10 |
+
- Testing
|
11 |
+
- Configuration
|
12 |
+
- Interoperability
|
13 |
+
- Known Issues
|
14 |
+
|
15 |
+
|
16 |
+
Prerequisites
|
17 |
+
=============
|
18 |
+
|
19 |
+
- FreeBSD version 11.2
|
20 |
+
- Kernel configuration:
|
21 |
+
Please add the following kernel configuration options:
|
22 |
+
include GENERIC
|
23 |
+
options COMPAT_LINUXKPI
|
24 |
+
options IPOIB_CM
|
25 |
+
options IXL_IW
|
26 |
+
|
27 |
+
nodevice ixl
|
28 |
+
nodevice iavf
|
29 |
+
Note: IXL_IW is required for FreeBSD-CURRENT branch.
|
30 |
+
- For the iw_ixl driver to work, an if_ixl driver with iwarp interface
|
31 |
+
is required. The interface is available in if_ixl version 1.7.12 or later.
|
32 |
+
It should be enabled prior to usage, as the setting is switched off by
|
33 |
+
default. To enable iwarp compatibility, add
|
34 |
+
hw.ixl.enable_iwarp=1
|
35 |
+
to
|
36 |
+
/boot/loader.conf
|
37 |
+
|
38 |
+
The lan driver can be downloaded from
|
39 |
+
https://downloadcenter.intel.com/download/25160/Ethernet-Intel-Network-Adapter-D
|
40 |
+
river-for-PCIe-40-Gigabit-Ethernet-Network-Connection-under-FreeBSD
|
41 |
+
Or search on downloadcenter.intel.com using '40 Gigabit Ethernet Network
|
42 |
+
Connection under FreeBSD'. Newer OS releases contain the if_ixl driver in
|
43 |
+
the ixl driver version 1.7.12-k or later.
|
44 |
+
|
45 |
+
There are some known issues with the interface on if_ixl-1.7.12. Please
|
46 |
+
use version 1.7.13 or later.
|
47 |
+
|
48 |
+
- fastreg memory mode in krping needs a patch applied to krping.
|
49 |
+
Refer to the 'Testing' and 'Known Issues' sections for details.
|
50 |
+
|
51 |
+
|
52 |
+
Building and Installation
|
53 |
+
=========================
|
54 |
+
|
55 |
+
1. Untar ixl-<version>.tar.gz and iw_ixl-<version>.tar.gz
|
56 |
+
|
57 |
+
# tar -xf ixl-<version>.tar.gz
|
58 |
+
# tar -xf iw_ixl-<version>.tar.gz
|
59 |
+
|
60 |
+
2. Install the if_ixl driver:
|
61 |
+
|
62 |
+
# cd ixl-<version>/src directory
|
63 |
+
# make
|
64 |
+
# make install
|
65 |
+
|
66 |
+
3. Install the iw_ixl driver:
|
67 |
+
|
68 |
+
# cd iw_ixl-<version>/src
|
69 |
+
# make clean
|
70 |
+
# make IXL_DIR=$PATH_TO_IXL/ixl-<version>/src
|
71 |
+
# make install
|
72 |
+
|
73 |
+
4. Install the man page for the iw_ixl driver by copying the iw_ixl.4.gz file
|
74 |
+
to the directory where manual pages are held on your system. For instance:
|
75 |
+
|
76 |
+
# cp iw_ixl-<version>/doc/iw_ixl.4.gz /usr/share/man/man4/
|
77 |
+
|
78 |
+
For in-tree driver if_ixl-1.7.12-k or later, it is sufficient to follow
|
79 |
+
the instruction from point 3 but ensure the correct path to if_ixl source
|
80 |
+
folder is supplied. For instance:
|
81 |
+
IXL_DIR=/usr/src/sys/dev/ixl/
|
82 |
+
|
83 |
+
|
84 |
+
Testing
|
85 |
+
-------
|
86 |
+
1. To load the iw_ixl driver, call:
|
87 |
+
|
88 |
+
# kldload iw_ixl
|
89 |
+
|
90 |
+
If if_ixl is not already loaded, the system will load it on its own.
|
91 |
+
Please remember to add
|
92 |
+
hw.ixl.enable_iwarp=1
|
93 |
+
to /boot/loader.conf file prior to if_ixl loading, to ensure the ixl
|
94 |
+
driver has the iwarp interface enabled.
|
95 |
+
|
96 |
+
2. To validate the load of the driver, check:
|
97 |
+
|
98 |
+
# sysctl -a | grep infiniband
|
99 |
+
|
100 |
+
A number of sys.class.infiniband should appear, provided at least one
|
101 |
+
port of the X722 is up.
|
102 |
+
|
103 |
+
3. The source code for krping software is provided with the kernel in
|
104 |
+
/usr/src/sys/contrib/rdma/krping/. To compile the software, change directory
|
105 |
+
to /usr/src/sys/modules/rdma/krping/ and invoke the following:
|
106 |
+
|
107 |
+
# make clean
|
108 |
+
# make
|
109 |
+
# make install
|
110 |
+
|
111 |
+
4. Start krping server on one machine:
|
112 |
+
|
113 |
+
# echo size=64,count=1,port=6601,addr=100.0.0.189,server > /dev/krping
|
114 |
+
5. Connect client from another machine:
|
115 |
+
|
116 |
+
# echo size=64,count=1,port=6601,addr=100.0.0.189,client > /dev/krping
|
117 |
+
|
118 |
+
|
119 |
+
Configuration
|
120 |
+
=============
|
121 |
+
The following sysctl options are visible:
|
122 |
+
- hw.iw_ixl.max_ceq
|
123 |
+
determines the maximum number of msix vectors available to the driver
|
124 |
+
for CEQ usage.
|
125 |
+
- hw.iw_ixl.debug
|
126 |
+
defines level of debug messages.
|
127 |
+
- hw.iw_ixl.mpa_version
|
128 |
+
shows the current MPA version used.
|
129 |
+
|
130 |
+
The max_ceq setting may be changed by adding:
|
131 |
+
hw.iw_ixl.max_ceq=$value
|
132 |
+
to /boot/loader.conf file. The final number of CEQ is evaluated depending
|
133 |
+
on the available msix vectors, number of cpu cores, and hardware limits.
|
134 |
+
|
135 |
+
If max_ceq=0, the value is ignored.
|
136 |
+
|
137 |
+
The debug setting may be changed either by adding:
|
138 |
+
hw.iw_ixl.debug=$value
|
139 |
+
to the /boot/loader.conf file or by calling
|
140 |
+
sysctl hw.iw_ixl.debug=$value
|
141 |
+
|
142 |
+
The mpa_version may be changed by adding:
|
143 |
+
hw.iw_ixl.mpa_version=$value
|
144 |
+
to the /boot/loader.conf file.
|
145 |
+
|
146 |
+
|
147 |
+
Interoperability
|
148 |
+
================
|
149 |
+
|
150 |
+
To interoperate with Chelsio iWARP devices:
|
151 |
+
|
152 |
+
1. Load the ixl driver with parameter mpa_version set to 1. Add the line:
|
153 |
+
hw.iw_ixl.mpa_version=1
|
154 |
+
to /boot/loader.conf
|
155 |
+
|
156 |
+
2. Load Chelsio T4/T5 RDMA driver (iw_cxgb4) with parameter dack_mode set to 0.
|
157 |
+
|
158 |
+
|
159 |
+
Known Issues
|
160 |
+
============
|
161 |
+
|
162 |
+
- Loopback is not supported.
|
163 |
+
- MTU changes are not supported.
|
164 |
+
- IPv6 is not supported.
|
165 |
+
- MW memory mode is not supported.
|
166 |
+
- MR memory mode supports only single buffer.
|
167 |
+
- The function ib_cq_resize is not supported.
|
168 |
+
- The max number of registered cq, qp, pd or mr reported by the device may
|
169 |
+
differ from the actual number of registrations achievable.
|
170 |
+
- A kernel crash may occur when trying to run krping without ensuring that the
|
171 |
+
two machines are able to ping each other.
|
172 |
+
- A kernel crash may occur when trying to load the iw_ixl driver when
|
173 |
+
hw.ixl.enable_iwarp=0 (fixed with if_ixl 1.7.13).
|
174 |
+
- A kernel crash may occur when loading the iw_ixl driver on a card that is
|
175 |
+
supported by if_ixl driver, but does not have iWARP capability (fixed with
|
176 |
+
if_ixl 1.7.13).
|
177 |
+
- Krping with fastreg memory mode will not work unless some changes are made
|
178 |
+
to krping. To work around the issue, modify the krping_rdma_rkey function
|
179 |
+
such that, in the case of FASTREG memory mode, the ib_post_send function
|
180 |
+
with &cd->invalidate_wr parameter is not called during the first run of
|
181 |
+
the function.
|
182 |
+
|
183 |
+
|
184 |
+
Support
|
185 |
+
=======
|
186 |
+
For general information, go to the Intel support website at:
|
187 |
+
http://www.intel.com/support/
|
188 |
+
|
189 |
+
If an issue is identified with the released source code on a supported kernel
|
190 |
+
with a supported adapter, email the specific information related to the issue
|
191 |
+
to e1000-rdma@lists.sourceforge.net
|
192 |
+
|
193 |
+
|
194 |
+
Copyright(c) 2017-2019 Intel Corporation.
|
195 |
+
|
196 |
+
|
197 |
+
Trademarks
|
198 |
+
==========
|
199 |
+
Intel is a trademark or registered trademark of Intel Corporation or its
|
200 |
+
subsidiaries in the United States and/or other countries.
|
201 |
+
|
202 |
+
* Other names and brands may be claimed as the property of others.
|
203 |
+
|
204 |
+
|
same_training_and_inference_test.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for when the training and inference graphs are the same."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
import tensorflow as tf
|
21 |
+
|
22 |
+
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import same_training_and_inference as svd
|
23 |
+
from tensorflow_model_optimization.python.core.keras.compat import keras
|
24 |
+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
|
25 |
+
|
26 |
+
|
27 |
+
def _build_model():
|
28 |
+
i = keras.layers.Input(shape=(28, 28), name='input')
|
29 |
+
x = keras.layers.Reshape((28, 28, 1))(i)
|
30 |
+
x = keras.layers.Conv2D(
|
31 |
+
20, 5, activation='relu', padding='valid', name='conv1'
|
32 |
+
)(x)
|
33 |
+
x = keras.layers.MaxPool2D(2, 2)(x)
|
34 |
+
x = keras.layers.Conv2D(
|
35 |
+
50, 5, activation='relu', padding='valid', name='conv2'
|
36 |
+
)(x)
|
37 |
+
x = keras.layers.MaxPool2D(2, 2)(x)
|
38 |
+
x = keras.layers.Flatten()(x)
|
39 |
+
x = keras.layers.Dense(500, activation='relu', name='fc1')(x)
|
40 |
+
output = keras.layers.Dense(10, name='fc2')(x)
|
41 |
+
|
42 |
+
model = keras.Model(inputs=[i], outputs=[output])
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def _get_dataset():
|
47 |
+
mnist = keras.datasets.mnist
|
48 |
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
49 |
+
x_train, x_test = x_train / 255.0, x_test / 255.0
|
50 |
+
# Use subset of 60000 examples to keep unit test speed fast.
|
51 |
+
x_train = x_train[0:1000]
|
52 |
+
y_train = y_train[0:1000]
|
53 |
+
return (x_train, y_train), (x_test, y_test)
|
54 |
+
|
55 |
+
|
56 |
+
def _train_model(model):
|
57 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
58 |
+
|
59 |
+
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
|
60 |
+
|
61 |
+
(x_train, y_train), _ = _get_dataset()
|
62 |
+
|
63 |
+
model.fit(x_train, y_train, epochs=1)
|
64 |
+
|
65 |
+
|
66 |
+
def _save_as_saved_model(model):
|
67 |
+
saved_model_dir = tempfile.mkdtemp()
|
68 |
+
model.save(saved_model_dir)
|
69 |
+
return saved_model_dir
|
70 |
+
|
71 |
+
|
72 |
+
# TODO(tfmot): reuse existing test utilities.
|
73 |
+
def _convert_to_tflite(saved_model_dir):
|
74 |
+
_, tflite_file = tempfile.mkstemp()
|
75 |
+
|
76 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
77 |
+
tflite_model = converter.convert()
|
78 |
+
|
79 |
+
with open(tflite_file, 'wb') as f:
|
80 |
+
f.write(tflite_model)
|
81 |
+
|
82 |
+
return tflite_file
|
83 |
+
|
84 |
+
|
85 |
+
def _get_directory_size_in_bytes(directory):
|
86 |
+
total = 0
|
87 |
+
try:
|
88 |
+
for entry in os.scandir(directory):
|
89 |
+
if entry.is_file():
|
90 |
+
# if it's a file, use stat() function
|
91 |
+
total += entry.stat().st_size
|
92 |
+
elif entry.is_dir():
|
93 |
+
# if it's a directory, recursively call this function
|
94 |
+
total += _get_directory_size_in_bytes(entry.path)
|
95 |
+
except NotADirectoryError:
|
96 |
+
# if `directory` isn't a directory, get the file size then
|
97 |
+
return os.path.getsize(directory)
|
98 |
+
except PermissionError:
|
99 |
+
# if for whatever reason we can't open the folder, return 0
|
100 |
+
return 0
|
101 |
+
return total
|
102 |
+
|
103 |
+
|
104 |
+
class FunctionalTest(tf.test.TestCase):
|
105 |
+
|
106 |
+
# TODO(tfmot): can simplify to single layer test that checks exact
|
107 |
+
# dimensions of weights.
|
108 |
+
def testSVD_ReducesSavedModelSize(self):
|
109 |
+
model = _build_model()
|
110 |
+
|
111 |
+
original_saved_model_dir = _save_as_saved_model(model)
|
112 |
+
|
113 |
+
compressed_model = svd.SVD(rank=16).compress_model(model)
|
114 |
+
|
115 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
116 |
+
|
117 |
+
original_size = _get_directory_size_in_bytes(original_saved_model_dir)
|
118 |
+
compressed_size = _get_directory_size_in_bytes(saved_model_dir)
|
119 |
+
|
120 |
+
self.assertLess(compressed_size, original_size / 3)
|
121 |
+
|
122 |
+
def testSVD_HasReasonableAccuracy_TF(self):
|
123 |
+
model = _build_model()
|
124 |
+
|
125 |
+
compressed_model = svd.SVD(rank=16).compress_model(model)
|
126 |
+
|
127 |
+
_train_model(compressed_model)
|
128 |
+
|
129 |
+
_, (x_test, y_test) = _get_dataset()
|
130 |
+
|
131 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
132 |
+
|
133 |
+
compressed_model.compile(
|
134 |
+
optimizer='adam', loss=loss_fn, metrics=['accuracy'])
|
135 |
+
|
136 |
+
results = compressed_model.evaluate(x_test, y_test)
|
137 |
+
|
138 |
+
self.assertGreater(results[1], 0.60)
|
139 |
+
|
140 |
+
def testSVD_ReducesTFLiteModelSize(self):
|
141 |
+
model = _build_model()
|
142 |
+
|
143 |
+
original_saved_model_dir = _save_as_saved_model(model)
|
144 |
+
original_tflite_file = _convert_to_tflite(original_saved_model_dir)
|
145 |
+
|
146 |
+
compressed_model = svd.SVD(rank=16).compress_model(model)
|
147 |
+
|
148 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
149 |
+
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
|
150 |
+
|
151 |
+
original_size = os.path.getsize(original_tflite_file)
|
152 |
+
compressed_size = os.path.getsize(compressed_tflite_file)
|
153 |
+
|
154 |
+
self.assertLess(compressed_size, original_size / 6)
|
155 |
+
|
156 |
+
def testSVD_HasReasonableAccuracy_TFLite(self):
|
157 |
+
model = _build_model()
|
158 |
+
|
159 |
+
compressed_model = svd.SVD(rank=16).compress_model(model)
|
160 |
+
|
161 |
+
_train_model(compressed_model)
|
162 |
+
|
163 |
+
saved_model_dir = _save_as_saved_model(compressed_model)
|
164 |
+
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
|
165 |
+
|
166 |
+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
|
167 |
+
|
168 |
+
self.assertGreater(accuracy, 0.60)
|
169 |
+
|
170 |
+
# TODO(tfmot): can simplify to single layer test.
|
171 |
+
def testSVD_BreaksDownLayerWeights(self):
|
172 |
+
model = _build_model()
|
173 |
+
|
174 |
+
first_conv_layer = model.layers[2]
|
175 |
+
self.assertLen(first_conv_layer.weights, 2)
|
176 |
+
|
177 |
+
compressed_model = svd.SVD(rank=16).compress_model(model)
|
178 |
+
|
179 |
+
first_conv_layer = compressed_model.layers[2]
|
180 |
+
|
181 |
+
self.assertLen(first_conv_layer.weights, 3)
|
182 |
+
|
183 |
+
# TODO(tfmot): can simplify to single layer test.
|
184 |
+
def testSVD_PreservesPretrainedWeights(self):
|
185 |
+
i = keras.layers.Input(shape=(2), name='input')
|
186 |
+
output = keras.layers.Dense(3, name='fc1')(i)
|
187 |
+
model = keras.Model(inputs=[i], outputs=[output])
|
188 |
+
|
189 |
+
dense_layer_weights = model.layers[1].get_weights()
|
190 |
+
|
191 |
+
algorithm = svd.SVD(rank=1)
|
192 |
+
compressed_model = algorithm.compress_model(model)
|
193 |
+
|
194 |
+
dense_layer_compressed_weights = compressed_model.layers[1].get_weights()
|
195 |
+
|
196 |
+
# kernel
|
197 |
+
algorithm.weight_reprs = []
|
198 |
+
algorithm.init_training_weights(dense_layer_weights[0])
|
199 |
+
w1_repr, w2_repr = algorithm.weight_reprs
|
200 |
+
assert (w1_repr.kwargs['initializer'](None) == \
|
201 |
+
dense_layer_compressed_weights[0]).numpy().all()
|
202 |
+
assert (w2_repr.kwargs['initializer'](None) == \
|
203 |
+
dense_layer_compressed_weights[1]).numpy().all()
|
204 |
+
|
205 |
+
# bias
|
206 |
+
assert (dense_layer_weights[1] == dense_layer_compressed_weights[2]).all()
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == '__main__':
|
210 |
+
tf.test.main()
|