{ "cells": [ { "cell_type": "markdown", "id": "bf8fb38a", "metadata": {}, "source": [ "# Data Pipeline" ] }, { "cell_type": "code", "execution_count": 1, "id": "9b83dcb9", "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass, field\n", "from pathlib import Path\n", "\n", "import datasets\n", "from datasets import Dataset, load_dataset\n", "import numpy as np\n", "\n", "from transformers import BartTokenizer\n", "\n", "from tqdm import tqdm\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from flax.training.common_utils import shard" ] }, { "cell_type": "markdown", "id": "a661a89e", "metadata": {}, "source": [ "File containing image paths, captions and VQGAN-encoded indices." ] }, { "cell_type": "code", "execution_count": 2, "id": "0e84e889", "metadata": {}, "outputs": [], "source": [ "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M" ] }, { "cell_type": "markdown", "id": "7fdc640b", "metadata": {}, "source": [ "TODO: generate train/test splits if necessary." ] }, { "cell_type": "code", "execution_count": 3, "id": "cc6789b4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-91833df78e844785\n", "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n" ] } ], "source": [ "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])" ] }, { "cell_type": "code", "execution_count": 4, "id": "f3ed4919", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['image_file', 'caption', 'encoding'],\n", " num_rows: 9999\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "a70c7354", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['image_file', 'caption', 'encoding'],\n", " num_rows: 9999\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset[\"train\"]\n", "dataset" ] }, { "cell_type": "markdown", "id": "a73454cf", "metadata": {}, "source": [ "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX." ] }, { "cell_type": "markdown", "id": "7c0fa992", "metadata": {}, "source": [ "## Preprocessing" ] }, { "cell_type": "markdown", "id": "a0e36582", "metadata": {}, "source": [ "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions." ] }, { "cell_type": "code", "execution_count": 6, "id": "d46f6ac5", "metadata": {}, "outputs": [], "source": [ "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n", "max_length = 256 # Read from data_args.max_source_length\n", "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n", "image_bos = 16384 # Max token is 16383 in our VQGAN configuration" ] }, { "cell_type": "code", "execution_count": 7, "id": "4cac6643", "metadata": {}, "outputs": [], "source": [ "def preprocess_function(examples):\n", " inputs = examples[\"caption\"]\n", "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n", " model_inputs = tokenizer(\n", " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n", " )\n", "\n", " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n", "\n", " return model_inputs" ] }, { "cell_type": "code", "execution_count": 8, "id": "e6a4cb91", "metadata": {}, "outputs": [], "source": [ "num_workers = 48 # We have 96 processors in the TPU\n", "column_names = dataset.column_names\n", "input_dataset = dataset.map(preprocess_function,\n", " remove_columns=column_names,\n", " batched=True,\n", " num_proc=48\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "id": "a9b1b467", "metadata": {}, "outputs": [], "source": [ "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n", " \"\"\"\n", " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n", " Shuffle batches if `shuffle` is `True`.\n", " \"\"\"\n", " steps_per_epoch = len(dataset) // batch_size\n", "\n", " if shuffle:\n", " batch_idx = jax.random.permutation(rng, len(dataset))\n", " else:\n", " batch_idx = jnp.arange(len(dataset))\n", "\n", " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n", " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n", "\n", " for idx in batch_idx:\n", " batch = dataset[idx] \n", " batch = {k: jnp.array(v) for k, v in batch.items()}\n", " batch = shard(batch)\n", " yield batch" ] }, { "cell_type": "code", "execution_count": 10, "id": "0a628505", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Starting the local TPU driver.\n", "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n", "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n" ] } ], "source": [ "rng = jax.random.PRNGKey(23) # Use training_args.seed\n", "batch_size = 64 # Per device\n", "super_batch_size = batch_size * jax.device_count()" ] }, { "cell_type": "code", "execution_count": 11, "id": "b3a5ce7d", "metadata": {}, "outputs": [], "source": [ "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)" ] }, { "cell_type": "code", "execution_count": 12, "id": "67aa8f9c", "metadata": {}, "outputs": [], "source": [ "superbatch = next(iter(loader))" ] }, { "cell_type": "code", "execution_count": 13, "id": "7cd99402", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['attention_mask', 'input_ids', 'labels'])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "superbatch.keys()" ] }, { "cell_type": "code", "execution_count": 14, "id": "652a4a9e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "8" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(superbatch[\"labels\"])" ] }, { "cell_type": "code", "execution_count": 15, "id": "de7de4e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(8, 64, 257)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "superbatch[\"labels\"].shape" ] }, { "cell_type": "markdown", "id": "6800153b", "metadata": {}, "source": [ "Any image sequence should begin with `image_bos`:" ] }, { "cell_type": "code", "execution_count": 16, "id": "cfe23a71", "metadata": {}, "outputs": [], "source": [ "assert superbatch[\"labels\"][1][5][0].item() == image_bos" ] }, { "cell_type": "code", "execution_count": null, "id": "0fb899b4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }