# Data Pipeline

In [1]:
from dataclasses import dataclass, field
from pathlib import Path

import datasets
from datasets import Dataset, load_dataset
import numpy as np

from transformers import BartTokenizer

from tqdm import tqdm

import jax
import jax.numpy as jnp

from flax.training.common_utils import shard

File containing image paths, captions and VQGAN-encoded indices.

In [2]:
datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M

TODO: generate train/test splits if necessary.

In [3]:
dataset = load_dataset('csv', delimiter='\t', data_files=[datafile])

Using custom data configuration default-91833df78e844785
Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)


In [4]:
dataset

DatasetDict({
 train: Dataset({
 features: ['image_file', 'caption', 'encoding'],
 num_rows: 9999
 })
})

In [5]:
dataset = dataset["train"]
dataset

Dataset({
 features: ['image_file', 'caption', 'encoding'],
 num_rows: 9999
})

We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX.

## Preprocessing

The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions.

In [6]:
# Setting padding="max_length" as we need fixed length inputs for jitted functions
max_length = 256 # Read from data_args.max_source_length
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
image_bos = 16384 # Max token is 16383 in our VQGAN configuration

In [7]:
def preprocess_function(examples):
 inputs = examples["caption"]
# inputs = [prefix + inp for inp in inputs] # Do we need this?
 model_inputs = tokenizer(
 inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np"
 )

 model_inputs["labels"] = [[image_bos] + eval(indices) for indices in examples['encoding']]

 return model_inputs

In [8]:
num_workers = 48 # We have 96 processors in the TPU
column_names = dataset.column_names
input_dataset = dataset.map(preprocess_function,
 remove_columns=column_names,
 batched=True,
 num_proc=48
)

In [9]:
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
 """
 Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
 Shuffle batches if `shuffle` is `True`.
 """
 steps_per_epoch = len(dataset) // batch_size

 if shuffle:
 batch_idx = jax.random.permutation(rng, len(dataset))
 else:
 batch_idx = jnp.arange(len(dataset))

 batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
 batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

 for idx in batch_idx:
 batch = dataset[idx] 
 batch = {k: jnp.array(v) for k, v in batch.items()}
 batch = shard(batch)
 yield batch

In [10]:
rng = jax.random.PRNGKey(23) # Use training_args.seed
batch_size = 64 # Per device
super_batch_size = batch_size * jax.device_count()

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Host TPU Interpreter


In [11]:
loader = data_loader(rng, input_dataset, batch_size=super_batch_size)

In [12]:
superbatch = next(iter(loader))

In [13]:
superbatch.keys()

dict_keys(['attention_mask', 'input_ids', 'labels'])

In [14]:
len(superbatch["labels"])

8

In [15]:
superbatch["labels"].shape

(8, 64, 257)

Any image sequence should begin with `image_bos`:

In [16]:
assert superbatch["labels"][1][5][0].item() == image_bos