VQGAN-f16-16384
Model Description
This is a Flax/JAX implementation of VQGAN, which learns a codebook of context-rich visual parts by leveraging both the use of convolutional methods and transformers. It was introduced in Taming Transformers for High-Resolution Image Synthesis (CVPR paper).
The model allows the encoding of images as a fixed-length sequence of tokens taken from the codebook.
This version of the model uses a reduction factor f=16
and a vocabulary of 13,384
tokens.
As an example of how the reduction factor works, images of size 256x256
are encoded to sequences of 256
tokens: 256/16 * 256/16
. Images of 512x512
would result in sequences of 1024
tokens.
Datasets Used for Training
- ImageNet. We didn't train this model from scratch. Instead, we started from a checkpoint pre-trained on ImageNet.
- Conceptual Captions 3M (CC3M).
- OpenAI subset of YFCC100M.
We fine-tuned on CC3M and YFCC100M to improve the encoding quality of people and faces, which are not very well represented in ImageNet. We used a subset of 2,268,720 images from CC3M and YFCC100M for this purpose.
Training Process
Finetuning was performed in PyTorch using taming-transformers. The full training process and model preparation includes these steps:
- Pre-training on ImageNet. Previously performed. We used this checkpoint.
- Fine-tuning, Part 1.
- Fine-tuning, Part 2 – continuation from Part 1. The final checkpoint was uploaded to boris/vqgan_f16_16384.
- Conversion to JAX, which is the model described in this card.
How to Use
The checkpoint can be loaded using Suraj Patil's implementation of VQModel
.
- Encoding.
coming soon
. - Decoding.
coming soon
.
Other
This model was successfully used as part of the implementation of DALL·E mini. Our report contains more details on how to leverage it in an image encoding / generation pipeline.