FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 RUN apt-get update && apt-get install -y \ git \ python3 \ python3-pip \ && rm -rf /var/lib/apt/lists/* RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && pip install -q \ git+https://github.com/borisdayma/dalle-mini.git \ git+https://github.com/patil-suraj/vqgan-jax.git RUN pip install jupyter WORKDIR /workspace