valhalla commited on
Commit
f6c4cb2
1 Parent(s): 29db327

make checkpointing optional

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +6 -6
dalle_mini/modeling_bart_flax.py CHANGED
@@ -252,8 +252,7 @@ class FlaxBartEncoderLayer(nn.Module):
252
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
253
  )
254
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
255
-
256
- @nn.remat
257
  def __call__(
258
  self,
259
  hidden_states: jnp.ndarray,
@@ -283,8 +282,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
283
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
284
 
285
  def setup(self):
 
286
  self.layers = [
287
- FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
288
  ]
289
 
290
  def __call__(
@@ -344,8 +344,7 @@ class FlaxBartDecoderLayer(nn.Module):
344
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
345
  )
346
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
347
-
348
- @nn.remat
349
  def __call__(
350
  self,
351
  hidden_states: jnp.ndarray,
@@ -394,8 +393,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
394
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
395
 
396
  def setup(self):
 
397
  self.layers = [
398
- FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
399
  ]
400
 
401
  def __call__(
 
252
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
253
  )
254
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
255
+
 
256
  def __call__(
257
  self,
258
  hidden_states: jnp.ndarray,
 
282
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
283
 
284
  def setup(self):
285
+ layer_module = nn.remat(FlaxBartEncoderLayer) if self.config.gradient_checkpointing else FlaxBartEncoderLayer
286
  self.layers = [
287
+ layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
288
  ]
289
 
290
  def __call__(
 
344
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
345
  )
346
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
347
+
 
348
  def __call__(
349
  self,
350
  hidden_states: jnp.ndarray,
 
393
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
394
 
395
  def setup(self):
396
+ layer_module = nn.remat(FlaxBartDecoderLayer) if self.config.gradient_checkpointing else FlaxBartDecoderLayer
397
  self.layers = [
398
+ layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
399
  ]
400
 
401
  def __call__(