-f https://storage.googleapis.com/jax-releases/libtpu_releases.html jax[tpu]==0.4.5 flax==0.6.7 transformers chex datasets optax orbax ftfy tensorboard Jinja2 --extra-index-url https://download.pytorch.org/whl/cpu torch torchvision