{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# Pre-encoding a dataset for DALLEĀ·mini" ] }, { "cell_type": "markdown", "id": "ba7b31e6", "metadata": {}, "source": [ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n", "\n", "Adapt it to your own dataset and image encoder.\n", "\n", "At the end you should have a dataset of pairs:\n", "* a caption defined as a string\n", "* an encoded image defined as a list of int." ] }, { "cell_type": "code", "execution_count": null, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "from tqdm.notebook import tqdm\n", "\n", "import torchvision.transforms as T\n", "\n", "import webdataset as wds\n", "\n", "import jax\n", "import braceexpand\n", "from pathlib import Path" ] }, { "cell_type": "markdown", "id": "c7c4c1e6", "metadata": {}, "source": [ "## Configuration Parameters" ] }, { "cell_type": "code", "execution_count": 3, "id": "1265dbfe", "metadata": {}, "outputs": [], "source": [ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n", "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n", "\n", "VQGAN_REPO, VQGAN_COMMIT_ID = (\n", " \"dalle-mini/vqgan_imagenet_f16_16384\",\n", " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n", ")\n", "\n", "# good defaults for a TPU v3-8\n", "batch_size = 128 # Per device\n", "num_workers = 8 # For parallel processing\n", "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n", "save_frequency = 128 # Number of batches to create a new file (180MBĀ for f16 and 720MB for f8 per file)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['XXX/shard-0000.tar',\n", " 'XXX/shard-0001.tar',\n", " 'XXX/shard-0002.tar',\n", " 'XXX/shard-0003.tar',\n", " 'XXX/shard-0004.tar',\n", " 'XXX/shard-0005.tar',\n", " 'XXX/shard-0006.tar',\n", " 'XXX/shard-0007.tar',\n", " 'XXX/shard-0008.tar']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shards = list(\n", " braceexpand.braceexpand(shards)\n", ") # better display for tqdm with known length" ] }, { "cell_type": "markdown", "id": "75dba8e2", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "id": "a1e8fb95", "metadata": {}, "source": [ "We load data using `webdataset`." ] }, { "cell_type": "code", "execution_count": null, "id": "9ef5de9e", "metadata": {}, "outputs": [], "source": [ "ds = (\n", " wds.WebDataset(shards, handler=wds.warn_and_continue)\n", " .decode(\"rgb\", handler=wds.warn_and_continue)\n", " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n", " .batched(total_bs) # load in batch per worker (faster)\n", ")" ] }, { "cell_type": "markdown", "id": "90981824", "metadata": {}, "source": [ "Note:\n", "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n", "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n", "* you can also filter out some items using `select`." ] }, { "cell_type": "markdown", "id": "129c377d", "metadata": {}, "source": [ "We can now inspect our data." ] }, { "cell_type": "code", "execution_count": null, "id": "8cac98cb", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "images, captions = next(iter(ds))" ] }, { "cell_type": "code", "execution_count": null, "id": "cd268fbf", "metadata": {}, "outputs": [], "source": [ "images.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5acfc4d8", "metadata": {}, "outputs": [], "source": [ "captions[:10]" ] }, { "cell_type": "code", "execution_count": null, "id": "c24693c0", "metadata": {}, "outputs": [], "source": [ "T.ToPILImage()(images[0].permute(2, 0, 1))" ] }, { "cell_type": "markdown", "id": "3059ffb1", "metadata": {}, "source": [ "Finally we create our dataloader." ] }, { "cell_type": "code", "execution_count": null, "id": "c227c551", "metadata": {}, "outputs": [], "source": [ "dl = (\n", " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n", ") # avoid partial batch at the end of each worker" ] }, { "cell_type": "markdown", "id": "a354472b", "metadata": {}, "source": [ "## Image encoder\n", "\n", "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": null, "id": "47a8b818", "metadata": { "scrolled": true }, "outputs": [], "source": [ "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from flax.jax_utils import replicate\n", "\n", "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "markdown", "id": "62ad01c3", "metadata": {}, "source": [ "## Encoding" ] }, { "cell_type": "markdown", "id": "20357f74", "metadata": {}, "source": [ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`." ] }, { "cell_type": "code", "execution_count": null, "id": "322a4619", "metadata": {}, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "from functools import partial\n", "\n", "\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_encode(batch, params):\n", " # Not sure if we should `replicate` params, does not seem to have any effect\n", " _, indices = vqgan.encode(batch, params=params)\n", " return indices" ] }, { "cell_type": "code", "execution_count": null, "id": "ff6c10d4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "def encode_dataset(dataloader, output_dir, save_frequency):\n", " output_dir.mkdir(parents=True, exist_ok=True)\n", " all_captions = []\n", " all_encoding = []\n", " n_file = 1\n", " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n", " images = images.numpy()\n", " n = len(images) // 8 * 8\n", " if n != len(images):\n", " # get the max number of images we can (multiple of 8)\n", " print(f\"Different sizes {n} vs {len(images)}\")\n", " images = images[:n]\n", " captions = captions[:n]\n", " if not len(captions):\n", " print(f\"No images/captions in batch...\")\n", " continue\n", " images = shard(images)\n", " encoded = p_encode(images, vqgan_params)\n", " encoded = encoded.reshape(-1, encoded.shape[-1])\n", " all_captions.extend(captions)\n", " all_encoding.extend(encoded.tolist())\n", "\n", " # save files
 if (idx + 1) % save_frequency == 0:
 print(f"Saving file {n_file}")
 batch_df = pd.DataFrame.from_dict(
 {"caption": all_captions, "encoding": all_encoding}
 )
 batch_df.to_parquet(f"{output_dir}/{n_file:03d}.parquet")
 all_captions = []
 all_encoding = []
 n_file += 1

 if len(all_captions):
 print(f"Saving final file {n_file}")
 batch_df = pd.DataFrame.from_dict(
 {"caption": all_captions, "encoding": all_encoding}
 )
 batch_df.to_parquet(f"{output_dir}/{n_file:03d}.parquet")" ] }, { "cell_type": "code", "execution_count": null, "id": "7704863d", "metadata": {}, "outputs": [], "source": [ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)" ] }, { "cell_type": "markdown", "id": "8953dd84", "metadata": {}, "source": [ "----" ]