boris commited on
Commit
44b7c3e
1 Parent(s): d483294

fix: load from checkpoint

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +8 -11
src/dalle_mini/model/modeling.py CHANGED
@@ -334,22 +334,19 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
334
 
335
  # init weights on CPU
336
  if load_on_cpu:
337
- init_fn = jax.jit(
338
- self.init_weights, static_argnames="input_shape", backend="cpu"
339
- )
340
  else:
341
- init_fn = self.init_weights
342
 
343
  # randomly initialized parameters
 
344
  if abstract_init:
345
- # init the model weights only abstractly, eval_shape will return a pytree
346
- # with the structure as weights but without any actual values, this will just contain
347
- # the shape information. Weights need to be loaded later.
348
- random_params = jax.eval_shape(
349
- init_fn, rng=self.key, input_shape=input_shape
350
- )
351
  else:
352
- random_params = init_fn(rng=self.key, input_shape=input_shape)
353
 
354
  # save required_params as set
355
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
 
334
 
335
  # init weights on CPU
336
  if load_on_cpu:
337
+ # init weights on CPU
338
+ init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
 
339
  else:
340
+ init_fn = self.init_weigths
341
 
342
  # randomly initialized parameters
343
+ random_params = self.init_weights(self.key, input_shape)
344
  if abstract_init:
345
+ # only set shape and dtype, load parameters separately
346
+ init_fn = partial(init_fn, input_shape=input_shape)
347
+ random_params = jax.eval_shape(init_fn, self.key)
 
 
 
348
  else:
349
+ random_params = init_fn(self.key, input_shape)
350
 
351
  # save required_params as set
352
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())