{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# vqgan-jax-encoding-yfcc100m" ] }, { "cell_type": "markdown", "id": "747733a4", "metadata": {}, "source": [ "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n", "\n", "This dataset was prepared by @borisdayma in Json lines format." ] }, { "cell_type": "code", "execution_count": 1, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision.datasets.folder import default_loader\n", "\n", "import jax\n", "from jax import pmap" ] }, { "cell_type": "markdown", "id": "511c3b9e", "metadata": {}, "source": [ "## VQGAN-JAX model" ] }, { "cell_type": "markdown", "id": "bb408f6c", "metadata": {}, "source": [ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities." ] }, { "cell_type": "code", "execution_count": 2, "id": "2ca50dc7", "metadata": {}, "outputs": [], "source": [ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "markdown", "id": "7b60da9a", "metadata": {}, "source": [ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": 4, "id": "29ce8b15", "metadata": {}, "outputs": [], "source": [ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "markdown", "id": "c7c4c1e6", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "id": "fd4c608e", "metadata": {}, "source": [ "I splitted the files to do the process iteratively. Pandas struggles with memory and `datasets` has problems when filtering files, as described [in this issue](https://github.com/huggingface/datasets/issues/2644)." ] }, { "cell_type": "code", "execution_count": 5, "id": "6c058636", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 6, "id": "81b19eca", "metadata": {}, "outputs": [], "source": [ "yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')\n", "# Images are 'sharded' from the following directory\n", "yfcc100m_images = yfcc100m/'data'/'images'\n", "yfcc100m_metadata_splits = yfcc100m/'metadata_splitted'\n", "yfcc100m_output = yfcc100m/'metadata_encoded'" ] }, { "cell_type": "code", "execution_count": 7, "id": "40873de9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_17'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_10'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_22'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_28'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_09'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_03'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_07'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_26'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_14'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_19'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_13'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_21'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_00'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_02'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_08'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_11'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_29'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_23'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_24'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_16'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_05'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_01'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_12'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_18'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_20'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_27'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_15'),\n", " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_06')]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_splits = [x for x in yfcc100m_metadata_splits.iterdir() if x.is_file()]\n", "all_splits" ] }, { "cell_type": "markdown", "id": "f604e3c9", "metadata": {}, "source": [ "### Cleanup" ] }, { "cell_type": "code", "execution_count": 8, "id": "dea06b92", "metadata": {}, "outputs": [], "source": [ "def image_exists(root: str, name: str, ext: str):\n", " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)\n", " return image_path.exists()" ] }, { "cell_type": "code", "execution_count": 9, "id": "1d34d7aa", "metadata": {}, "outputs": [], "source": [ "class YFC100Dataset(Dataset):\n", " def __init__(self, image_list: pd.DataFrame, images_root: str, image_size: int, max_items=None):\n", " \"\"\"\n", " :param image_list: DataFrame with clean entries - all images must exist.\n", " :param images_root: Root directory containing the images\n", " :param image_size: Image size. Source images will be resized and center-cropped.\n", " :max_items: Limit dataset size for debugging\n", " \"\"\"\n", " self.image_list = image_list\n", " self.images_root = Path(images_root)\n", " if max_items is not None: self.image_list = self.image_list[:max_items]\n", " self.image_size = image_size\n", " \n", " def __len__(self):\n", " return len(self.image_list)\n", " \n", " def _get_raw_image(self, i):\n", " image_name = self.image_list.iloc[0].key\n", " image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')\n", " return default_loader(image_path)\n", " \n", " def resize_image(self, image):\n", " s = min(image.size)\n", " r = self.image_size / s\n", " s = (round(r * image.size[1]), round(r * image.size[0]))\n", " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n", " image = TF.center_crop(image, output_size = 2 * [self.image_size])\n", " # FIXME: np.array is necessary in my installation, but it should be automatic\n", " image = torch.unsqueeze(T.ToTensor()(np.array(image)), 0)\n", " image = image.permute(0, 2, 3, 1).numpy()\n", " return image\n", " \n", " def __getitem__(self, i):\n", " image = self._get_raw_image(i)\n", " image = self.resize_image(image)\n", " # Just return the image, not the caption\n", " return image" ] }, { "cell_type": "markdown", "id": "62ad01c3", "metadata": {}, "source": [ "## Encoding" ] }, { "cell_type": "code", "execution_count": 10, "id": "88f36d0b", "metadata": {}, "outputs": [], "source": [ "def encode(model, batch):\n", " print(\"jitting encode function\")\n", " _, indices = model.encode(batch)\n", "\n", "# # FIXME: The model does not run in my computer (no cudNN currently installed) - faking it\n", "# indices = np.random.randint(0, 16384, (batch.shape[0], 256))\n", " return indices" ] }, { "cell_type": "code", "execution_count": null, "id": "d1f45dd8", "metadata": {}, "outputs": [], "source": [ "#FIXME\n", "# import random\n", "# model = {}" ] }, { "cell_type": "code", "execution_count": 11, "id": "1f35f0cb", "metadata": {}, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "\n", "def superbatch_generator(dataloader):\n", " iter_loader = iter(dataloader)\n", " for batch in iter_loader:\n", " batch = batch.squeeze(1)\n", " # Skip incomplete last batch\n", " if batch.shape[0] == dataloader.batch_size:\n", " yield shard(batch)" ] }, { "cell_type": "code", "execution_count": 13, "id": "2210705b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import jax\n", "\n", "def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):\n", " if os.path.isfile(output_jsonl):\n", " print(f\"Destination file {output_jsonl} already exists, please move away.\")\n", " return\n", " \n", " num_tpus = jax.device_count()\n", " dataloader = DataLoader(dataset, batch_size=num_tpus*batch_size, num_workers=num_workers)\n", " superbatches = superbatch_generator(dataloader)\n", " \n", " p_encoder = pmap(lambda batch: encode(model, batch))\n", "\n", " # We save each superbatch to avoid reallocation of buffers as we process them.\n", " # We keep the file open to prevent excessive file seeks.\n", " with open(output_jsonl, \"w\") as file:\n", " iterations = len(dataset) // (batch_size * num_tpus)\n", " for n in tqdm(range(iterations)):\n", " superbatch = next(superbatches)\n", " encoded = p_encoder(superbatch.numpy())\n", " encoded = encoded.reshape(-1, encoded.shape[-1])\n", "\n", " # Extract fields from the dataset internal `image_list` property, and save to disk\n", " # We need to read from the df because the Dataset only returns images\n", " start_index = n * batch_size * num_tpus\n", " end_index = (n+1) * batch_size * num_tpus\n", " keys = dataset.image_list[\"key\"][start_index:end_index].values\n", " captions = dataset.image_list[\"caption\"][start_index:end_index].values\n", "# encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n", " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded})\n", " batch_df.to_json(file, orient='records', lines=True)" ] }, { "cell_type": "code", "execution_count": 14, "id": "7704863d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04\n", "54024 selected from 500000 total entries\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Starting the local TPU driver.\n", "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n", "INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n", " 0%| | 0/31 [00:00