Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utility functions for blocks.""" | |
from __future__ import division | |
from __future__ import unicode_literals | |
import math | |
import numpy as np | |
import six | |
import tensorflow as tf | |
class RsqrtInitializer(object): | |
"""Gaussian initializer with standard deviation 1/sqrt(n). | |
Note that tf.truncated_normal is used internally. Therefore any random sample | |
outside two-sigma will be discarded and re-sampled. | |
""" | |
def __init__(self, dims=(0,), **kwargs): | |
"""Creates an initializer. | |
Args: | |
dims: Dimension(s) index to compute standard deviation: | |
1.0 / sqrt(product(shape[dims])) | |
**kwargs: Extra keyword arguments to pass to tf.truncated_normal. | |
""" | |
if isinstance(dims, six.integer_types): | |
self._dims = [dims] | |
else: | |
self._dims = dims | |
self._kwargs = kwargs | |
def __call__(self, shape, dtype): | |
stddev = 1.0 / np.sqrt(np.prod([shape[x] for x in self._dims])) | |
return tf.truncated_normal( | |
shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) | |
class RectifierInitializer(object): | |
"""Gaussian initializer with standard deviation sqrt(2/fan_in). | |
Note that tf.random_normal is used internally to ensure the expected weight | |
distribution. This is intended to be used with ReLU activations, specially | |
in ResNets. | |
For details please refer to: | |
Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet | |
Classification | |
""" | |
def __init__(self, dims=(0,), scale=2.0, **kwargs): | |
"""Creates an initializer. | |
Args: | |
dims: Dimension(s) index to compute standard deviation: | |
sqrt(scale / product(shape[dims])) | |
scale: A constant scaling for the initialization used as | |
sqrt(scale / product(shape[dims])). | |
**kwargs: Extra keyword arguments to pass to tf.truncated_normal. | |
""" | |
if isinstance(dims, six.integer_types): | |
self._dims = [dims] | |
else: | |
self._dims = dims | |
self._kwargs = kwargs | |
self._scale = scale | |
def __call__(self, shape, dtype): | |
stddev = np.sqrt(self._scale / np.prod([shape[x] for x in self._dims])) | |
return tf.random_normal( | |
shape=shape, dtype=dtype, stddev=stddev, **self._kwargs) | |
class GaussianInitializer(object): | |
"""Gaussian initializer with a given standard deviation. | |
Note that tf.truncated_normal is used internally. Therefore any random sample | |
outside two-sigma will be discarded and re-sampled. | |
""" | |
def __init__(self, stddev=1.0): | |
self._stddev = stddev | |
def __call__(self, shape, dtype): | |
return tf.truncated_normal(shape=shape, dtype=dtype, stddev=self._stddev) | |