File size: 10,948 Bytes
506da10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for rematerialization.

Incubates a version of tf.recompute_grad that is XLA compatible.

This file is based on the recompute_grad.py in the bigbird codebase [1]:
https://github.com/google-research/bigbird/blob/db06498ec8804c6438111938d8654b66ddaccd5d/bigbird/core/recompute_grad.py

[1] Big Bird: Transformers for Longer Sequences, NeurIPS 2020.
      Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris
      Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li
      Yang, Amr Ahmed.
"""
import collections
import os
import threading
from typing import Deque, List, NamedTuple, Optional, Sequence

from absl import logging
import tensorflow.compat.v2 as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.ops import custom_gradient


# Remove when https://github.com/tensorflow/tensorflow/pull/45298
# gets merged
def get_variable_by_name(var_name):
  """Retrieves tf.Variable from name in MirroredStrategy (multi-gpu)."""

  # Get all variables, but it will have copies from different replicas
  all_global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)

  def _replica_filter(var):
    """Filter out variables from different context."""
    try:
      return var_name == var.op.name
    except AttributeError:
      return False
  candidate_vars = list(filter(_replica_filter, all_global_vars))

  if len(candidate_vars) >= 1:
    # Filter out non-trainable variables.
    candidate_vars = [v for v in candidate_vars if v.trainable]
  else:
    raise ValueError('Unsuccessful at finding variable {}.'.format(var_name))

  if len(candidate_vars) == 1:
    return candidate_vars[0]
  elif len(candidate_vars) > 1:
    raise ValueError(
        'Unsuccessful at finding trainable variable {}. '
        'Number of candidates: {}. '
        'Candidates: {}'.format(var_name, len(candidate_vars), candidate_vars))
  else:
    # The variable is not trainable.
    return None
custom_gradient.get_variable_by_name = get_variable_by_name


class RecomputeContext(
    NamedTuple('RecomputeContext', [
        ('is_recomputing', bool),
        ('seed', tf.Tensor),
        ('children', Deque['RecomputeContext']),
    ])):
  """Context for recomputation.

  Attributes:
    is_recomputing: Whether we are in a recomputation phase.
    seed: Scalar integer tensor that should be used with stateless random ops
      for deterministic behavior and correct computation of the gradient.
    children: Nested `RecomputeContext` instances. Used internally by
      `recompute_grad` to track nested instances of `RecomputeContext`.
  """

  def __enter__(self):
    return _context_stack.push(self)

  def __exit__(self, exc_type, exc_value, traceback):
    _context_stack.pop(self)


# Simplified version of `_DefaultStack` in
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py.
class _ContextStack(threading.local):
  """A thread-local stack for providing implicit recompute contexts."""

  def __init__(self):
    super(_ContextStack, self).__init__()
    self._stack = []

  def top(self) -> Optional[RecomputeContext]:
    return self._stack[-1] if self._stack else None

  def push(self, context: RecomputeContext):
    self._stack.append(context)
    return context

  def pop(self, context: RecomputeContext):
    if self._stack[-1] is not context:
      raise AssertionError('Nesting violated for RecomputeContext.')
    self._stack.pop()


_context_stack = _ContextStack()


def get_recompute_context() -> Optional[RecomputeContext]:
  """Returns the current recomputing context if it exists."""
  return _context_stack.top()


# Adapted from
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_util.py.
def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]:
  """Returns the first ancestor `XLAControlFlowContext` in the `graph`."""
  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
  while ctxt:
    if ctxt.IsXLAContext():
      return ctxt
    ctxt = ctxt.outer_context
  return None


def _in_xla_context(graph: Optional[tf.Graph] = None) -> bool:
  """Detects whether we are in an XLA context."""
  if '--tf_xla_auto_jit=2' in os.environ.get('TF_XLA_FLAGS', ''):
    return True
  graph = tf.compat.v1.get_default_graph() if graph is None else graph
  while True:
    if _get_containing_xla_context(graph) is not None:
      return True
    try:
      graph = graph.outer_graph
    except AttributeError:
      return False


def _force_data_dependency(
    first_compute: Sequence[tf.Tensor],
    then_compute: Sequence[tf.Tensor]) -> List[tf.Tensor]:
  """Forces all of `then_compute` to depend on all of `first_compute`.

  Uses a dummy data dependency, which is useful when running on TPUs because
  XLA ignores control dependencies. Only supports float arguments.

  Args:
    first_compute: Sequence of `Tensor`s to be executed before `then_compute`.
    then_compute: Sequence of `Tensor`s to executed after `first_compute`.

  Returns:
    Sequence of `Tensor`s with same length of `then_compute`.

  Raises:
    ValueError: if ranks are unknown or types are not floating.
  """

  def _first_element(x):
    if x.shape.ndims is None:
      raise ValueError('Rank of Tensor %s must be known' % x)
    ndims = x.shape.ndims
    begin = tf.zeros(ndims, dtype=tf.int32)
    size = tf.ones(ndims, dtype=tf.int32)
    return tf.reshape(tf.slice(x, begin, size), [])

  first_compute_sum = tf.add_n(
      [_first_element(x) for x in first_compute if x is not None])
  dtype = first_compute_sum.dtype
  if not dtype.is_floating:
    raise ValueError('_force_data_dependency only supports floating dtypes.')
  zero = tf.cast(0.0, first_compute_sum.dtype) * first_compute_sum
  then_compute_sequence = [
      x + tf.cast(zero, x.dtype) if x is not None else None
      for x in tf.nest.flatten(then_compute)
  ]
  return tf.nest.pack_sequence_as(then_compute, then_compute_sequence)


def _make_seed_if_none(seed: Optional[tf.Tensor]) -> tf.Tensor:
  """Uses the global generator to make a seed if necessary."""
  if seed is not None:
    return seed
  generator = tf.random.experimental.get_global_generator()
  # The two seeds for stateless random ops don't have individual semantics and
  # are scrambled together, so providing one seed is fine. This makes it easier
  # for users to provide a local seed without worrying about integer overflow.
  # See `make_seeds` in
  # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py.
  try:
    return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed')
  except (RuntimeError, TypeError, ValueError, tf.errors.NotFoundError) as e:
    # For a number of reasons, the above operation can fail like using multiple
    # graphs or toggling between eager and graph modes. Reset the generator.
    logging.warn('Resetting the generator. %s: %s', type(e), e)
    tf.random.experimental.set_global_generator(None)
    generator = tf.random.experimental.get_global_generator()
    return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed')


def recompute_grad(f, seed=None):
  """An eager-compatible version of recompute_grad.

  For f(*args, **kwargs), this supports gradients with respect to args, or to
  gradients with respect to any variables residing in the kwarg 'variables'.
  Note that for keras layer and model objects, this is handled automatically.

  Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
  be able to access the member variables of that object, because `g` returns
  through the wrapper function `inner`. When recomputing gradients through
  objects that inherit from keras, we suggest keeping a reference to the
  underlying object around for the purpose of accessing these variables.

  Args:
    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
    seed: Optional seed for random ops. `seed` should an integer scalar
      `Tensor`. When compiling to XLA, `seed` must have dtype `tf.int32`. If
      `seed` is not provided one will be generated.

  Returns:
   A function `g` that wraps `f`, but which recomputes `f` on the backwards
   pass of a gradient call.
  """

  @tf.custom_gradient
  def inner(*args, **kwargs):
    """Inner function closure for calculating gradients."""
    # Detect when we're nested and in the backwards pass, so we don't generate
    # an additional seed.
    parent_context = get_recompute_context()
    if parent_context is not None and parent_context.is_recomputing:
      # Use the cached context in the recomputation phase.
      with parent_context.children.popleft()._replace(
          is_recomputing=True) as context:
        result = f(*args, **kwargs)
    else:
      with RecomputeContext(
          is_recomputing=False,
          seed=_make_seed_if_none(seed),
          children=collections.deque()) as context:
        result = f(*args, **kwargs)
        # In the forward pass, build up a tree of recomputation contexts.
        if parent_context is not None and not parent_context.is_recomputing:
          parent_context.children.append(context)

    def grad(*dresult, **grad_kwargs):
      """Gradient function calculation for inner function."""
      variables = grad_kwargs.pop('variables', None)
      if grad_kwargs:
        raise ValueError('Found unexpected kwargs for `grad`: ',
                         list(grad_kwargs.keys()))
      inputs, seed = list(args), context.seed
      if _in_xla_context():
        inputs = _force_data_dependency(
            tf.nest.flatten(dresult), inputs + [seed])
        seed = inputs.pop()
      # tf.keras.backend.set_learning_phase(1)
      with tf.GradientTape() as tape:
        tape.watch(inputs)
        if variables is not None:
          tape.watch(variables)
        with tf.control_dependencies(dresult):
          with context._replace(is_recomputing=True, seed=seed):
            result = f(*inputs, **kwargs)
      kw_vars = []
      if variables is not None:
        kw_vars = list(variables)
      grads = tape.gradient(
          result, list(inputs) + kw_vars, output_gradients=dresult)
      return grads[:len(inputs)], grads[len(inputs):]

    return result, grad

  return inner