NCTCMumbai's picture
Upload 2583 files
97b6013 verified
raw
history blame
No virus
4.19 kB
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizers for use in unrolled optimization.
These optimizers contain a compute_updates function and its own ability to keep
track of internal state.
These functions can be used with a tf.while_loop to perform multiple training
steps per sess.run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import tensorflow as tf
import sonnet as snt
from learning_unsupervised_learning import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
class UnrollableOptimizer(snt.AbstractModule):
"""Interface for optimizers that can be used in unrolled computation.
apply_gradients is derrived from compute_update and assign_state.
"""
def __init__(self, *args, **kwargs):
super(UnrollableOptimizer, self).__init__(*args, **kwargs)
self()
@abc.abstractmethod
def compute_updates(self, xs, gs, state=None):
"""Compute next step updates for a given variable list and state.
Args:
xs: list of tensors
The "variables" to perform an update on.
Note these must match the same order for which get_state was originally
called.
gs: list of tensors
Gradients of `xs` with respect to some loss.
state: Any
Optimizer specific state to keep track of accumulators such as momentum
terms
"""
raise NotImplementedError()
def _build(self):
pass
@abc.abstractmethod
def get_state(self, var_list):
"""Get the state value associated with a list of tf.Variables.
This state is commonly going to be a NamedTuple that contains some
mapping between variables and the state associated with those variables.
This state could be a moving momentum variable tracked by the optimizer.
Args:
var_list: list of tf.Variable
Returns:
state: Any
Optimizer specific state
"""
raise NotImplementedError()
def assign_state(self, state):
"""Assigns the state to the optimizers internal variables.
Args:
state: Any
Returns:
op: tf.Operation
The operation that performs the assignment.
"""
raise NotImplementedError()
def apply_gradients(self, grad_vars):
gradients, variables = zip(*grad_vars)
state = self.get_state(variables)
new_vars, new_state = self.compute_updates(variables, gradients, state)
assign_op = self.assign_state(new_state)
op = utils.assign_variables(variables, new_vars)
return tf.group(assign_op, op, name="apply_gradients")
class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer):
def __init__(self,
learning_rate,
name="UnrollableGradientDescentRollingOptimizer"):
self.learning_rate = learning_rate
super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name)
def compute_updates(self, xs, gs, learning_rates, state):
new_vars = []
for x, g, lr in utils.eqzip(xs, gs, learning_rates):
if lr is None:
lr = self.learning_rate
if g is not None:
new_vars.append((x * (1 - lr) - g * lr))
else:
new_vars.append(x)
return new_vars, state
def get_state(self, var_list):
return tf.constant(0.0)
def assign_state(self, state, var_list=None):
return tf.no_op()