valhalla commited on
Commit
95a8ed2
1 Parent(s): 7774483

add gradient checkpointing

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +4 -2
dalle_mini/modeling_bart_flax.py CHANGED
@@ -252,7 +252,8 @@ class FlaxBartEncoderLayer(nn.Module):
252
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
253
  )
254
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
255
-
 
256
  def __call__(
257
  self,
258
  hidden_states: jnp.ndarray,
@@ -343,7 +344,8 @@ class FlaxBartDecoderLayer(nn.Module):
343
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
344
  )
345
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
346
-
 
347
  def __call__(
348
  self,
349
  hidden_states: jnp.ndarray,
 
252
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
253
  )
254
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
255
+
256
+ @nn.remat
257
  def __call__(
258
  self,
259
  hidden_states: jnp.ndarray,
 
344
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
345
  )
346
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
347
+
348
+ @nn.remat
349
  def __call__(
350
  self,
351
  hidden_states: jnp.ndarray,