jax~=0.4.13 jaxlib~=0.4.13 flax~=0.7.1 fjutils~=0.0.15 numpy==1.25.2 typing~=3.7.4.3 transformers~=4.31.0 einops~=0.6.1 optax~=0.1.7 msgpack~=1.0.5 ipython~=8.14.0 tqdm==4.65.0 datasets==2.14.3 setuptools~=60.0.0 easydel wandb tensorboard torch