# Pre-encoding a dataset for DALLEĀ·mini

This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).

Adapt it to your own dataset and image encoder.

At the end you should have a dataset of pairs:
* a caption defined as a string
* an encoded image defined as a list of int.

In [None]:
from tqdm.notebook import tqdm

import torchvision.transforms as T

import webdataset as wds

import jax
import braceexpand
from pathlib import Path

## Configuration Parameters

In [3]:
shards = "my_images/shard-{0000..0008}.tar" # defined using braceexpand format as used by webdataset
encoded_output = Path("encoded_data") # where we will save our encoded data

VQGAN_REPO, VQGAN_COMMIT_ID = (
 "dalle-mini/vqgan_imagenet_f16_16384",
 "85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384",
)

# good defaults for a TPU v3-8
batch_size = 128 # Per device
num_workers = 8 # For parallel processing
total_bs = batch_size * jax.device_count() # You can use a smaller size while testing
save_frequency = 128 # Number of batches to create a new file (180MBĀ for f16 and 720MB for f8 per file)

In [5]:
shards = list(
 braceexpand.braceexpand(shards)
) # better display for tqdm with known length

['XXX/shard-0000.tar',
 'XXX/shard-0001.tar',
 'XXX/shard-0002.tar',
 'XXX/shard-0003.tar',
 'XXX/shard-0004.tar',
 'XXX/shard-0005.tar',
 'XXX/shard-0006.tar',
 'XXX/shard-0007.tar',
 'XXX/shard-0008.tar']

## Load data

We load data using `webdataset`.

In [None]:
ds = (
 wds.WebDataset(shards, handler=wds.warn_and_continue)
 .decode("rgb", handler=wds.warn_and_continue)
 .to_tuple("jpg", "txt") # assumes image is in `jpg` and caption in `txt`
 .batched(total_bs) # load in batch per worker (faster)
)

Note:
* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.
* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.
* you can also filter out some items using `select`.

We can now inspect our data.

In [None]:
%%time
images, captions = next(iter(ds))

In [None]:
images.shape

In [None]:
captions[:10]

In [None]:
T.ToPILImage()(images[0].permute(2, 0, 1))

Finally we create our dataloader.

In [None]:
dl = (
 wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)
) # avoid partial batch at the end of each worker

## Image encoder

We'll use a VQGAN trained with Taming Transformers and converted to a JAX model.

In [None]:
from vqgan_jax.modeling_flax_vqgan import VQModel
from flax.jax_utils import replicate

vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
vqgan_params = replicate(vqgan.params)

## Encoding

Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`.

In [None]:
from flax.training.common_utils import shard
from functools import partial


@partial(jax.pmap, axis_name="batch")
def p_encode(batch, params):
 # Not sure if we should `replicate` params, does not seem to have any effect
 _, indices = vqgan.encode(batch, params=params)
 return indices

In [None]:
import pandas as pd


def encode_dataset(dataloader, output_dir, save_frequency):
 output_dir.mkdir(parents=True, exist_ok=True)
 all_captions = []
 all_encoding = []
 n_file = 1
 for idx, (images, captions) in enumerate(tqdm(dataloader)):
 images = images.numpy()
 n = len(images) // 8 * 8
 if n != len(images):
 # get the max number of images we can (multiple of 8)
 print(f"Different sizes {n} vs {len(images)}")
 images = images[:n]
 captions = captions[:n]
 if not len(captions):
 print(f"No images/captions in batch...")
 continue
 images = shard(images)
 encoded = p_encode(images, vqgan_params)
 encoded = encoded.reshape(-1, encoded.shape[-1])
 all_captions.extend(captions)
 all_encoding.extend(encoded.tolist())

 # save files
 if (idx + 1) % save_frequency == 0:
 print(f"Saving file {n_file}")
 batch_df = pd.DataFrame.from_dict(
 {"caption": all_captions, "encoding": all_encoding}
 )
 batch_df.to_parquet(f"{output_dir}/{n_file:03d}.parquet")
 all_captions = []
 all_encoding = []
 n_file += 1

 if len(all_captions):
 print(f"Saving final file {n_file}")
 batch_df = pd.DataFrame.from_dict(
 {"caption": all_captions, "encoding": all_encoding}
 )
 batch_df.to_parquet(f"{output_dir}/{n_file:03d}.parquet")

In [None]:
encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)

----