boris commited on
Commit
fa72aa7
1 Parent(s): 0952927

feat(modeling): simplify abstract_init

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +7 -4
src/dalle_mini/model/modeling.py CHANGED
@@ -334,7 +334,9 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
334
 
335
  # init weights on CPU
336
  if load_on_cpu:
337
- init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
 
 
338
  else:
339
  init_fn = self.init_weights
340
 
@@ -343,10 +345,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
343
  # init the model weights only abstractly, eval_shape will return a pytree
344
  # with the structure as weights but without any actual values, this will just contain
345
  # the shape information. Weights need to be loaded later.
346
- init_fn = partial(init_fn, input_shape=input_shape)
347
- random_params = jax.eval_shape(init_fn, self.key)
 
348
  else:
349
- random_params = init_fn(self.key, input_shape)
350
 
351
  # save required_params as set
352
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
 
334
 
335
  # init weights on CPU
336
  if load_on_cpu:
337
+ init_fn = jax.jit(
338
+ self.init_weights, static_argnames="input_shape", backend="cpu"
339
+ )
340
  else:
341
  init_fn = self.init_weights
342
 
 
345
  # init the model weights only abstractly, eval_shape will return a pytree
346
  # with the structure as weights but without any actual values, this will just contain
347
  # the shape information. Weights need to be loaded later.
348
+ random_params = jax.eval_shape(
349
+ init_fn, rng=self.key, input_shape=input_shape
350
+ )
351
  else:
352
+ random_params = init_fn(rng=self.key, input_shape=input_shape)
353
 
354
  # save required_params as set
355
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())