{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "118UKH5bWCGa" }, "source": [ "# DALL·E mini - Inference pipeline\n", "\n", "*Generate images from a text prompt*\n", "\n", "\n", "\n", "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", "\n", "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n", "\n", "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." ] }, { "cell_type": "markdown", "metadata": { "id": "dS8LbaonYm3a" }, "source": [ "## 🛠️ Installation and set-up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uzjAM2GBYpZX", "outputId": "70550075-5204-4c56-dce4-4fff061a096c" }, "outputs": [], "source": [ "# Install required libraries\n", "!pip install -q transformers\n", "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n", "!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n", "!pip install -q wandb" ] }, { "cell_type": "markdown", "metadata": { "id": "ozHzTkyv8cqU" }, "source": [ "We load required models:\n", "* dalle·mini for text to encoded images\n", "* VQGAN for decoding images\n", "* CLIP for scoring predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K6CxW2o42f-w" }, "outputs": [], "source": [ "# Model references\n", "\n", "# dalle-mini\n", "DALLE_MODEL = 'dalle-mini/dalle-mini/model-3bqwu04f:latest' # can be wandb artifact or 🤗 Hub or local folder\n", "DALLE_COMMIT_ID = None # used only with 🤗 hub\n", "\n", "# VQGAN model\n", "VQGAN_REPO = 'dalle-mini/vqgan_imagenet_f16_16384'\n", "VQGAN_COMMIT_ID = 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n", "\n", "# CLIP model\n", "CLIP_REPO = 'openai/clip-vit-base-patch16'\n", "CLIP_COMMIT_ID = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "# type used for computation - use bfloat16 on TPU's\n", "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n", "\n", "# TODO:\n", "# - we currently have an issue with model.generate() in bfloat16\n", "# - https://github.com/google/jax/pull/9089 should fix it\n", "# - remove below line and test on TPU with next release of JAX\n", "dtype = jnp.float32" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 374 }, "id": "92zYmvsQ38vL", "outputId": "909b0a3c-14cb-4722-8eb2-f876ff50257c" }, "outputs": [], "source": [ "# Load models & tokenizer\n", "from dalle_mini.model import DalleBart\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import AutoTokenizer, CLIPProcessor, FlaxCLIPModel\n", "import wandb\n", "\n", "# Load dalle-mini\n", "if ':' in DALLE_MODEL:\n", " # wandb artifact\n", " artifact = wandb.Api().artifact(DALLE_MODEL)\n", " # we only download required files (no need for opt_state which is large)\n", " model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n", " for f in model_files:\n", " artifact.get_path(f).download('model')\n", " model = DalleBart.from_pretrained('model', dtype=dtype, abstract_init=True)\n", " tokenizer = AutoTokenizer.from_pretrained('model')\n", "else:\n", " # local folder or 🤗 Hub\n", " model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True)\n", " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n", "\n", "# Load VQGAN\n", "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n", "\n", "# Load CLIP\n", "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)" ] }, { "cell_type": "markdown", "metadata": { "id": "o_vH2X1tDtzA" }, "source": [ "Model parameters are replicated on each device for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtvLoM48EeVw" }, "outputs": [], "source": [ "from flax.jax_utils import replicate\n", "\n", "# convert model parameters for inference if requested\n", "if dtype == jnp.bfloat16:\n", " model.params = model.to_bf16(model.params)\n", "\n", "model_params = replicate(model.params)\n", "vqgan_params = replicate(vqgan.params)\n", "clip_params = replicate(clip.params)" ] }, { "cell_type": "markdown", "metadata": { "id": "0A9AHQIgZ_qw" }, "source": [ "Model functions are compiled and parallelized to take advantage of multiple devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sOtoOmYsSYPz" }, "outputs": [], "source": [ "from functools import partial\n", "\n", "# model inference\n", "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n", "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n", " return model.generate(\n", " **tokenized_prompt,\n", " do_sample=True,\n", " num_beams=1,\n", " prng_key=key,\n", " params=params,\n", " top_k=top_k,\n", " top_p=top_p\n", " )\n", "\n", "# decode images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "# score images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_clip(inputs, params):\n", " logits = clip(params=params, **inputs).logits_per_image\n", " return logits" ] }, { "cell_type": "markdown", "metadata": { "id": "HmVN6IBwapBA" }, "source": [ "Keys are passed to the model on each device to generate unique inferences per device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4CTXmlUkThhX" }, "outputs": [], "source": [ "import random\n", "\n", "# create a random key\n", "seed = random.randint(0, 2**32-1)\n", "key = jax.random.PRNGKey(seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "BrnVyCo81pij" }, "source": [ "## 🖍 Text Prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "rsmj0Aj5OQox" }, "source": [ "Our model may require to normalize the prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YjjhUychOVxm" }, "outputs": [], "source": [ "from dalle_mini.text import TextNormalizer\n", "\n", "text_normalizer = TextNormalizer() if model.config.normalize_text else None" ] }, { "cell_type": "markdown", "metadata": { "id": "BQ7fymSPyvF_" }, "source": [ "Let's define a text prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_0vI9ge1oKr" }, "outputs": [], "source": [ "prompt = 'a red T-shirt'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VKjEZGjtO49k" }, "outputs": [], "source": [ "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n", "processed_prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "iFVOyYboP0L-" }, "source": [ "We repeat the prompt on each device and tokenize it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rii_FJ7POw1y" }, "outputs": [], "source": [ "# repeat the prompt on each device\n", "repeated_prompts = [processed_prompt] * jax.device_count()\n", "\n", "# tokenize\n", "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n", "tokenized_prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "_Y5dqFj7prMQ" }, "source": [ "Notes:\n", "\n", "* `0`: BOS, special token representing the beginning of a sequence\n", "* `2`: EOS, special token representing the end of a sequence\n", "* `1`: special token representing the padding of a sequence when requesting a specific length" ] }, { "cell_type": "markdown", "metadata": { "id": "2wiDtG3_SH2u" }, "source": [ "Finally we distribute the tokenized prompt onto the devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AImyrxHtR9TG" }, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "\n", "tokenized_prompt = shard(tokenized_prompt)" ] }, { "cell_type": "markdown", "metadata": { "id": "phQ9bhjRkgAZ" }, "source": [ "## 🎨 Generate images\n", "\n", "We generate images using dalle-mini model and decode them with the VQGAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d0wVkXpKqnHA" }, "outputs": [], "source": [ "# number of predictions\n", "n_predictions = 32\n", "\n", "# We can customize top_k/top_p used for generating samples\n", "gen_top_k = None\n", "gen_top_p = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDjEx9JxR3v8" }, "outputs": [], "source": [ "from flax.training.common_utils import shard_prng_key\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm.notebook import trange\n", "\n", "# generate images\n", "images = []\n", "for i in trange(n_predictions // jax.device_count()):\n", " # get a new key\n", " key, subkey = jax.random.split(key)\n", " # generate images\n", " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p)\n", " # remove BOS\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " # decode images\n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n", " for img in decoded_images:\n", " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))" ] }, { "cell_type": "markdown", "metadata": { "id": "tw02wG9zGmyB" }, "source": [ "Let's calculate their score with CLIP." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FoLXpjCmGpju" }, "outputs": [], "source": [ "# get clip scores\n", "clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n", "logits = p_clip(shard(clip_inputs), clip_params)\n", "logits = logits.squeeze().flatten()" ] }, { "cell_type": "markdown", "metadata": { "id": "4AAWRm70LgED" }, "source": [ "Let's display images ranked by CLIP score." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zsgxxubLLkIu" }, "outputs": [], "source": [ "print(f'Prompt: {prompt}\\n')\n", "for idx in logits.argsort()[::-1]:\n", " display(images[idx])\n", " print(f'Score: {logits[idx]:.2f}\\n')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "machine_shape": "hm", "name": "Copy of DALL·E mini - Inference pipeline.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }