{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# vqgan-jax-encoding-with-captions" ] }, { "cell_type": "markdown", "id": "875c82b3", "metadata": {}, "source": [ "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n", "\n", "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model." ] }, { "cell_type": "code", "execution_count": 1, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import jax\n", "from jax import pmap" ] }, { "cell_type": "markdown", "id": "511c3b9e", "metadata": {}, "source": [ "## VQGAN-JAX model" ] }, { "cell_type": "markdown", "id": "bb408f6c", "metadata": {}, "source": [ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities." ] }, { "cell_type": "code", "execution_count": 2, "id": "2ca50dc7", "metadata": {}, "outputs": [], "source": [ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "markdown", "id": "7b60da9a", "metadata": {}, "source": [ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": 3, "id": "29ce8b15", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "db406bdfc5d5428eaeae1631a04989dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/433 [00:00