|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Optimizers for use in unrolled optimization. |
|
|
|
These optimizers contain a compute_updates function and its own ability to keep |
|
track of internal state. |
|
These functions can be used with a tf.while_loop to perform multiple training |
|
steps per sess.run. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import abc |
|
import collections |
|
import tensorflow as tf |
|
import sonnet as snt |
|
|
|
from learning_unsupervised_learning import utils |
|
|
|
from tensorflow.python.framework import ops |
|
from tensorflow.python.ops import math_ops |
|
from tensorflow.python.ops import resource_variable_ops |
|
from tensorflow.python.training import optimizer |
|
from tensorflow.python.training import training_ops |
|
|
|
|
|
class UnrollableOptimizer(snt.AbstractModule): |
|
"""Interface for optimizers that can be used in unrolled computation. |
|
apply_gradients is derrived from compute_update and assign_state. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super(UnrollableOptimizer, self).__init__(*args, **kwargs) |
|
self() |
|
|
|
@abc.abstractmethod |
|
def compute_updates(self, xs, gs, state=None): |
|
"""Compute next step updates for a given variable list and state. |
|
|
|
Args: |
|
xs: list of tensors |
|
The "variables" to perform an update on. |
|
Note these must match the same order for which get_state was originally |
|
called. |
|
gs: list of tensors |
|
Gradients of `xs` with respect to some loss. |
|
state: Any |
|
Optimizer specific state to keep track of accumulators such as momentum |
|
terms |
|
""" |
|
raise NotImplementedError() |
|
|
|
def _build(self): |
|
pass |
|
|
|
@abc.abstractmethod |
|
def get_state(self, var_list): |
|
"""Get the state value associated with a list of tf.Variables. |
|
|
|
This state is commonly going to be a NamedTuple that contains some |
|
mapping between variables and the state associated with those variables. |
|
This state could be a moving momentum variable tracked by the optimizer. |
|
|
|
Args: |
|
var_list: list of tf.Variable |
|
Returns: |
|
state: Any |
|
Optimizer specific state |
|
""" |
|
raise NotImplementedError() |
|
|
|
def assign_state(self, state): |
|
"""Assigns the state to the optimizers internal variables. |
|
|
|
Args: |
|
state: Any |
|
Returns: |
|
op: tf.Operation |
|
The operation that performs the assignment. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def apply_gradients(self, grad_vars): |
|
gradients, variables = zip(*grad_vars) |
|
state = self.get_state(variables) |
|
new_vars, new_state = self.compute_updates(variables, gradients, state) |
|
assign_op = self.assign_state(new_state) |
|
op = utils.assign_variables(variables, new_vars) |
|
return tf.group(assign_op, op, name="apply_gradients") |
|
|
|
|
|
class UnrollableGradientDescentRollingOptimizer(UnrollableOptimizer): |
|
|
|
def __init__(self, |
|
learning_rate, |
|
name="UnrollableGradientDescentRollingOptimizer"): |
|
self.learning_rate = learning_rate |
|
super(UnrollableGradientDescentRollingOptimizer, self).__init__(name=name) |
|
|
|
|
|
def compute_updates(self, xs, gs, learning_rates, state): |
|
new_vars = [] |
|
for x, g, lr in utils.eqzip(xs, gs, learning_rates): |
|
if lr is None: |
|
lr = self.learning_rate |
|
if g is not None: |
|
new_vars.append((x * (1 - lr) - g * lr)) |
|
else: |
|
new_vars.append(x) |
|
return new_vars, state |
|
|
|
def get_state(self, var_list): |
|
return tf.constant(0.0) |
|
|
|
def assign_state(self, state, var_list=None): |
|
return tf.no_op() |
|
|