datasets >= 1.8.0 jax>=0.2.17 jaxlib>=0.1.68 flax>=0.3.5 optax>=0.0.8