# vqgan-jax-encoding-yfcc100m

Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.

This dataset was prepared by @borisdayma in Json lines format.

In [1]:
import io

import requests
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader

import jax
from jax import pmap

## VQGAN-JAX model

`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities.

In [2]:
from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel

We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model.

In [4]:
model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

## Dataset

I splitted the files to do the process iteratively. Pandas struggles with memory and `datasets` has problems when filtering files, as described [in this issue](https://github.com/huggingface/datasets/issues/2644).

In [5]:
import pandas as pd
from pathlib import Path

In [6]:
yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')
# Images are 'sharded' from the following directory
yfcc100m_images = yfcc100m/'data'/'images'
yfcc100m_metadata_splits = yfcc100m/'metadata_splitted'
yfcc100m_output = yfcc100m/'metadata_encoded'

In [7]:
all_splits = [x for x in yfcc100m_metadata_splits.iterdir() if x.is_file()]
all_splits

[PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_17'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_10'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_22'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_28'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_09'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_03'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_07'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_26'),
 PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitte

### Cleanup

In [8]:
def image_exists(root: str, name: str, ext: str):
    image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)
    return image_path.exists()

In [9]:
class YFC100Dataset(Dataset):
    def __init__(self, image_list: pd.DataFrame, images_root: str, image_size: int, max_items=None):
        """
        :param image_list: DataFrame with clean entries - all images must exist.
        :param images_root: Root directory containing the images
        :param image_size: Image size. Source images will be resized and center-cropped.
        :max_items: Limit dataset size for debugging
        """
        self.image_list = image_list
        self.images_root = Path(images_root)
        if max_items is not None: self.image_list = self.image_list[:max_items]
        self.image_size = image_size
        
    def __len__(self):
        return len(self.image_list)
    
    def _get_raw_image(self, i):
        image_name = self.image_list.iloc[0].key
        image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')
        return default_loader(image_path)
    
    def resize_image(self, image):
        s = min(image.size)
        r = self.image_size / s
        s = (round(r * image.size[1]), round(r * image.size[0]))
        image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
        image = TF.center_crop(image, output_size = 2 * [self.image_size])
        # FIXME: np.array is necessary in my installation, but it should be automatic
        image = torch.unsqueeze(T.ToTensor()(np.array(image)), 0)
        image = image.permute(0, 2, 3, 1).numpy()
        return image
        
    def __getitem__(self, i):
        image = self._get_raw_image(i)
        image = self.resize_image(image)
        # Just return the image, not the caption
        return image

## Encoding

In [10]:
def encode(model, batch):
    print("jitting encode function")
    _, indices = model.encode(batch)

#     # FIXME: The model does not run in my computer (no cudNN currently installed) - faking it
#     indices = np.random.randint(0, 16384, (batch.shape[0], 256))
    return indices

In [None]:
#FIXME
# import random
# model = {}

In [11]:
from flax.training.common_utils import shard

def superbatch_generator(dataloader):
    iter_loader = iter(dataloader)
    for batch in iter_loader:
        batch = batch.squeeze(1)
        # Skip incomplete last batch
        if batch.shape[0] == dataloader.batch_size:
            yield shard(batch)

In [13]:
import os
import jax

def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):
    if os.path.isfile(output_jsonl):
        print(f"Destination file {output_jsonl} already exists, please move away.")
        return
    
    num_tpus = jax.device_count()
    dataloader = DataLoader(dataset, batch_size=num_tpus*batch_size, num_workers=num_workers)
    superbatches = superbatch_generator(dataloader)
    
    p_encoder = pmap(lambda batch: encode(model, batch))

    # We save each superbatch to avoid reallocation of buffers as we process them.
    # We keep the file open to prevent excessive file seeks.
    with open(output_jsonl, "w") as file:
        iterations = len(dataset) // (batch_size * num_tpus)
        for n in tqdm(range(iterations)):
            superbatch = next(superbatches)
            encoded = p_encoder(superbatch.numpy())
            encoded = encoded.reshape(-1, encoded.shape[-1])

            # Extract fields from the dataset internal `image_list` property, and save to disk
            # We need to read from the df because the Dataset only returns images
            start_index = n * batch_size * num_tpus
            end_index = (n+1) * batch_size * num_tpus
            keys = dataset.image_list["key"][start_index:end_index].values
            captions = dataset.image_list["caption"][start_index:end_index].values
#             encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
            batch_df = pd.DataFrame.from_dict({"key": keys, "caption": captions, "encoding": encoded})
            batch_df.to_json(file, orient='records', lines=True)

In [14]:
for split in all_splits:
    print(f"Processing {split}")
    df = pd.read_json(split, orient="records", lines=True)
    df['image_exists'] = df.apply(lambda row: image_exists(yfcc100m_images, row['key'], '.' + row['ext']), axis=1)
    print(f"{len(df[df.image_exists])} selected from {len(df)} total entries")
    
    df = df[df.image_exists]
    captions = df.apply(lambda row: ' '.join([row["title_clean"], row["description_clean"]]), axis=1)
    df["caption"] = captions.values
    
    dataset = YFC100Dataset(
        image_list = df,
        images_root = yfcc100m_images,
        image_size = 256,
#         max_items = 2000,
    )
    
    encode_captioned_dataset(dataset, yfcc100m_output/split.name, batch_size=64, num_workers=16)

Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04
54024 selected from 500000 total entries


INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
  0%|                                                                                        | 0/31 [00:00<?, ?it/s]

jitting encode function


100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 10.61it/s]


Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25
99530 selected from 500000 total entries


  3%|██▌                                                                             | 1/31 [00:01<00:53,  1.79s/it]

jitting encode function


100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:03<00:00,  9.92it/s]


----