vishred18's picture
Upload 364 files
d5ee97c
# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Norm Modules."""
import warnings
import tensorflow as tf
class WeightNormalization(tf.keras.layers.Wrapper):
"""Layer wrapper to decouple magnitude and direction of the layer's weights.
This wrapper reparameterizes a layer by decoupling the weight's
magnitude and direction. This speeds up convergence by improving the
conditioning of the optimization problem. It has an optional data-dependent
initialization scheme, in which initial values of weights are set as functions
of the first minibatch of data. Both the weight normalization and data-
dependent initialization are described in [Salimans and Kingma (2016)][1].
#### Example
```python
net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
input_shape=(32, 32, 3), data_init=True)(x)
net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'),
data_init=True)
net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
data_init=True)(net)
net = WeightNorm(tf.keras.layers.Dense(num_classes),
data_init=True)(net)
```
#### References
[1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks. In
_30th Conference on Neural Information Processing Systems_, 2016.
https://arxiv.org/abs/1602.07868
"""
def __init__(self, layer, data_init=True, **kwargs):
"""Initialize WeightNorm wrapper.
Args:
layer: A `tf.keras.layers.Layer` instance. Supported layer types are
`Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
are not supported.
data_init: `bool`, if `True` use data dependent variable initialization.
**kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`.
Raises:
ValueError: If `layer` is not a `tf.keras.layers.Layer` instance.
"""
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError(
"Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` "
"instance. You passed: {input}".format(input=layer)
)
layer_type = type(layer).__name__
if layer_type not in [
"Dense",
"Conv2D",
"Conv2DTranspose",
"Conv1D",
"GroupConv1D",
]:
warnings.warn(
"`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, "
"`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format(
layer_type
)
)
super().__init__(layer, **kwargs)
self.data_init = data_init
self._track_trackable(layer, name="layer")
self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1
def _compute_weights(self):
"""Generate weights with normalization."""
# Determine the axis along which to expand `g` so that `g` broadcasts to
# the shape of `v`.
new_axis = -self.filter_axis - 3
self.layer.kernel = tf.nn.l2_normalize(
self.v, axis=self.kernel_norm_axes
) * tf.expand_dims(self.g, new_axis)
def _init_norm(self):
"""Set the norm of the weight vector."""
kernel_norm = tf.sqrt(
tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes)
)
self.g.assign(kernel_norm)
def _data_dep_init(self, inputs):
"""Data dependent initialization."""
# Normalize kernel first so that calling the layer calculates
# `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
self._compute_weights()
activation = self.layer.activation
self.layer.activation = None
use_bias = self.layer.bias is not None
if use_bias:
bias = self.layer.bias
self.layer.bias = tf.zeros_like(bias)
# Since the bias is initialized as zero, setting the activation to zero and
# calling the initialized layer (with normalized kernel) yields the correct
# computation ((5) in Salimans and Kingma (2016))
x_init = self.layer(inputs)
norm_axes_out = list(range(x_init.shape.rank - 1))
m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
scale_init = 1.0 / tf.sqrt(v_init + 1e-10)
self.g.assign(self.g * scale_init)
if use_bias:
self.layer.bias = bias
self.layer.bias.assign(-m_init * scale_init)
self.layer.activation = activation
def build(self, input_shape=None):
"""Build `Layer`.
Args:
input_shape: The shape of the input to `self.layer`.
Raises:
ValueError: If `Layer` does not contain a `kernel` of weights
"""
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, "kernel"):
raise ValueError(
"`WeightNorm` must wrap a layer that"
" contains a `kernel` for weights"
)
self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
self.kernel_norm_axes.pop(self.filter_axis)
self.v = self.layer.kernel
# to avoid a duplicate `kernel` variable after `build` is called
self.layer.kernel = None
self.g = self.add_weight(
name="g",
shape=(int(self.v.shape[self.filter_axis]),),
initializer="ones",
dtype=self.v.dtype,
trainable=True,
)
self.initialized = self.add_weight(
name="initialized", dtype=tf.bool, trainable=False
)
self.initialized.assign(False)
super().build()
def call(self, inputs):
"""Call `Layer`."""
if not self.initialized:
if self.data_init:
self._data_dep_init(inputs)
else:
# initialize `g` as the norm of the initialized kernel
self._init_norm()
self.initialized.assign(True)
self._compute_weights()
output = self.layer(inputs)
return output
def compute_output_shape(self, input_shape):
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())