vqgan_f16_16384 / README.md
Pedro Cuenca
New section: related models in the hub.
3850341

VQGAN-f16-16384

Model Description

This is a Flax/JAX implementation of VQGAN, which learns a codebook of context-rich visual parts by leveraging both the use of convolutional methods and transformers. It was introduced in Taming Transformers for High-Resolution Image Synthesis (CVPR paper).

The model allows the encoding of images as a fixed-length sequence of tokens taken from the codebook.

This version of the model uses a reduction factor f=16 and a vocabulary of 13,384 tokens.

As an example of how the reduction factor works, images of size 256x256 are encoded to sequences of 256 tokens: 256/16 * 256/16. Images of 512x512 would result in sequences of 1024 tokens.

Datasets Used for Training

We fine-tuned on CC3M and YFCC100M to improve the encoding quality of people and faces, which are not very well represented in ImageNet. We used a subset of 2,268,720 images from CC3M and YFCC100M for this purpose.

Training Process

Finetuning was performed in PyTorch using taming-transformers. The full training process and model preparation includes these steps:

  • Pre-training on ImageNet. Previously performed. We used this checkpoint.
  • Fine-tuning, Part 1.
  • Fine-tuning, Part 2 – continuation from Part 1. The final checkpoint was uploaded to boris/vqgan_f16_16384.
  • Conversion to JAX, which is the model described in this card.

How to Use

The checkpoint can be loaded using Suraj Patil's implementation of VQModel.

  • Example notebook, heavily based in work by Suraj: Open In Colab

  • Batch encoding using JAX pmap, complete example including data loading with PyTorch:

# VQGAN-JAX - pmap encoding HowTo

import numpy as np

# For data loading
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from torchvision.transforms import InterpolationMode

# For data saving
from pathlib import Path
import pandas as pd
from tqdm import tqdm

import jax
from jax import pmap

from vqgan_jax.modeling_flax_vqgan import VQModel

## Params and arguments

# List of paths containing images to encode
image_list = '/sddata/dalle-mini/CC12M/10k.tsv'
output_tsv = 'output.tsv'                # Encoded results
batch_size = 64
num_workers = 4    # TPU v3-8s have 96 cores, so feel free to increase this number when necessary

# Load model
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

## Data Loading.

# Simple torch Dataset to load images from paths.
# You can use your own pipeline instead.
class ImageDataset(Dataset):
    def __init__(self, image_list_path: str, image_size: int, max_items=None):
        """
        :param image_list_path: Path to a file containing a list of all images. We assume absolute paths for now.
        :param image_size: Image size. Source images will be resized and center-cropped.
        :max_items: Limit dataset size for debugging
        """
        self.image_list = pd.read_csv(image_list_path, sep='\t', header=None)
        if max_items is not None: self.image_list = self.image_list[:max_items]
        self.image_size = image_size
        
    def __len__(self):
        return len(self.image_list)
    
    def _get_raw_image(self, i):
        image_path = Path(self.image_list.iloc[i][0])
        return default_loader(image_path)
    
    def resize_image(self, image):
        s = min(image.size)
        r = self.image_size / s
        s = (round(r * image.size[1]), round(r * image.size[0]))
        image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
        image = TF.center_crop(image, output_size = 2 * [self.image_size])
        image = np.expand_dims(np.array(image), axis=0)
        return image
    
    def __getitem__(self, i):
        image = self._get_raw_image(i)
        return self.resize_image(image)

## Encoding

# Encoding function to be parallelized with `pmap`
# Note: images have to be square
def encode(model, batch):
    _, indices = model.encode(batch)
    return indices

# Alternative: create a batch with num_tpus*batch_size and use `shard` to distribute.
def superbatch_generator(dataloader, num_tpus):
    iter_loader = iter(dataloader)
    for batch in iter_loader:
        superbatch = [batch.squeeze(1)]
        try:
            for _ in range(num_tpus-1):
                batch = next(iter_loader)
                if batch is None:
                    break
                # Skip incomplete last batch
                if batch.shape[0] == dataloader.batch_size:
                    superbatch.append(batch.squeeze(1))
        except StopIteration:
            pass
        superbatch = torch.stack(superbatch, axis=0)
        yield superbatch

def encode_dataset(dataset, batch_size=32):
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    superbatches = superbatch_generator(dataloader, num_tpus=jax.device_count())
    
    num_tpus = jax.device_count()
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)
    
    p_encoder = pmap(lambda batch: encode(model, batch))

    # Save each superbatch to avoid reallocation of buffers as we process them.
    # Keep the file open to prevent excessive file seeks.
    with open(output_tsv, "w") as file:
        iterations = len(dataset) // (batch_size * num_tpus)
        for n in tqdm(range(iterations)):
            superbatch = next(superbatches)
            encoded = p_encoder(superbatch.numpy())
            encoded = encoded.reshape(-1, encoded.shape[-1])

            # Extract paths from the dataset, save paths and encodings (as string)
            start_index = n * batch_size * num_tpus
            end_index = (n+1) * batch_size * num_tpus
            paths = dataset.image_list[start_index:end_index][0].values
            encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
            batch_df = pd.DataFrame.from_dict({"image_file": paths, "encoding": encoded_as_string})
            batch_df.to_csv(file, sep='\t', header=(n==0), index=None)
            
dataset = ImageDataset(image_list, image_size=256)
encoded_dataset = encode_dataset(dataset, batch_size=batch_size)

Related Models in the Hub

  • PyTorch version of VQGAN, trained on the same datasets described here: boris/vqgan_f16_16384.
  • DALL·E mini, a Flax/JAX simplified implementation of OpenAI's DALL·E.

Other

This model was successfully used as part of the implementation of DALL·E mini. Our report contains more details on how to leverage it in an image encoding / generation pipeline.