boris commited on
Commit
2be9847
1 Parent(s): 80b41d1

feat: don't ignore mismatched

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +0 -1
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -531,7 +531,6 @@ def main():
531
  config=config,
532
  seed=training_args.seed_model,
533
  dtype=getattr(jnp, model_args.dtype),
534
- ignore_mismatched_sizes=True,
535
  )
536
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
537
  print(model.params)
 
531
  config=config,
532
  seed=training_args.seed_model,
533
  dtype=getattr(jnp, model_args.dtype),
 
534
  )
535
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
536
  print(model.params)