{ "cells": [ { "cell_type": "markdown", "id": "f6d33374", "metadata": {}, "source": [ "# Test notebook with CLIP scoring" ] }, { "cell_type": "code", "execution_count": null, "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf", "metadata": {}, "outputs": [], "source": [ "# !pip install flax transformers\n", "# !git clone https://github.com/patil-suraj/vqgan-jax.git" ] }, { "cell_type": "code", "execution_count": null, "id": "41db7534-f589-4b63-9165-9c9799e1b06e", "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "import jax\n", "import flax.linen as nn\n", "from flax.training.common_utils import shard\n", "from flax.jax_utils import replicate, unreplicate\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "\n", "jax.devices()" ] }, { "cell_type": "code", "execution_count": null, "id": "09295910", "metadata": {}, "outputs": [], "source": [ "from vqgan_jax.modeling_flax_vqgan import VQModel" ] }, { "cell_type": "code", "execution_count": null, "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2", "metadata": {}, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large-cnn'" ] }, { "cell_type": "code", "execution_count": null, "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30", "metadata": {}, "outputs": [], "source": [ "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": null, "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d", "metadata": {}, "outputs": [], "source": [ "import wandb\n", "run = wandb.init()\n", "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:latest', type='bart_model')\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": null, "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6", "metadata": {}, "outputs": [], "source": [ "# create our model\n", "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n", "model.config.force_bos_token_to_be_generated = False\n", "model.config.forced_bos_token_id = None\n", "model.config.forced_eos_token_id = None\n", "\n", "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "8d5e0f14-2502-470e-9553-daee6748601f", "metadata": {}, "outputs": [], "source": [ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6cca395a-93c2-49bc-a3be-98287e4403d4", "metadata": {}, "outputs": [], "source": [ "def custom_to_pil(x):\n", " x = np.clip(x, 0., 1.)\n", " x = (255*x).astype(np.uint8)\n", " x = Image.fromarray(x)\n", " if not x.mode == \"RGB\":\n", " x = x.convert(\"RGB\")\n", " return x\n", "\n", "def generate(input, rng, params):\n", " return model.generate(\n", " **input,\n", " max_length=257,\n", " num_beams=1,\n", " do_sample=True,\n", " prng_key=rng,\n", " eos_token_id=50000,\n", " pad_token_id=50000,\n", " params=params\n", " )\n", "\n", "def get_images(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "\n", "def plot_images(images):\n", " fig = plt.figure(figsize=(40, 20))\n", " columns = 4\n", " rows = 2\n", " plt.subplots_adjust(hspace=0, wspace=0)\n", "\n", " for i in range(1, columns*rows +1):\n", " fig.add_subplot(rows, columns, i)\n", " plt.imshow(images[i-1])\n", " plt.gca().axes.get_yaxis().set_visible(False)\n", " plt.show()\n", " \n", "def stack_reconstructions(images):\n", " w, h = images[0].size[0], images[0].size[1]\n", " img = Image.new(\"RGB\", (len(images)*w, h))\n", " for i, img_ in enumerate(images):\n", " img.paste(img_, (i*w,0))\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e", "metadata": {}, "outputs": [], "source": [ "p_generate = jax.pmap(generate, \"batch\")\n", "p_get_images = jax.pmap(get_images, \"batch\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5", "metadata": {}, "outputs": [], "source": [ "bart_params = replicate(model.params)\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "code", "execution_count": null, "id": "e8b268d8-6992-422a-8373-95651474ae70", "metadata": {}, "outputs": [], "source": [ "prompts = [\n", " \"man in blue jacket walking on pathway in between trees during daytime\",\n", " 'white snow covered mountain under blue sky during daytime',\n", " 'white snow covered mountain under blue sky during night',\n", " \"orange tabby cat on persons hand\",\n", " \"aerial view of beach during daytime\",\n", " \"chess pieces on chess board\",\n", " \"laptop on brown wooden table\",\n", " \"white bus on road near high rise buildings\",\n", "]\n", "\n", "\n", "prompt = [prompts[1]] * jax.device_count()\n", "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n", "inputs = shard(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for i in range(8):\n", " key = random.randint(0, 1e7)\n", " rng = jax.random.PRNGKey(key)\n", " rngs = jax.random.split(rng, jax.local_device_count())\n", " indices = p_generate(inputs, rngs, bart_params).sequences\n", " indices = indices[:, :, 1:]\n", "\n", " images = p_get_images(indices, vqgan_params)\n", " images = np.squeeze(np.asarray(images), 1)\n", " imges = [custom_to_pil(image) for image in images]\n", "\n", " plt.figure(figsize=(40, 20))\n", " plt.imshow(stack_reconstructions(imges))" ] }, { "cell_type": "markdown", "id": "b6e1060f", "metadata": {}, "source": [ "## CLIP Scoring" ] }, { "cell_type": "code", "execution_count": null, "id": "c68724bc", "metadata": {}, "outputs": [], "source": [ "from transformers import CLIPProcessor, FlaxCLIPModel" ] }, { "cell_type": "code", "execution_count": null, "id": "17158e5b", "metadata": {}, "outputs": [], "source": [ "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f1b37b6d", "metadata": {}, "outputs": [], "source": [ "def hallucinate(prompt, num_images=64):\n", " prompt = [prompt] * jax.device_count()\n", " inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n", " inputs = shard(inputs)\n", "\n", " all_images = []\n", " for i in range(num_images // jax.device_count()):\n", " key = random.randint(0, 1e7)\n", " rng = jax.random.PRNGKey(key)\n", " rngs = jax.random.split(rng, jax.local_device_count())\n", " indices = p_generate(inputs, rngs, bart_params).sequences\n", " indices = indices[:, :, 1:]\n", "\n", " images = p_get_images(indices, vqgan_params)\n", " images = np.squeeze(np.asarray(images), 1)\n", " for image in images:\n", " all_images.append(custom_to_pil(image))\n", " return all_images" ] }, { "cell_type": "code", "execution_count": null, "id": "831c715f", "metadata": {}, "outputs": [], "source": [ "def clip_top_k(prompt, images, k=8):\n", " inputs = processor(text=prompt, images=images, return_tensors=\"np\", padding=True)\n", " outputs = clip(**inputs)\n", " logits = outputs.logits_per_text\n", " scores = np.array(logits[0]).argsort()[-k:][::-1]\n", " return [images[score] for score in scores]" ] }, { "cell_type": "code", "execution_count": null, "id": "00605e13", "metadata": {}, "outputs": [], "source": [ "prompt = \"white snow covered mountain under blue sky during daytime\"\n", "images = hallucinate(prompt)\n", "selected = clip_top_k(prompt, images, k=8)\n", "stack_reconstructions(selected)" ] }, { "cell_type": "code", "execution_count": null, "id": "cc745da2", "metadata": {}, "outputs": [], "source": [ "prompt = \"aerial view of beach at night\"\n", "images = hallucinate(prompt)\n", "selected = clip_top_k(prompt, images, k=8)\n", "stack_reconstructions(selected)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9cc0b1d", "metadata": {}, "outputs": [], "source": [ "prompt = \"an armchair in the shape of an avocado\"\n", "images = hallucinate(prompt)\n", "selected = clip_top_k(prompt, images, k=8)\n", "stack_reconstructions(selected)" ] }, { "cell_type": "code", "execution_count": null, "id": "574e9433", "metadata": {}, "outputs": [], "source": [ "prompt = \"young woman riding her bike into a forest\"\n", "images = hallucinate(prompt)\n", "selected = clip_top_k(prompt, images, k=8)\n", "stack_reconstructions(selected)" ] }, { "cell_type": "markdown", "id": "4762c91e", "metadata": {}, "source": [ "`Forest` seems to dominate. Interesting cubist interpretation in the fourth image." ] }, { "cell_type": "code", "execution_count": null, "id": "af30608a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }