{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8", "metadata": {}, "outputs": [], "source": [ "import csv\n", "import tempfile\n", "from functools import partial\n", "import random\n", "import numpy as np\n", "from PIL import Image\n", "import jax\n", "import jax.numpy as jnp\n", "from flax.training.common_utils import shard, shard_prng_key\n", "from flax.jax_utils import replicate\n", "import wandb\n", "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n", "from dalle_mini.text import TextNormalizer" ] }, { "cell_type": "code", "execution_count": null, "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270", "metadata": {}, "outputs": [], "source": [ "wandb_runs = ['rjf3rycy']\n", "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n", "normalize_text = True" ] }, { "cell_type": "code", "execution_count": null, "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3", "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n", "num_images = 128\n", "top_k = 8\n", "text_normalizer = TextNormalizer() if normalize_text else None" ] }, { "cell_type": "code", "execution_count": null, "id": "6a045827-3461-4499-8959-38d173bc4e5e", "metadata": {}, "outputs": [], "source": [ "seed = random.randint(0, 2**32-1)\n", "key = jax.random.PRNGKey(seed)" ] }, { "cell_type": "code", "execution_count": null, "id": "c6a878fa-4bf5-4978-abb5-e235841d765b", "metadata": {}, "outputs": [], "source": [ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n", "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")" ] }, { "cell_type": "code", "execution_count": null, "id": "4927529a-8828-4150-bc76-e1b60d8dee62", "metadata": {}, "outputs": [], "source": [ "clip_params = replicate(clip.params)\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "code", "execution_count": null, "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9", "metadata": {}, "outputs": [], "source": [ "with open('samples.csv', newline='', encoding='utf8') as f:\n", " reader = csv.DictReader(f)\n", " samples = []\n", " for row in reader:\n", " samples.append(row)\n", " # make list multiple of batch_size by adding \"empty\"\n", " samples_to_add = [{'Caption':'empty', 'Theme':'empty'}] * (-len(samples) % batch_size)\n", " samples.extend(samples_to_add)\n", " # reshape\n", " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]" ] }, { "cell_type": "code", "execution_count": null, "id": "f75b2869-fc25-4f56-b937-e97bbb712ede", "metadata": {}, "outputs": [], "source": [ "len(samples)" ] }, { "cell_type": "code", "execution_count": null, "id": "c48525c9-447a-4430-81d7-4b699f545638", "metadata": {}, "outputs": [], "source": [ "samples[-1]" ] }, { "cell_type": "code", "execution_count": null, "id": "a2c629e9-1a82-40c6-a260-ca1780c19a2e", "metadata": {}, "outputs": [], "source": [ "api = wandb.Api()" ] }, { "cell_type": "code", "execution_count": null, "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08", "metadata": {}, "outputs": [], "source": [ "# TODO: iterate on runs\n", "wandb_run = wandb_runs[0]\n", "functions_pmapped = False" ] }, { "cell_type": "code", "execution_count": null, "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614", "metadata": {}, "outputs": [], "source": [ "try:\n", " versions = api.artifact_versions(type_name='bart_model', name=f'dalle-mini/dalle-mini/model-{wandb_run}', per_page=10000)\n", "except:\n", " versions = []" ] }, { "cell_type": "code", "execution_count": null, "id": "e8026e63-9e73-472c-9440-5e742c614901", "metadata": {}, "outputs": [], "source": [ "versions, len(versions)" ] }, { "cell_type": "code", "execution_count": null, "id": "ead44aee-52d5-4ca2-8984-c4d267d9e72a", "metadata": {}, "outputs": [], "source": [ "versions[0].version" ] }, { "cell_type": "code", "execution_count": null, "id": "cfd48de9-6022-444f-8b12-05cba8fad071", "metadata": {}, "outputs": [], "source": [ "artifact = versions[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "4db848c1-2bb5-432c-a732-1c6d0636e172", "metadata": {}, "outputs": [], "source": [ "version = int(artifact.version[1:])" ] }, { "cell_type": "code", "execution_count": null, "id": "25fac577-146d-4e62-a3ea-f0baea79ef83", "metadata": {}, "outputs": [], "source": [ "version" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570", "metadata": {}, "outputs": [], "source": [ "# retrieve training run\n", "training_run = api.run(f'dalle-mini/dalle-mini/{wandb_run}')\n", "config = training_run.config" ] }, { "cell_type": "code", "execution_count": null, "id": "9b9393c6-0a3c-46a8-ba27-ba37982b0009", "metadata": {}, "outputs": [], "source": [ "# see summary metrics\n", "training_run.summary" ] }, { "cell_type": "code", "execution_count": null, "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f", "metadata": {}, "outputs": [], "source": [ "# retrieve inference run details\n", "def get_last_version_inference(run_id):\n", " try:\n", " inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')\n", " return inference_run.summary.get('_step', None)\n", " except:\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "93b8d869-1658-4fa4-a401-2b91f8ac7a11", "metadata": {}, "outputs": [], "source": [ "last_version_inference = get_last_version_inference(wandb_run)" ] }, { "cell_type": "code", "execution_count": null, "id": "8324835e-fd94-408e-b106-138be308480b", "metadata": {}, "outputs": [], "source": [ "if last_version_inference is None:\n", " assert version == 0\n", "elif last_version_inference >= version:\n", " print(f'Version {version} has already been logged')\n", "else:\n", " assert version == last_version_inference + 1" ] }, { "cell_type": "code", "execution_count": null, "id": "8ce9d2d3-aea3-4d5e-834a-c5caf85dd117", "metadata": {}, "outputs": [], "source": [ "run = wandb.init(job_type='inference', config=config, id=f'inference-{wandb_run}', resume='allow')" ] }, { "cell_type": "code", "execution_count": null, "id": "ffe392c9-36d2-4aaa-a1b3-a827e348c1ef", "metadata": {}, "outputs": [], "source": [ "tmp_f.cleanup\n", "tmp_f = tempfile.TemporaryDirectory()\n", "tmp = tmp_f.name\n", "#TODO: use context manager" ] }, { "cell_type": "code", "execution_count": null, "id": "562036ed-dc86-48af-90b1-9c18383b3552", "metadata": {}, "outputs": [], "source": [ "# remove tmp\n", "tmp_f.cleanup()" ] }, { "cell_type": "code", "execution_count": null, "id": "299db1bb-fbe6-4d79-a48f-89893f8ed809", "metadata": {}, "outputs": [], "source": [ "artifact = run.use_artifact(artifact)" ] }, { "cell_type": "code", "execution_count": null, "id": "d71481bf-98aa-42cb-b7e2-545d13ae4309", "metadata": {}, "outputs": [], "source": [ "# only download required files\n", "for f in ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']:\n", " artifact.get_path(f).download(tmp)" ] }, { "cell_type": "code", "execution_count": null, "id": "6f8ad8dd-da8f-40f9-b438-e43b779d637c", "metadata": {}, "outputs": [], "source": [ "# we verify all the files are present\n", "from pathlib import Path\n", "list(Path(tmp).glob('*'))" ] }, { "cell_type": "code", "execution_count": null, "id": "5b715c32-e757-4cb0-9912-ff90238b9f10", "metadata": {}, "outputs": [], "source": [ "tokenizer = BartTokenizer.from_pretrained(tmp)\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)" ] }, { "cell_type": "code", "execution_count": null, "id": "320823c9-124a-4fc3-a12c-8c015a128285", "metadata": {}, "outputs": [], "source": [ "model_params = replicate(model.params)" ] }, { "cell_type": "code", "execution_count": null, "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac", "metadata": {}, "outputs": [], "source": [ "# function to generate encoded images\n", "# we should generate this function only once per run\n", "if not functions_pmapped:\n", " @partial(jax.pmap, axis_name=\"batch\")\n", " def p_generate(tokenized_prompt, key, params):\n", " return model.generate(\n", " **tokenized_prompt,\n", " do_sample=True,\n", " num_beams=1,\n", " prng_key=key,\n", " params=params\n", " )\n", " \n", " @partial(jax.pmap, axis_name=\"batch\")\n", " def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", " \n", " @partial(jax.pmap, axis_name=\"batch\")\n", " def p_clip(inputs):\n", " logits = clip(**inputs).logits_per_image\n", " return logits\n", " scores = jax.nn.softmax(logits, axis=0).squeeze() \n", " \n", " functions_pmapped = False" ] }, { "cell_type": "code", "execution_count": null, "id": "7a24b903-777b-4e3d-817c-00ed613a7021", "metadata": {}, "outputs": [], "source": [ "# TODO: loop over samples\n", "batch = samples[0]\n", "prompts = [x['Caption'] for x in batch]\n", "processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts" ] }, { "cell_type": "code", "execution_count": null, "id": "d77aa785-dc05-4070-aba2-aa007524d20b", "metadata": {}, "outputs": [], "source": [ "processed_prompts" ] }, { "cell_type": "code", "execution_count": null, "id": "95db38fb-8948-4814-98ae-c172ca7c6d0a", "metadata": {}, "outputs": [], "source": [ "repeated_prompts = processed_prompts * jax.device_count()" ] }, { "cell_type": "code", "execution_count": null, "id": "e948ba9e-3700-4e87-926f-580a10f3e5cd", "metadata": {}, "outputs": [], "source": [ "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n", "tokenized_prompt = shard(tokenized_prompt)" ] }, { "cell_type": "code", "execution_count": null, "id": "30d96812-fc17-4acf-bb64-5fdb8d0cd313", "metadata": {}, "outputs": [], "source": [ "tokenized_prompt['input_ids'].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "92ea034b-2649-4d18-ab6d-877ed04ae5c4", "metadata": {}, "outputs": [], "source": [ "images = []\n", "for i in range(num_images // jax.device_count()):\n", " key, subkey = jax.random.split(key, 2)\n", " \n", " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " \n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n", " \n", " for img in decoded_images:\n", " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "84d52f30-44c9-4a74-9992-fb2578f19b90", "metadata": {}, "outputs": [], "source": [ "len(images)" ] }, { "cell_type": "code", "execution_count": null, "id": "beb594f9-5b91-47fe-98bd-41e68c6b1d73", "metadata": {}, "outputs": [], "source": [ "images[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "bb135190-64e5-44af-b416-e688b034da44", "metadata": {}, "outputs": [], "source": [ "images[1]" ] }, { "cell_type": "code", "execution_count": null, "id": "d78a0d92-72c2-4f82-a6ab-b3f5865dd863", "metadata": {}, "outputs": [], "source": [ "clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data" ] }, { "cell_type": "code", "execution_count": null, "id": "89ff78a6-bfa4-44d9-ad66-07a4a68b4352", "metadata": {}, "outputs": [], "source": [ "# each shard will have one prompt\n", "clip_inputs['input_ids'].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "2cda8984-049c-4c87-96ad-7b0412750656", "metadata": {}, "outputs": [], "source": [ "# each shard needs to have the images corresponding to a specific prompt\n", "clip_inputs['pixel_values'].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0a044e8f-be29-404b-b6c7-8f2395c5efc6", "metadata": {}, "outputs": [], "source": [ "images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n", "images_per_prompt_indices" ] }, { "cell_type": "code", "execution_count": null, "id": "7a6c61b3-12e0-45d8-b39a-830288324d3d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7318e67e-4214-46f9-bf60-6d139d4bd00f", "metadata": {}, "outputs": [], "source": [ "# reorder so each shard will have correct images\n", "clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))" ] }, { "cell_type": "code", "execution_count": null, "id": "90c949a2-8e2a-4905-b6d4-92038f1704b8", "metadata": {}, "outputs": [], "source": [ "clip_inputs = shard(clip_inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "58fa836e-5ebb-45e7-af77-ab10646dfbfb", "metadata": {}, "outputs": [], "source": [ "logits = p_clip(clip_inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "fd7a3f91-3a1f-4a0a-8b3e-3c926cd367fb", "metadata": {}, "outputs": [], "source": [ "logits.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "fa406db7-0a21-4e4b-9890-4c7aece4280c", "metadata": {}, "outputs": [], "source": [ "logits = logits.reshape(-1, num_images)" ] }, { "cell_type": "code", "execution_count": null, "id": "9c359a8c-2c27-4e68-8775-371857397723", "metadata": {}, "outputs": [], "source": [ "logits.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a56b9f28-dd91-4382-bc47-11e89fda1254", "metadata": {}, "outputs": [], "source": [ "logits" ] }, { "cell_type": "code", "execution_count": null, "id": "0bed8167-0a6d-46c1-badf-8bdc20b93c31", "metadata": {}, "outputs": [], "source": [ "top_idx = logits.argsort()[:, -top_k:][..., ::-1]" ] }, { "cell_type": "code", "execution_count": null, "id": "188c5333-6b8c-4a17-8cc8-15651c77ef99", "metadata": {}, "outputs": [], "source": [ "len(images)" ] }, { "cell_type": "code", "execution_count": null, "id": "babd22b3-e773-467d-8bbb-f0323f57a44b", "metadata": {}, "outputs": [], "source": [ "results = []\n", "columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]" ] }, { "cell_type": "code", "execution_count": null, "id": "75976c9f-dea5-48e3-8920-55a1bbfd91c2", "metadata": {}, "outputs": [], "source": [ "for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n", " cur_images = [images[x] for x in images_per_prompt_indices + i]\n", " top_images = [wandb.Image(cur_images[x]) for x in idx]\n", " top_scores = [logits[x] for x in idx]\n", " results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)" ] }, { "cell_type": "code", "execution_count": null, "id": "e1c04761-1016-47e9-925c-3a9ec6fec95a", "metadata": {}, "outputs": [], "source": [ "wandb.finish()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }