{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "118UKH5bWCGa" }, "source": [ "# DALL·E mini - Inference pipeline\n", "\n", "*Generate images from a text prompt*\n", "\n", "\n", "\n", "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", "\n", "Just want to play? Use directly [the app](https://www.craiyon.com/).\n", "\n", "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." ] }, { "cell_type": "markdown", "metadata": { "id": "dS8LbaonYm3a" }, "source": [ "## 🛠️ Installation and set-up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uzjAM2GBYpZX" }, "outputs": [], "source": [ "# Required only for colab environments + GPU\n", "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# Install required libraries\n", "!pip install -q dalle-mini\n", "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git" ] }, { "cell_type": "markdown", "metadata": { "id": "ozHzTkyv8cqU" }, "source": [ "We load required models:\n", "* DALL·E mini for text to encoded images\n", "* VQGAN for decoding images\n", "* CLIP for scoring predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K6CxW2o42f-w" }, "outputs": [], "source": [ "# Model references\n", "\n", "# dalle-mega\n", "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", "DALLE_COMMIT_ID = None\n", "\n", "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n", "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n", "\n", "# VQGAN model\n", "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yv-aR3t4Oe5v" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "# check how many devices are available\n", "jax.local_device_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "92zYmvsQ38vL" }, "outputs": [], "source": [ "# Load models & tokenizer\n", "from dalle_mini import DalleBart, DalleBartProcessor\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import CLIPProcessor, FlaxCLIPModel\n", "\n", "# Load dalle-mini\n", "model, params = DalleBart.from_pretrained(\n", " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")\n", "\n", "# Load VQGAN\n", "vqgan, vqgan_params = VQModel.from_pretrained(\n", " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "o_vH2X1tDtzA" }, "source": [ "Model parameters are replicated on each device for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtvLoM48EeVw" }, "outputs": [], "source": [ "from flax.jax_utils import replicate\n", "\n", "params = replicate(params)\n", "vqgan_params = replicate(vqgan_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "0A9AHQIgZ_qw" }, "source": [ "Model functions are compiled and parallelized to take advantage of multiple devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sOtoOmYsSYPz" }, "outputs": [], "source": [ "from functools import partial\n", "\n", "\n", "# model inference\n", "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", "def p_generate(\n", " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n", "):\n", " return model.generate(\n", " **tokenized_prompt,\n", " prng_key=key,\n", " params=params,\n", " top_k=top_k,\n", " top_p=top_p,\n", " temperature=temperature,\n", " condition_scale=condition_scale,\n", " )\n", "\n", "\n", "# decode image\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)" ] }, { "cell_type": "markdown", "metadata": { "id": "HmVN6IBwapBA" }, "source": [ "Keys are passed to the model on each device to generate unique inference per device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4CTXmlUkThhX" }, "outputs": [], "source": [ "import random\n", "\n", "# create a random key\n", "seed = random.randint(0, 2**32 - 1)\n", "key = jax.random.PRNGKey(seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "BrnVyCo81pij" }, "source": [ "## 🖍 Text Prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "rsmj0Aj5OQox" }, "source": [ "Our model requires processing prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YjjhUychOVxm" }, "outputs": [], "source": [ "from dalle_mini import DalleBartProcessor\n", "\n", "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)" ] }, { "cell_type": "markdown", "metadata": { "id": "BQ7fymSPyvF_" }, "source": [ "Let's define some text prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_0vI9ge1oKr" }, "outputs": [], "source": [ "prompts = [\n", " \"sunset over a lake in the mountains\",\n", " \"the Eiffel tower landing on the moon\",\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "XlZUG3SCLnGE" }, "source": [ "Note: we could use the same prompt multiple times for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VKjEZGjtO49k" }, "outputs": [], "source": [ "tokenized_prompts = processor(prompts)" ] }, { "cell_type": "markdown", "metadata": { "id": "-CEJBnuJOe5z" }, "source": [ "Finally we replicate the prompts onto each device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lQePgju5Oe5z" }, "outputs": [], "source": [ "tokenized_prompt = replicate(tokenized_prompts)" ] }, { "cell_type": "markdown", "metadata": { "id": "phQ9bhjRkgAZ" }, "source": [ "## 🎨 Generate images\n", "\n", "We generate images using dalle-mini model and decode them with the VQGAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d0wVkXpKqnHA" }, "outputs": [], "source": [ "# number of predictions per prompt\n", "n_predictions = 8\n", "\n", "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n", "gen_top_k = None\n", "gen_top_p = None\n", "temperature = None\n", "cond_scale = 10.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDjEx9JxR3v8" }, "outputs": [], "source": [ "from flax.training.common_utils import shard_prng_key\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm.notebook import trange\n", "\n", "print(f\"Prompts: {prompts}\\n\")\n", "# generate images\n", "images = []\n", "for i in trange(max(n_predictions // jax.device_count(), 1)):\n", " # get a new key\n", " key, subkey = jax.random.split(key)\n", " # generate images\n", " encoded_images = p_generate(\n", " tokenized_prompt,\n", " shard_prng_key(subkey),\n", " params,\n", " gen_top_k,\n", " gen_top_p,\n", " temperature,\n", " cond_scale,\n", " )\n", " # remove BOS\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " # decode images\n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n", " for decoded_img in decoded_images:\n", " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n", " images.append(img)\n", " display(img)\n", " print()" ] }, { "cell_type": "markdown", "metadata": { "id": "tw02wG9zGmyB" }, "source": [ "## 🏅 Optional: Rank images by CLIP score\n", "\n", "We can rank images according to CLIP.\n", "\n", "**Note: your session may crash if you don't have a subscription to Colab Pro.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RGjlIW_f6GA0" }, "outputs": [], "source": [ "# CLIP model\n", "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n", "CLIP_COMMIT_ID = None\n", "\n", "# Load CLIP\n", "clip, clip_params = FlaxCLIPModel.from_pretrained(\n", " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")\n", "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", "clip_params = replicate(clip_params)\n", "\n", "\n", "# score images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_clip(inputs, params):\n", " logits = clip(params=params, **inputs).logits_per_image\n", " return logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FoLXpjCmGpju" }, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "\n", "# get clip scores\n", "clip_inputs = clip_processor(\n", " text=prompts * jax.device_count(),\n", " images=images,\n", " return_tensors=\"np\",\n", " padding=\"max_length\",\n", " max_length=77,\n", " truncation=True,\n", ").data\n", "logits = p_clip(shard(clip_inputs), clip_params)\n", "\n", "# organize scores per prompt\n", "p = len(prompts)\n", "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()" ] }, { "cell_type": "markdown", "metadata": { "id": "4AAWRm70LgED" }, "source": [ "Let's now display images ranked by CLIP score." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zsgxxubLLkIu" }, "outputs": [], "source": [ "for i, prompt in enumerate(prompts):\n", " print(f\"Prompt: {prompt}\\n\")\n", " for idx in logits[i].argsort()[::-1]:\n", " display(images[idx * p + i])\n", " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n", " print()" ] }, { "cell_type": "markdown", "metadata": { "id": "oZT9i3jCjir0" }, "source": [ "## 🪄 Optional: Save your Generated Images as W&B Tables\n", "\n", "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-pSiv6Vwjkn0" }, "outputs": [], "source": [ "import wandb\n", "\n", "# Initialize a W&B run.\n", "project = \"dalle-mini-tables-colab\"\n", "run = wandb.init(project=project)\n", "\n", "# Initialize an empty W&B Tables.\n", "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n", "gen_table = wandb.Table(columns=columns)\n", "\n", "# Add data to the table.\n", "for i, prompt in enumerate(prompts):\n", " # If CLIP scores exist, sort the Images\n", " if logits is not None:\n", " idxs = logits[i].argsort()[::-1]\n", " tmp_imgs = images[i :: len(prompts)]\n", " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n", " else:\n", " tmp_imgs = images[i :: len(prompts)]\n", "\n", " # Add the data to the table.\n", " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n", "\n", "# Log the Table to W&B dashboard.\n", "wandb.log({\"Generated Images\": gen_table})\n", "\n", "# Close the W&B run.\n", "run.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "Ck2ZnHwVjnRd" }, "source": [ "Click on the link above to check out your generated images." ] } ], "metadata": { "accelerator": "GPU", "colab": { "machine_shape": "hm", "name": "DALL·E mini - Inference pipeline.ipynb", "provenance": [], "gpuType": "A100", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 }