optax flax matplotlib jax2d