vqgan_f16_16384 / README.md
Pedro Cuenca
More details in model card.
ac3e482

VQGAN-f16-16384

Model Description

This is a Flax/JAX implementation of VQGAN, which learns a codebook of context-rich visual parts by leveraging both the use of convolutional methods and transformers. It was introduced in Taming Transformers for High-Resolution Image Synthesis (CVPR paper).

The model allows the encoding of images as a fixed-length sequence of tokens taken from the codebook.

This version of the model uses a reduction factor f=16 and a vocabulary of 13,384 tokens.

As an example of how the reduction factor works, images of size 256x256 are encoded to sequences of 256 tokens: 256/16 * 256/16. Images of 512x512 would result in sequences of 1024 tokens.

Datasets Used for Training

We fine-tuned on CC3M and YFCC100M to improve the encoding quality of people and faces, which are not very well represented in ImageNet. We used a subset of 2,268,720 images from CC3M and YFCC100M for this purpose.

Training Process

Finetuning was performed in PyTorch using taming-transformers. The full training process and model preparation includes these steps:

  • Pre-training on ImageNet. Previously performed. We used this checkpoint.
  • Fine-tuning, Part 1.
  • Fine-tuning, Part 2 – continuation from Part 1. The final checkpoint was uploaded to boris/vqgan_f16_16384.
  • Conversion to JAX, which is the model described in this card.

How to Use

The checkpoint can be loaded using Suraj Patil's implementation of VQModel.

  • Encoding. coming soon.
  • Decoding. coming soon.

Other

This model was successfully used as part of the implementation of DALL·E mini. Our report contains more details on how to leverage it in an image encoding / generation pipeline.