{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "118UKH5bWCGa" }, "source": [ "# DALL·E mini - Inference pipeline\n", "\n", "*Generate images from a text prompt*\n", "\n", "\n", "\n", "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", "\n", "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n", "\n", "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." ] }, { "cell_type": "markdown", "metadata": { "id": "dS8LbaonYm3a" }, "source": [ "## 🛠️ Installation and set-up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uzjAM2GBYpZX" }, "outputs": [], "source": [ "# Install required libraries\n", "!pip install -q transformers\n", "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n", "!pip install -q git+https://github.com/borisdayma/dalle-mini.git" ] }, { "cell_type": "markdown", "metadata": { "id": "ozHzTkyv8cqU" }, "source": [ "We load required models:\n", "* dalle·mini for text to encoded images\n", "* VQGAN for decoding images\n", "* CLIP for scoring predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K6CxW2o42f-w" }, "outputs": [], "source": [ "# Model references\n", "\n", "# dalle-mini\n", "DALLE_MODEL = \"dalle-mini/dalle-mini/model-1reghx5l:latest\" # can be wandb artifact or 🤗 Hub or local folder\n", "DALLE_COMMIT_ID = None\n", "\n", "# VQGAN model\n", "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n", "\n", "# CLIP model\n", "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n", "CLIP_COMMIT_ID = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yv-aR3t4Oe5v" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "# check how many devices are available\n", "jax.local_device_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HWnQrQuXOe5w" }, "outputs": [], "source": [ "# type used for computation - use bfloat16 on TPU's\n", "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n", "\n", "# TODO: fix issue with bfloat16\n", "dtype = jnp.float32" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "92zYmvsQ38vL" }, "outputs": [], "source": [ "# Load models & tokenizer\n", "from dalle_mini.model import DalleBart, DalleBartTokenizer\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import CLIPProcessor, FlaxCLIPModel\n", "import wandb\n", "\n", "# Load dalle-mini\n", "model = DalleBart.from_pretrained(\n", " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n", ")\n", "tokenizer = DalleBartTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n", "\n", "# Load VQGAN\n", "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n", "\n", "# Load CLIP\n", "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", "processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)" ] }, { "cell_type": "markdown", "metadata": { "id": "o_vH2X1tDtzA" }, "source": [ "Model parameters are replicated on each device for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtvLoM48EeVw" }, "outputs": [], "source": [ "from flax.jax_utils import replicate\n", "\n", "# convert model parameters for inference if requested\n", "if dtype == jnp.bfloat16:\n", " model.params = model.to_bf16(model.params)\n", "\n", "model_params = replicate(model.params)\n", "vqgan_params = replicate(vqgan.params)\n", "clip_params = replicate(clip.params)" ] }, { "cell_type": "markdown", "metadata": { "id": "0A9AHQIgZ_qw" }, "source": [ "Model functions are compiled and parallelized to take advantage of multiple devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sOtoOmYsSYPz" }, "outputs": [], "source": [ "from functools import partial\n", "\n", "# model inference\n", "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n", "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n", " return model.generate(\n", " **tokenized_prompt,\n", " do_sample=True,\n", " num_beams=1,\n", " prng_key=key,\n", " params=params,\n", " top_k=top_k,\n", " top_p=top_p,\n", " max_length=257\n", " )\n", "\n", "\n", "# decode images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "\n", "# score images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_clip(inputs, params):\n", " logits = clip(params=params, **inputs).logits_per_image\n", " return logits" ] }, { "cell_type": "markdown", "metadata": { "id": "HmVN6IBwapBA" }, "source": [ "Keys are passed to the model on each device to generate unique inference per device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4CTXmlUkThhX" }, "outputs": [], "source": [ "import random\n", "\n", "# create a random key\n", "seed = random.randint(0, 2**32 - 1)\n", "key = jax.random.PRNGKey(seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "BrnVyCo81pij" }, "source": [ "## 🖍 Text Prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "rsmj0Aj5OQox" }, "source": [ "Our model may require to normalize the prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YjjhUychOVxm" }, "outputs": [], "source": [ "from dalle_mini.text import TextNormalizer\n", "\n", "text_normalizer = TextNormalizer() if model.config.normalize_text else None" ] }, { "cell_type": "markdown", "metadata": { "id": "BQ7fymSPyvF_" }, "source": [ "Let's define a text prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_0vI9ge1oKr" }, "outputs": [], "source": [ "prompt = \"a blue table\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VKjEZGjtO49k" }, "outputs": [], "source": [ "processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n", "processed_prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "QUzYACWxOe5z" }, "source": [ "We tokenize the prompt." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n8e7MvGwOe5z" }, "outputs": [], "source": [ "tokenized_prompt = tokenizer(\n", " processed_prompt,\n", " return_tensors=\"jax\",\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=128,\n", ").data\n", "tokenized_prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "_Y5dqFj7prMQ" }, "source": [ "Notes:\n", "\n", "* `0`: BOS, special token representing the beginning of a sequence\n", "* `2`: EOS, special token representing the end of a sequence\n", "* `1`: special token representing the padding of a sequence when requesting a specific length" ] }, { "cell_type": "markdown", "metadata": { "id": "-CEJBnuJOe5z" }, "source": [ "Finally we replicate it onto each device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lQePgju5Oe5z" }, "outputs": [], "source": [ "tokenized_prompt = replicate(tokenized_prompt)" ] }, { "cell_type": "markdown", "metadata": { "id": "phQ9bhjRkgAZ" }, "source": [ "## 🎨 Generate images\n", "\n", "We generate images using dalle-mini model and decode them with the VQGAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d0wVkXpKqnHA" }, "outputs": [], "source": [ "# number of predictions\n", "n_predictions = 32\n", "\n", "# We can customize top_k/top_p used for generating samples\n", "gen_top_k = None\n", "gen_top_p = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDjEx9JxR3v8" }, "outputs": [], "source": [ "from flax.training.common_utils import shard_prng_key\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm.notebook import trange\n", "\n", "# generate images\n", "images = []\n", "for i in trange(n_predictions // jax.device_count()):\n", " # get a new key\n", " key, subkey = jax.random.split(key)\n", " # generate images\n", " encoded_images = p_generate(\n", " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n", " )\n", " # remove BOS\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " # decode images\n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n", " for img in decoded_images:\n", " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))" ] }, { "cell_type": "markdown", "metadata": { "id": "tw02wG9zGmyB" }, "source": [ "Let's calculate their score with CLIP." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FoLXpjCmGpju" }, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "\n", "# get clip scores\n", "clip_inputs = processor(\n", " text=[prompt] * jax.device_count(),\n", " images=images,\n", " return_tensors=\"np\",\n", " padding=\"max_length\",\n", " max_length=77,\n", " truncation=True,\n", ").data\n", "logits = p_clip(shard(clip_inputs), clip_params)\n", "logits = logits.squeeze().flatten()" ] }, { "cell_type": "markdown", "metadata": { "id": "4AAWRm70LgED" }, "source": [ "Let's display images ranked by CLIP score." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zsgxxubLLkIu" }, "outputs": [], "source": [ "print(f\"Prompt: {prompt}\\n\")\n", "for idx in logits.argsort()[::-1]:\n", " display(images[idx])\n", " print(f\"Score: {logits[idx]:.2f}\\n\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "include_colab_link": true, "machine_shape": "hm", "name": "DALL·E mini - Inference pipeline.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 }