{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "118UKH5bWCGa"
},
"source": [
"# DALL·E mini - Inference pipeline\n",
"\n",
"*Generate images from a text prompt*\n",
"\n",
"\n",
"\n",
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
"\n",
"Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
"\n",
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dS8LbaonYm3a"
},
"source": [
"## 🛠️ Installation and set-up"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uzjAM2GBYpZX",
"outputId": "70550075-5204-4c56-dce4-4fff061a096c"
},
"outputs": [],
"source": [
"# Install required libraries\n",
"!pip install -q transformers\n",
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
"!pip install -q wandb"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ozHzTkyv8cqU"
},
"source": [
"We load required models:\n",
"* dalle·mini for text to encoded images\n",
"* VQGAN for decoding images\n",
"* CLIP for scoring predictions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "K6CxW2o42f-w"
},
"outputs": [],
"source": [
"# Model references\n",
"\n",
"# dalle-mini\n",
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
"DALLE_COMMIT_ID = None\n",
"\n",
"# VQGAN model\n",
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
"\n",
"# CLIP model\n",
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
"CLIP_COMMIT_ID = None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mjnp\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# check how many devices are available\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlocal_device_count\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:330\u001b[0m, in \u001b[0;36mlocal_device_count\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlocal_device_count\u001b[39m(backend: Optional[Union[\u001b[38;5;28mstr\u001b[39m, XlaBackend]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mint\u001b[39m:\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m\"\"\"Returns the number of devices addressable by this process.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 330\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(\u001b[43mget_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlocal_device_count())\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:298\u001b[0m, in \u001b[0;36mget_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(maxsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# don't use util.memoize because there is no X64 dependence.\u001b[39;00m\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_backend\u001b[39m(platform\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_get_backend_uncached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:281\u001b[0m, in \u001b[0;36m_get_backend_uncached\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(platform, (\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m), \u001b[38;5;28mstr\u001b[39m)):\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m platform\n\u001b[0;32m--> 281\u001b[0m bs \u001b[38;5;241m=\u001b[39m \u001b[43mbackends\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m platform \u001b[38;5;241m=\u001b[39m (platform \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_xla_backend \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_platform_name\n\u001b[1;32m 283\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m platform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:231\u001b[0m, in \u001b[0;36mbackends\u001b[0;34m()\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m platform, priority \u001b[38;5;129;01min\u001b[39;00m platforms_and_priorites:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43m_init_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 232\u001b[0m _backends[platform] \u001b[38;5;241m=\u001b[39m backend\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m priority \u001b[38;5;241m>\u001b[39m default_priority:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:260\u001b[0m, in \u001b[0;36m_init_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mplatform\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 259\u001b[0m logging\u001b[38;5;241m.\u001b[39mvlog(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInitializing backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m platform)\n\u001b[0;32m--> 260\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43mfactory\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# TODO(skye): consider raising more descriptive errors directly from backend\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# factories instead of returning None.\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m backend \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:170\u001b[0m, in \u001b[0;36mtpu_client_timer_callback\u001b[0;34m(timer_secs)\u001b[0m\n\u001b[1;32m 167\u001b[0m t\u001b[38;5;241m.\u001b[39mstart()\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 170\u001b[0m client \u001b[38;5;241m=\u001b[39m \u001b[43mxla_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 172\u001b[0m t\u001b[38;5;241m.\u001b[39mcancel()\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jaxlib/xla_client.py:96\u001b[0m, in \u001b[0;36mmake_tpu_client\u001b[0;34m()\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_tpu_client\u001b[39m():\n\u001b[0;32m---> 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_xla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_inflight_computations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"# check how many devices are available\n",
"jax.local_device_count()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# type used for computation - use bfloat16 on TPU's\n",
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
"\n",
"# TODO: fix issue with bfloat16\n",
"dtype = jnp.float32"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 374
},
"id": "92zYmvsQ38vL",
"outputId": "909b0a3c-14cb-4722-8eb2-f876ff50257c"
},
"outputs": [],
"source": [
"# Load models & tokenizer\n",
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
"import wandb\n",
"\n",
"# Load dalle-mini\n",
"model = DalleBart.from_pretrained(\n",
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
")\n",
"tokenizer = DalleBartTokenizer.from_pretrained(\n",
" DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
")\n",
"\n",
"# Load VQGAN\n",
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
"\n",
"# Load CLIP\n",
"clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
"processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o_vH2X1tDtzA"
},
"source": [
"Model parameters are replicated on each device for faster inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wtvLoM48EeVw"
},
"outputs": [],
"source": [
"from flax.jax_utils import replicate\n",
"\n",
"# convert model parameters for inference if requested\n",
"if dtype == jnp.bfloat16:\n",
" model.params = model.to_bf16(model.params)\n",
"\n",
"model_params = replicate(model.params)\n",
"vqgan_params = replicate(vqgan.params)\n",
"clip_params = replicate(clip.params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0A9AHQIgZ_qw"
},
"source": [
"Model functions are compiled and parallelized to take advantage of multiple devices."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sOtoOmYsSYPz"
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"# model inference\n",
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
" return model.generate(\n",
" **tokenized_prompt,\n",
" do_sample=True,\n",
" num_beams=1,\n",
" prng_key=key,\n",
" params=params,\n",
" top_k=top_k,\n",
" top_p=top_p,\n",
" max_length=257\n",
" )\n",
"\n",
"\n",
"# decode images\n",
"@partial(jax.pmap, axis_name=\"batch\")\n",
"def p_decode(indices, params):\n",
" return vqgan.decode_code(indices, params=params)\n",
"\n",
"\n",
"# score images\n",
"@partial(jax.pmap, axis_name=\"batch\")\n",
"def p_clip(inputs, params):\n",
" logits = clip(params=params, **inputs).logits_per_image\n",
" return logits"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HmVN6IBwapBA"
},
"source": [
"Keys are passed to the model on each device to generate unique inference per device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4CTXmlUkThhX"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"# create a random key\n",
"seed = random.randint(0, 2**32 - 1)\n",
"key = jax.random.PRNGKey(seed)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BrnVyCo81pij"
},
"source": [
"## 🖍 Text Prompt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rsmj0Aj5OQox"
},
"source": [
"Our model may require to normalize the prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YjjhUychOVxm"
},
"outputs": [],
"source": [
"from dalle_mini.text import TextNormalizer\n",
"\n",
"text_normalizer = TextNormalizer() if model.config.normalize_text else None"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BQ7fymSPyvF_"
},
"source": [
"Let's define a text prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x_0vI9ge1oKr"
},
"outputs": [],
"source": [
"prompt = \"a waterfall under the sunset\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VKjEZGjtO49k"
},
"outputs": [],
"source": [
"processed_prompt = text_normalizer(prompt) if model.config.normalize_text else prompt\n",
"processed_prompt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We tokenize the prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenized_prompt = tokenizer(\n",
" processed_prompt,\n",
" return_tensors=\"jax\",\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" max_length=128,\n",
").data\n",
"tokenized_prompt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Y5dqFj7prMQ"
},
"source": [
"Notes:\n",
"\n",
"* `0`: BOS, special token representing the beginning of a sequence\n",
"* `2`: EOS, special token representing the end of a sequence\n",
"* `1`: special token representing the padding of a sequence when requesting a specific length"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we replicate it onto each device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenized_prompt = replicate(tokenized_prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "phQ9bhjRkgAZ"
},
"source": [
"## 🎨 Generate images\n",
"\n",
"We generate images using dalle-mini model and decode them with the VQGAN."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d0wVkXpKqnHA"
},
"outputs": [],
"source": [
"# number of predictions\n",
"n_predictions = 32\n",
"\n",
"# We can customize top_k/top_p used for generating samples\n",
"gen_top_k = None\n",
"gen_top_p = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SDjEx9JxR3v8"
},
"outputs": [],
"source": [
"from flax.training.common_utils import shard_prng_key\n",
"import numpy as np\n",
"from PIL import Image\n",
"from tqdm.notebook import trange\n",
"\n",
"# generate images\n",
"images = []\n",
"for i in trange(n_predictions // jax.device_count()):\n",
" # get a new key\n",
" key, subkey = jax.random.split(key)\n",
" # generate images\n",
" encoded_images = p_generate(\n",
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
" )\n",
" # remove BOS\n",
" encoded_images = encoded_images.sequences[..., 1:]\n",
" # decode images\n",
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
" for img in decoded_images:\n",
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tw02wG9zGmyB"
},
"source": [
"Let's calculate their score with CLIP."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FoLXpjCmGpju"
},
"outputs": [],
"source": [
"from flax.training.common_utils import shard\n",
"\n",
"# get clip scores\n",
"clip_inputs = processor(\n",
" text=[prompt] * jax.device_count(),\n",
" images=images,\n",
" return_tensors=\"np\",\n",
" padding=\"max_length\",\n",
" max_length=77,\n",
" truncation=True,\n",
").data\n",
"logits = p_clip(shard(clip_inputs), clip_params)\n",
"logits = logits.squeeze().flatten()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4AAWRm70LgED"
},
"source": [
"Let's display images ranked by CLIP score."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zsgxxubLLkIu"
},
"outputs": [],
"source": [
"print(f\"Prompt: {prompt}\\n\")\n",
"for idx in logits.argsort()[::-1]:\n",
" display(images[idx])\n",
" print(f\"Score: {logits[idx]:.2f}\\n\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "Copy of DALL·E mini - Inference pipeline.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}