{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# vqgan-jax-encoding-yfcc100m" ] }, { "cell_type": "markdown", "id": "ba7b31e6", "metadata": {}, "source": [ "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n", "\n", "This dataset was prepared by @borisdayma in Json lines format." ] }, { "cell_type": "code", "execution_count": 1, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision.datasets.folder import default_loader\n", "\n", "import jax\n", "from jax import pmap" ] }, { "cell_type": "markdown", "id": "511c3b9e", "metadata": {}, "source": [ "## VQGAN-JAX model" ] }, { "cell_type": "markdown", "id": "bb408f6c", "metadata": {}, "source": [ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities." ] }, { "cell_type": "code", "execution_count": 2, "id": "2ca50dc7", "metadata": {}, "outputs": [], "source": [ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "markdown", "id": "7b60da9a", "metadata": {}, "source": [ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model." ] }, { "cell_type": "markdown", "id": "ad05a1bd", "metadata": {}, "source": [ "**Disabling** Does not work in my local system right now." ] }, { "cell_type": "code", "execution_count": 3, "id": "29ce8b15", "metadata": {}, "outputs": [], "source": [ "#model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "markdown", "id": "c7c4c1e6", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 79, "id": "33861477", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 80, "id": "81b19eca", "metadata": {}, "outputs": [], "source": [ "yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')\n", "# Images are 'sharded' from the following directory\n", "yfcc100m_images = yfcc100m/'data'/'images'\n", "yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n", "yfcc100m_output = yfcc100m/'metadata_encoded.jsonl'" ] }, { "cell_type": "markdown", "id": "1c58bb4a", "metadata": {}, "source": [ "### Cleanup" ] }, { "cell_type": "markdown", "id": "1a14ae3d", "metadata": {}, "source": [ "We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas." ] }, { "cell_type": "code", "execution_count": 81, "id": "7811648c", "metadata": {}, "outputs": [], "source": [ "import datasets\n", "from datasets import Dataset, load_dataset" ] }, { "cell_type": "code", "execution_count": 82, "id": "753659fe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-57592e8ed16d752b\n", "Reusing dataset json (/home/pedro/.cache/huggingface/datasets/json/default-57592e8ed16d752b/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9)\n" ] } ], "source": [ "dataset = load_dataset(\"json\", data_files=[str(yfcc100m_metadata)])" ] }, { "cell_type": "code", "execution_count": 83, "id": "9343df1b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['photoid', 'uid', 'unickname', 'datetaken', 'dateuploaded', 'capturedevice', 'title', 'description', 'usertags', 'machinetags', 'longitude', 'latitude', 'accuracy', 'pageurl', 'downloadurl', 'licensename', 'licenseurl', 'serverid', 'farmid', 'secret', 'secretoriginal', 'ext', 'marker', 'key', 'title_clean', 'description_clean'],\n", " num_rows: 14825233\n", "})" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset['train']\n", "dataset" ] }, { "cell_type": "code", "execution_count": 84, "id": "c4794c29", "metadata": {}, "outputs": [], "source": [ "def image_exists(root: str, name: str, ext: str):\n", " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)\n", " return image_path.exists()" ] }, { "cell_type": "code", "execution_count": 90, "id": "1b500078", "metadata": {}, "outputs": [], "source": [ "def select_existing_rows(examples):\n", " # Select lists we want to keep\n", " keys = examples['key']\n", " titles_clean = examples['title_clean']\n", " descriptions_clean = examples.get('description_clean', '')\n", " exts = examples['ext']\n", " \n", " result = {'key': [], 'title_clean': [], 'description_clean': [], 'ext': []}\n", " for i, image_name in enumerate(keys):\n", " print(i, image_name)\n", " if image_exists(root=str(yfcc100m_images), name=image_name, ext='.' + exts[i]):\n", " result[\"key\"].append(image_name)\n", " result[\"title_clean\"].append(titles_clean[i])\n", " result[\"description_clean\"].append(descriptions_clean[i])\n", " result[\"ext\"].append(exts[i])\n", " print(f'returning {len(result[\"key\"])}')\n", " return result" ] }, { "cell_type": "code", "execution_count": 91, "id": "467378c1", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b72e866c3f174e9e9aa2430e204f2baf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Selecting rows with images that exist: 0%| | 0/14826 [00:00\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m filtered_dataset = dataset.map(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mselect_existing_rows\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mremove_columns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumn_names\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mbatched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnum_proc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mmap\u001b[0;34m(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m 1655\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1656\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnum_proc\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mnum_proc\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1657\u001b[0;31m return self._map_single(\n\u001b[0m\u001b[1;32m 1658\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1659\u001b[0m \u001b[0mwith_indices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwith_indices\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m }\n\u001b[1;32m 184\u001b[0m \u001b[0;31m# apply actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"DatasetDict\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;31m# re-apply format to the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/fingerprint.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[0;31m# Call actual function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 399\u001b[0m \u001b[0;31m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_map_single\u001b[0;34m(self, function, with_indices, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, desc)\u001b[0m\n\u001b[1;32m 2022\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2023\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcast_to_python_objects\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2024\u001b[0;31m \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2025\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mupdate_data\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mwriter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2026\u001b[0m \u001b[0mwriter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# close_stream=bool(buf_writer is None)) # We only close if we are writing in a file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/arrow_writer.py\u001b[0m in \u001b[0;36mwrite_batch\u001b[0;34m(self, batch_examples, writer_batch_size)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[0mtyped_sequence\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOptimizedTypedSequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_examples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcol_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtry_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcol_try_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0mtyped_sequence_examples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtyped_sequence\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 388\u001b[0;31m \u001b[0mpa_table\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpa\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pydict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtyped_sequence_examples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 389\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpa_table\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwriter_batch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/table.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.Table.from_pydict\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/array.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.asarray\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/array.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.array\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/array.pxi\u001b[0m in \u001b[0;36mpyarrow.lib._handle_arrow_array_protocol\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m~/code/hf_jax/datasets/src/datasets/arrow_writer.py\u001b[0m in \u001b[0;36m__arrow_array__\u001b[0;34m(self, type)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpa\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mtrying_type\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_py\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 101\u001b[0m raise TypeError(\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\"Specified try_type alters data. Please check that the type/feature that you provided match the type/features of the data.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/array.pxi\u001b[0m in \u001b[0;36mpyarrow.lib.Array.__getitem__\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/pyarrow/array.pxi\u001b[0m in \u001b[0;36mpyarrow.lib._normalize_index\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mIndexError\u001b[0m: index out of bounds" ] } ], "source": [ "filtered_dataset = dataset.map(\n", " select_existing_rows,\n", " remove_columns = dataset.column_names,\n", " batched = True,\n", " num_proc = 1,\n", " desc = \"Selecting rows with images that exist\"\n", ")" ] }, { "cell_type": "code", "execution_count": 109, "id": "7060ff8f", "metadata": {}, "outputs": [], "source": [ "# df['image_exists'] = df.apply(lambda row: image_exists(row['key']), axis=1)" ] }, { "cell_type": "code", "execution_count": 113, "id": "fecc9a00", "metadata": {}, "outputs": [], "source": [ "image_size = 256\n", "def image_transform(image):\n", " s = min(image.size)\n", " r = image_size / s\n", " s = (round(r * image.size[1]), round(r * image.size[0]))\n", " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n", " image = TF.center_crop(image, output_size = 2 * [image_size])\n", " image = torch.unsqueeze(T.ToTensor()(image), 0)\n", " image = image.permute(0, 2, 3, 1).numpy()\n", " return image" ] }, { "cell_type": "code", "execution_count": 98, "id": "1a065700", "metadata": {}, "outputs": [], "source": [ "class YFC100Dataset(Dataset):\n", " def __init__(self, image_list_path: str, images_root: str, image_size: int, max_items=None):\n", " \"\"\"\n", " :param image_list_path: Path to a file containing a list of all images, in jsonl format.\n", " :param images_root: Root directory containing the images\n", " :param image_size: Image size. Source images will be resized and center-cropped.\n", " :max_items: Limit dataset size for debugging\n", " \"\"\"\n", " self.image_list = pd.read_json(image_list_path, orient=\"records\", lines=True)\n", " self.images_root = Path(images_root)\n", " if max_items is not None: self.image_list = self.image_list[:max_items]\n", " self.image_size = image_size\n", " \n", " def __len__(self):\n", " return len(self.image_list)\n", " \n", " def _get_raw_image(self, i):\n", " image_name = self.image_list.iloc[0].key\n", " image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')\n", " return default_loader(image_path) if image_path.exists() else None\n", " \n", " # TODO: we could maybe use jax resizing / scaling functions\n", " def resize_image(self, image):\n", " s = min(image.size)\n", " r = self.image_size / s\n", " s = (round(r * image.size[1]), round(r * image.size[0]))\n", " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n", " image = TF.center_crop(image, output_size = 2 * [self.image_size])\n", " image = torch.unsqueeze(T.ToTensor()(image), 0)\n", " image = image.permute(0, 2, 3, 1).numpy()\n", " return image\n", " \n", " def _get_caption(self, i):\n", " # We are currently appending title and caption. Should we use another separator?\n", " row = self.image_list.iloc[i]\n", " return ' '.join(row.title_clean, row.description_clean)\n", " \n", " def __getitem__(self, i):\n", " image = self._get_raw_image(i)\n", " if image is None: return None\n", " image = self.resize_image(image)\n", " caption = self._get_caption(i)\n", " return {'image': image, 'text': caption}" ] }, { "cell_type": "code", "execution_count": 99, "id": "4ce2211f", "metadata": {}, "outputs": [], "source": [ "dataset = YFC100Dataset(\n", " image_list_path = yfc100m_metadata,\n", " images_root = yfc100m_images,\n", " image_size = 256,\n", ")" ] }, { "cell_type": "code", "execution_count": 100, "id": "cc922704", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5000" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dataset)" ] }, { "cell_type": "code", "execution_count": 102, "id": "6e47ba46", "metadata": {}, "outputs": [], "source": [ "dataloader = DataLoader(dataset, batch_size=32, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 103, "id": "c8a130eb", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Caught TypeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py\", line 287, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\", line 47, in fetch\n return self.collate_fn(data)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py\", line 86, in default_collate\n raise TypeError(default_collate_err_msg_format.format(elem_type))\nTypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found \n", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_320049/1409168804.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1202\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_task_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1203\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1228\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1229\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1230\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[0;31m# have message field\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 426\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: Caught TypeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py\", line 287, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\", line 47, in fetch\n return self.collate_fn(data)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py\", line 86, in default_collate\n raise TypeError(default_collate_err_msg_format.format(elem_type))\nTypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found \n" ] } ], "source": [ "next(iter(dataloader))" ] }, { "cell_type": "markdown", "id": "62ad01c3", "metadata": {}, "source": [ "## Encoding" ] }, { "cell_type": "code", "execution_count": 89, "id": "88f36d0b", "metadata": {}, "outputs": [], "source": [ "def encode(model, batch):\n", " print(\"jitting encode function\")\n", "# _, indices = model.encode(batch)\n", "\n", " # The model does not run in my computer (no cudNN currently installed) - faking it\n", " indices = [random.randint(0, 16384) for _ in range(256)]\n", " return indices" ] }, { "cell_type": "code", "execution_count": 90, "id": "1f35f0cb", "metadata": {}, "outputs": [], "source": [ "def superbatch_generator(dataloader, num_tpus):\n", " iter_loader = iter(dataloader)\n", " for batch in iter_loader:\n", " superbatch = [batch.squeeze(1)]\n", " try:\n", " for b in range(num_tpus-1):\n", " batch = next(iter_loader)\n", " if batch is None:\n", " break\n", " # Skip incomplete last batch\n", " if batch.shape[0] == dataloader.batch_size:\n", " superbatch.append(batch.squeeze(1))\n", " except StopIteration:\n", " pass\n", " superbatch = torch.stack(superbatch, axis=0)\n", " yield superbatch" ] }, { "cell_type": "code", "execution_count": 93, "id": "2210705b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import jax\n", "\n", "def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):\n", " if os.path.isfile(output_jsonl):\n", " print(f\"Destination file {output_jsonl} already exists, please move away.\")\n", " return\n", " \n", " num_tpus = jax.device_count()\n", " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n", " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n", " \n", " p_encoder = pmap(lambda batch: encode(model, batch))\n", "\n", " # We save each superbatch to avoid reallocation of buffers as we process them.\n", " # We keep the file open to prevent excessive file seeks.\n", " with open(output_jsonl, \"w\") as file:\n", " iterations = len(dataset) // (batch_size * num_tpus)\n", " for n in tqdm(range(iterations)):\n", " superbatch = next(superbatches)\n", " encoded = p_encoder(superbatch.numpy())\n", " encoded = encoded.reshape(-1, encoded.shape[-1])\n", "\n", " # Extract fields from the dataset internal `captions` property, and save to disk\n", " start_index = n * batch_size * num_tpus\n", " end_index = (n+1) * batch_size * num_tpus\n", " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n", " captions = dataset.captions[\"caption\"][start_index:end_index].values\n", " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n", " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n", " batch_df = batch_df.dropna()\n", " batch_df.to_json(file, orient='records', lines=True, index=None)\n", " " ] }, { "cell_type": "code", "execution_count": 94, "id": "7704863d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/78 [00:00\n", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_320049/140243368.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencode_captioned_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myfc100m_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipykernel_320049/2954345319.py\u001b[0m in \u001b[0;36mencode_captioned_dataset\u001b[0;34m(dataset, output_jsonl, batch_size, num_workers)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0miterations\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnum_tpus\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0msuperbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msuperbatches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0mencoded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp_encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msuperbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mencoded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoded\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoded\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_320049/4148450576.py\u001b[0m in \u001b[0;36msuperbatch_generator\u001b[0;34m(dataloader, num_tpus)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msuperbatch_generator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_tpus\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0miter_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miter_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0msuperbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1202\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_task_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1203\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1227\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1228\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1229\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1230\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[0;31m# have message field\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 426\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: Caught TypeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py\", line 287, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\", line 47, in fetch\n return self.collate_fn(data)\n File \"/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py\", line 86, in default_collate\n raise TypeError(default_collate_err_msg_format.format(elem_type))\nTypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found \n" ] } ], "source": [ "encode_captioned_dataset(dataset, yfc100m_output, batch_size=64, num_workers=16)" ] }, { "cell_type": "markdown", "id": "8953dd84", "metadata": {}, "source": [ "----" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }