expecial for version of jax, jaxlib and flax.
jax[tpu]==0.4.5flax==0.6.7
Thanks
· Sign up or log in to comment