Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/utils
/weight_norm.py
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Weight Norm Modules.""" | |
| import warnings | |
| import tensorflow as tf | |
| class WeightNormalization(tf.keras.layers.Wrapper): | |
| """Layer wrapper to decouple magnitude and direction of the layer's weights. | |
| This wrapper reparameterizes a layer by decoupling the weight's | |
| magnitude and direction. This speeds up convergence by improving the | |
| conditioning of the optimization problem. It has an optional data-dependent | |
| initialization scheme, in which initial values of weights are set as functions | |
| of the first minibatch of data. Both the weight normalization and data- | |
| dependent initialization are described in [Salimans and Kingma (2016)][1]. | |
| #### Example | |
| ```python | |
| net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), | |
| input_shape=(32, 32, 3), data_init=True)(x) | |
| net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'), | |
| data_init=True) | |
| net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), | |
| data_init=True)(net) | |
| net = WeightNorm(tf.keras.layers.Dense(num_classes), | |
| data_init=True)(net) | |
| ``` | |
| #### References | |
| [1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple | |
| Reparameterization to Accelerate Training of Deep Neural Networks. In | |
| _30th Conference on Neural Information Processing Systems_, 2016. | |
| https://arxiv.org/abs/1602.07868 | |
| """ | |
| def __init__(self, layer, data_init=True, **kwargs): | |
| """Initialize WeightNorm wrapper. | |
| Args: | |
| layer: A `tf.keras.layers.Layer` instance. Supported layer types are | |
| `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs | |
| are not supported. | |
| data_init: `bool`, if `True` use data dependent variable initialization. | |
| **kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`. | |
| Raises: | |
| ValueError: If `layer` is not a `tf.keras.layers.Layer` instance. | |
| """ | |
| if not isinstance(layer, tf.keras.layers.Layer): | |
| raise ValueError( | |
| "Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` " | |
| "instance. You passed: {input}".format(input=layer) | |
| ) | |
| layer_type = type(layer).__name__ | |
| if layer_type not in [ | |
| "Dense", | |
| "Conv2D", | |
| "Conv2DTranspose", | |
| "Conv1D", | |
| "GroupConv1D", | |
| ]: | |
| warnings.warn( | |
| "`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, " | |
| "`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format( | |
| layer_type | |
| ) | |
| ) | |
| super().__init__(layer, **kwargs) | |
| self.data_init = data_init | |
| self._track_trackable(layer, name="layer") | |
| self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1 | |
| def _compute_weights(self): | |
| """Generate weights with normalization.""" | |
| # Determine the axis along which to expand `g` so that `g` broadcasts to | |
| # the shape of `v`. | |
| new_axis = -self.filter_axis - 3 | |
| self.layer.kernel = tf.nn.l2_normalize( | |
| self.v, axis=self.kernel_norm_axes | |
| ) * tf.expand_dims(self.g, new_axis) | |
| def _init_norm(self): | |
| """Set the norm of the weight vector.""" | |
| kernel_norm = tf.sqrt( | |
| tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes) | |
| ) | |
| self.g.assign(kernel_norm) | |
| def _data_dep_init(self, inputs): | |
| """Data dependent initialization.""" | |
| # Normalize kernel first so that calling the layer calculates | |
| # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]). | |
| self._compute_weights() | |
| activation = self.layer.activation | |
| self.layer.activation = None | |
| use_bias = self.layer.bias is not None | |
| if use_bias: | |
| bias = self.layer.bias | |
| self.layer.bias = tf.zeros_like(bias) | |
| # Since the bias is initialized as zero, setting the activation to zero and | |
| # calling the initialized layer (with normalized kernel) yields the correct | |
| # computation ((5) in Salimans and Kingma (2016)) | |
| x_init = self.layer(inputs) | |
| norm_axes_out = list(range(x_init.shape.rank - 1)) | |
| m_init, v_init = tf.nn.moments(x_init, norm_axes_out) | |
| scale_init = 1.0 / tf.sqrt(v_init + 1e-10) | |
| self.g.assign(self.g * scale_init) | |
| if use_bias: | |
| self.layer.bias = bias | |
| self.layer.bias.assign(-m_init * scale_init) | |
| self.layer.activation = activation | |
| def build(self, input_shape=None): | |
| """Build `Layer`. | |
| Args: | |
| input_shape: The shape of the input to `self.layer`. | |
| Raises: | |
| ValueError: If `Layer` does not contain a `kernel` of weights | |
| """ | |
| if not self.layer.built: | |
| self.layer.build(input_shape) | |
| if not hasattr(self.layer, "kernel"): | |
| raise ValueError( | |
| "`WeightNorm` must wrap a layer that" | |
| " contains a `kernel` for weights" | |
| ) | |
| self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims)) | |
| self.kernel_norm_axes.pop(self.filter_axis) | |
| self.v = self.layer.kernel | |
| # to avoid a duplicate `kernel` variable after `build` is called | |
| self.layer.kernel = None | |
| self.g = self.add_weight( | |
| name="g", | |
| shape=(int(self.v.shape[self.filter_axis]),), | |
| initializer="ones", | |
| dtype=self.v.dtype, | |
| trainable=True, | |
| ) | |
| self.initialized = self.add_weight( | |
| name="initialized", dtype=tf.bool, trainable=False | |
| ) | |
| self.initialized.assign(False) | |
| super().build() | |
| def call(self, inputs): | |
| """Call `Layer`.""" | |
| if not self.initialized: | |
| if self.data_init: | |
| self._data_dep_init(inputs) | |
| else: | |
| # initialize `g` as the norm of the initialized kernel | |
| self._init_norm() | |
| self.initialized.assign(True) | |
| self._compute_weights() | |
| output = self.layer(inputs) | |
| return output | |
| def compute_output_shape(self, input_shape): | |
| return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list()) | |