{ "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# vqgan-jax-encoding-with-captions" ] }, { "cell_type": "markdown", "id": "875c82b3", "metadata": {}, "source": [ "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n", "\n", "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model." ] }, { "cell_type": "code", "execution_count": 1, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import jax\n", "from jax import pmap" ] }, { "cell_type": "markdown", "id": "511c3b9e", "metadata": {}, "source": [ "## VQGAN-JAX model" ] }, { "cell_type": "code", "execution_count": 2, "id": "2ca50dc7", "metadata": {}, "outputs": [], "source": [ "from vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "markdown", "id": "7b60da9a", "metadata": {}, "source": [ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": 3, "id": "29ce8b15", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "db406bdfc5d5428eaeae1631a04989dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/433 [00:00