valhalla commited on
Commit
e5a52b9
1 Parent(s): a265819

add property to get num params

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +6 -0
dalle_mini/modeling_bart_flax.py CHANGED
@@ -24,6 +24,7 @@ import flax.linen as nn
24
  import jax
25
  import jax.numpy as jnp
26
  from flax.core.frozen_dict import FrozenDict, unfreeze
 
27
  from flax.linen import combine_masks, make_causal_mask
28
  from flax.linen.attention import dot_product_attention_weights
29
  from jax import lax
@@ -622,6 +623,11 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
622
  module = self.module_class(config=config, dtype=dtype)
623
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
624
 
 
 
 
 
 
625
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
626
  # init input tensors
627
  input_ids = jnp.zeros(input_shape, dtype="i4")
 
24
  import jax
25
  import jax.numpy as jnp
26
  from flax.core.frozen_dict import FrozenDict, unfreeze
27
+ from flax.traverse_util import flatten_dict
28
  from flax.linen import combine_masks, make_causal_mask
29
  from flax.linen.attention import dot_product_attention_weights
30
  from jax import lax
 
623
  module = self.module_class(config=config, dtype=dtype)
624
  super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
625
 
626
+ @property
627
+ def num_params(self):
628
+ num_params = jax.tree_map(lambda param: param.size, flatten_dict(unfreeze(self.params))).values()
629
+ return sum(list(num_params))
630
+
631
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
632
  # init input tensors
633
  input_ids = jnp.zeros(input_shape, dtype="i4")