Spaces:
Running
Running
DALL-E Mini - Generate image from text
Tentative Strategy of training (proposed by Luke and Suraj)
Data:
- Conceptual 12M Dataset (already loaded and preprocessed in TPU VM by Luke).
- YFCC100M Subset
- Coneptual Captions 3M
Architecture:
- Use the Taming Transformers VQ-GAN (with 16384 tokens)
- Use a seq2seq (language encoder --> image decoder) model with a pretrained non-autoregressive encoder (e.g. BERT) and an autoregressive decoder (like GPT).
Remaining Architecture Questions:
- Whether to freeze the text encoder?
- Whether to finetune the VQ-GAN?
- Which text encoder to use (e.g. BERT, RoBERTa, etc.)?
- Hyperparameter choices for the decoder (e.g. positional embedding, initialization, etc.)
TODO
- experiment with flax/jax and setup of the TPU instance that we should get shortly
- work on dataset loading - see suggested datasets
- Optionally create the OpenAI YFCC100M subset (see this post)
- work on text/image encoding
- concatenate inputs (not sure if we need fixed length for text or use a special token separating text & image)
- adapt training script
- create inference function
- integrate CLIP for better results (only if we have the time)
- work on a demo (streamlit or colab or maybe just HF widget)
- document (set up repo on model hub per instructions, start on README writeup…)
- help with coordinating activities & progress
Dependencies Installation
You should create a new python virtual environment and install the project dependencies inside the virtual env. You need to use the -f
(--find-links
) option for pip
to be able to find the appropriate libtpu
required for the TPU hardware:
$ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
If you use conda
, you can create the virtual env and install everything using: conda env update -f environments.yaml