Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/utils
/weight_norm.py
# -*- coding: utf-8 -*- | |
# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Weight Norm Modules.""" | |
import warnings | |
import tensorflow as tf | |
class WeightNormalization(tf.keras.layers.Wrapper): | |
"""Layer wrapper to decouple magnitude and direction of the layer's weights. | |
This wrapper reparameterizes a layer by decoupling the weight's | |
magnitude and direction. This speeds up convergence by improving the | |
conditioning of the optimization problem. It has an optional data-dependent | |
initialization scheme, in which initial values of weights are set as functions | |
of the first minibatch of data. Both the weight normalization and data- | |
dependent initialization are described in [Salimans and Kingma (2016)][1]. | |
#### Example | |
```python | |
net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), | |
input_shape=(32, 32, 3), data_init=True)(x) | |
net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'), | |
data_init=True) | |
net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), | |
data_init=True)(net) | |
net = WeightNorm(tf.keras.layers.Dense(num_classes), | |
data_init=True)(net) | |
``` | |
#### References | |
[1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple | |
Reparameterization to Accelerate Training of Deep Neural Networks. In | |
_30th Conference on Neural Information Processing Systems_, 2016. | |
https://arxiv.org/abs/1602.07868 | |
""" | |
def __init__(self, layer, data_init=True, **kwargs): | |
"""Initialize WeightNorm wrapper. | |
Args: | |
layer: A `tf.keras.layers.Layer` instance. Supported layer types are | |
`Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs | |
are not supported. | |
data_init: `bool`, if `True` use data dependent variable initialization. | |
**kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`. | |
Raises: | |
ValueError: If `layer` is not a `tf.keras.layers.Layer` instance. | |
""" | |
if not isinstance(layer, tf.keras.layers.Layer): | |
raise ValueError( | |
"Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` " | |
"instance. You passed: {input}".format(input=layer) | |
) | |
layer_type = type(layer).__name__ | |
if layer_type not in [ | |
"Dense", | |
"Conv2D", | |
"Conv2DTranspose", | |
"Conv1D", | |
"GroupConv1D", | |
]: | |
warnings.warn( | |
"`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, " | |
"`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format( | |
layer_type | |
) | |
) | |
super().__init__(layer, **kwargs) | |
self.data_init = data_init | |
self._track_trackable(layer, name="layer") | |
self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1 | |
def _compute_weights(self): | |
"""Generate weights with normalization.""" | |
# Determine the axis along which to expand `g` so that `g` broadcasts to | |
# the shape of `v`. | |
new_axis = -self.filter_axis - 3 | |
self.layer.kernel = tf.nn.l2_normalize( | |
self.v, axis=self.kernel_norm_axes | |
) * tf.expand_dims(self.g, new_axis) | |
def _init_norm(self): | |
"""Set the norm of the weight vector.""" | |
kernel_norm = tf.sqrt( | |
tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes) | |
) | |
self.g.assign(kernel_norm) | |
def _data_dep_init(self, inputs): | |
"""Data dependent initialization.""" | |
# Normalize kernel first so that calling the layer calculates | |
# `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]). | |
self._compute_weights() | |
activation = self.layer.activation | |
self.layer.activation = None | |
use_bias = self.layer.bias is not None | |
if use_bias: | |
bias = self.layer.bias | |
self.layer.bias = tf.zeros_like(bias) | |
# Since the bias is initialized as zero, setting the activation to zero and | |
# calling the initialized layer (with normalized kernel) yields the correct | |
# computation ((5) in Salimans and Kingma (2016)) | |
x_init = self.layer(inputs) | |
norm_axes_out = list(range(x_init.shape.rank - 1)) | |
m_init, v_init = tf.nn.moments(x_init, norm_axes_out) | |
scale_init = 1.0 / tf.sqrt(v_init + 1e-10) | |
self.g.assign(self.g * scale_init) | |
if use_bias: | |
self.layer.bias = bias | |
self.layer.bias.assign(-m_init * scale_init) | |
self.layer.activation = activation | |
def build(self, input_shape=None): | |
"""Build `Layer`. | |
Args: | |
input_shape: The shape of the input to `self.layer`. | |
Raises: | |
ValueError: If `Layer` does not contain a `kernel` of weights | |
""" | |
if not self.layer.built: | |
self.layer.build(input_shape) | |
if not hasattr(self.layer, "kernel"): | |
raise ValueError( | |
"`WeightNorm` must wrap a layer that" | |
" contains a `kernel` for weights" | |
) | |
self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims)) | |
self.kernel_norm_axes.pop(self.filter_axis) | |
self.v = self.layer.kernel | |
# to avoid a duplicate `kernel` variable after `build` is called | |
self.layer.kernel = None | |
self.g = self.add_weight( | |
name="g", | |
shape=(int(self.v.shape[self.filter_axis]),), | |
initializer="ones", | |
dtype=self.v.dtype, | |
trainable=True, | |
) | |
self.initialized = self.add_weight( | |
name="initialized", dtype=tf.bool, trainable=False | |
) | |
self.initialized.assign(False) | |
super().build() | |
def call(self, inputs): | |
"""Call `Layer`.""" | |
if not self.initialized: | |
if self.data_init: | |
self._data_dep_init(inputs) | |
else: | |
# initialize `g` as the norm of the initialized kernel | |
self._init_norm() | |
self.initialized.assign(True) | |
self._compute_weights() | |
output = self.layer(inputs) | |
return output | |
def compute_output_shape(self, input_shape): | |
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list()) | |