Spaces:
Running
Running
fix: load from checkpoint
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -334,22 +334,19 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
334 |
|
335 |
# init weights on CPU
|
336 |
if load_on_cpu:
|
337 |
-
|
338 |
-
|
339 |
-
)
|
340 |
else:
|
341 |
-
init_fn = self.
|
342 |
|
343 |
# randomly initialized parameters
|
|
|
344 |
if abstract_init:
|
345 |
-
#
|
346 |
-
|
347 |
-
|
348 |
-
random_params = jax.eval_shape(
|
349 |
-
init_fn, rng=self.key, input_shape=input_shape
|
350 |
-
)
|
351 |
else:
|
352 |
-
random_params = init_fn(
|
353 |
|
354 |
# save required_params as set
|
355 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
334 |
|
335 |
# init weights on CPU
|
336 |
if load_on_cpu:
|
337 |
+
# init weights on CPU
|
338 |
+
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
|
|
|
339 |
else:
|
340 |
+
init_fn = self.init_weigths
|
341 |
|
342 |
# randomly initialized parameters
|
343 |
+
random_params = self.init_weights(self.key, input_shape)
|
344 |
if abstract_init:
|
345 |
+
# only set shape and dtype, load parameters separately
|
346 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
347 |
+
random_params = jax.eval_shape(init_fn, self.key)
|
|
|
|
|
|
|
348 |
else:
|
349 |
+
random_params = init_fn(self.key, input_shape)
|
350 |
|
351 |
# save required_params as set
|
352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|