jax>=0.2.8 jaxlib>=0.1.59 flax>=0.3.4 optax>=0.0.8 -f https://download.pytorch.org/whl/torch_stable.html torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html torchvision==0.10.0+cpu