{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8", "metadata": {}, "outputs": [], "source": [ "import tempfile\n", "from functools import partial\n", "import random\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm.notebook import tqdm\n", "import jax\n", "import jax.numpy as jnp\n", "from flax.training.common_utils import shard, shard_prng_key\n", "from flax.jax_utils import replicate\n", "import wandb\n", "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n", "from dalle_mini.text import TextNormalizer" ] }, { "cell_type": "code", "execution_count": null, "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270", "metadata": {}, "outputs": [], "source": [ "run_ids = [\"63otg87g\"]\n", "ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n", "VQGAN_REPO, VQGAN_COMMIT_ID = (\n", " \"dalle-mini/vqgan_imagenet_f16_16384\",\n", " \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n", ")\n", "latest_only = True # log only latest or all versions\n", "suffix = \"\" # mainly for duplicate inference runs with a deleted version\n", "add_clip_32 = False" ] }, { "cell_type": "code", "execution_count": null, "id": "71f27b96-7e6c-4472-a2e4-e99a8fb67a72", "metadata": {}, "outputs": [], "source": [ "# model.generate parameters - Not used yet\n", "gen_top_k = None\n", "gen_top_p = None\n", "temperature = None" ] }, { "cell_type": "code", "execution_count": null, "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3", "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n", "num_images = 128\n", "top_k = 8\n", "text_normalizer = TextNormalizer()\n", "padding_item = \"NONE\"\n", "seed = random.randint(0, 2 ** 32 - 1)\n", "key = jax.random.PRNGKey(seed)\n", "api = wandb.Api()" ] }, { "cell_type": "code", "execution_count": null, "id": "c6a878fa-4bf5-4978-abb5-e235841d765b", "metadata": {}, "outputs": [], "source": [ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n", "vqgan_params = replicate(vqgan.params)\n", "\n", "clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n", "processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n", "clip16_params = replicate(clip16.params)\n", "\n", "if add_clip_32:\n", " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", " clip32_params = replicate(clip32.params)" ] }, { "cell_type": "code", "execution_count": null, "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3", "metadata": {}, "outputs": [], "source": [ "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_clip16(inputs, params):\n", " logits = clip16(params=params, **inputs).logits_per_image\n", " return logits\n", "\n", "\n", "if add_clip_32:\n", "\n", " @partial(jax.pmap, axis_name=\"batch\")\n", " def p_clip32(inputs, params):\n", " logits = clip32(params=params, **inputs).logits_per_image\n", " return logits" ] }, { "cell_type": "code", "execution_count": null, "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9", "metadata": {}, "outputs": [], "source": [ "with open(\"samples.txt\", encoding=\"utf8\") as f:\n", " samples = [l.strip() for l in f.readlines()]\n", " # make list multiple of batch_size by adding elements\n", " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n", " samples.extend(samples_to_add)\n", " # reshape\n", " samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]" ] }, { "cell_type": "code", "execution_count": null, "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614", "metadata": {}, "outputs": [], "source": [ "def get_artifact_versions(run_id, latest_only=False):\n", " try:\n", " if latest_only:\n", " return [\n", " api.artifact(\n", " type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n", " )\n", " ]\n", " else:\n", " return api.artifact_versions(\n", " type_name=\"bart_model\",\n", " name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n", " per_page=10000,\n", " )\n", " except:\n", " return []" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570", "metadata": {}, "outputs": [], "source": [ "def get_training_config(run_id):\n", " training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n", " config = training_run.config\n", " return config" ] }, { "cell_type": "code", "execution_count": null, "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f", "metadata": {}, "outputs": [], "source": [ "# retrieve inference run details\n", "def get_last_inference_version(run_id):\n", " try:\n", " inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n", " return inference_run.summary.get(\"version\", None)\n", " except:\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac", "metadata": {}, "outputs": [], "source": [ "# compile functions - needed only once per run\n", "def pmap_model_function(model):\n", " @partial(jax.pmap, axis_name=\"batch\")\n", " def _generate(tokenized_prompt, key, params):\n", " return model.generate(\n", " **tokenized_prompt,\n", " do_sample=True,\n", " num_beams=1,\n", " prng_key=key,\n", " params=params,\n", " top_k=gen_top_k,\n", " top_p=gen_top_p\n", " )\n", "\n", " return _generate" ] }, { "cell_type": "code", "execution_count": null, "id": "23b2444c-67a9-44d7-abd1-187ed83a9431", "metadata": {}, "outputs": [], "source": [ "run_id = run_ids[0]\n", "# TODO: loop over runs" ] }, { "cell_type": "code", "execution_count": null, "id": "bba70f33-af8b-4eb3-9973-7be672301a0b", "metadata": {}, "outputs": [], "source": [ "artifact_versions = get_artifact_versions(run_id, latest_only)\n", "last_inference_version = get_last_inference_version(run_id)\n", "training_config = get_training_config(run_id)\n", "run = None\n", "p_generate = None\n", "model_files = [\n", " \"config.json\",\n", " \"flax_model.msgpack\",\n", " \"merges.txt\",\n", " \"special_tokens_map.json\",\n", " \"tokenizer.json\",\n", " \"tokenizer_config.json\",\n", " \"vocab.json\",\n", "]\n", "for artifact in artifact_versions:\n", " print(f\"Processing artifact: {artifact.name}\")\n", " version = int(artifact.version[1:])\n", " results16, results32 = [], []\n", " columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n", "\n", " if latest_only:\n", " assert last_inference_version is None or version > last_inference_version\n", " else:\n", " if last_inference_version is None:\n", " # we should start from v0\n", " assert version == 0\n", " elif version <= last_inference_version:\n", " print(\n", " f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n", " )\n", " else:\n", " # check we are logging the correct version\n", " assert version == last_inference_version + 1\n", "\n", " # start/resume corresponding run\n", " if run is None:\n", " run = wandb.init(\n", " job_type=\"inference\",\n", " entity=\"dalle-mini\",\n", " project=\"dalle-mini\",\n", " config=training_config,\n", " id=f\"{run_id}-clip16{suffix}\",\n", " resume=\"allow\",\n", " )\n", "\n", " # work in temporary directory\n", " with tempfile.TemporaryDirectory() as tmp:\n", "\n", " # download model files\n", " artifact = run.use_artifact(artifact)\n", " for f in model_files:\n", " artifact.get_path(f).download(tmp)\n", "\n", " # load tokenizer and model\n", " tokenizer = BartTokenizer.from_pretrained(tmp)\n", " model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)\n", " model_params = replicate(model.params)\n", "\n", " # pmap model function needs to happen only once per model config\n", " if p_generate is None:\n", " p_generate = pmap_model_function(model)\n", "\n", " # process one batch of captions\n", " for batch in tqdm(samples):\n", " processed_prompts = (\n", " [text_normalizer(x) for x in batch]\n", " if model.config.normalize_text\n", " else list(batch)\n", " )\n", "\n", " # repeat the prompts to distribute over each device and tokenize\n", " processed_prompts = processed_prompts * jax.device_count()\n", " tokenized_prompt = tokenizer(\n", " processed_prompts,\n", " return_tensors=\"jax\",\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=128,\n", " ).data\n", " tokenized_prompt = shard(tokenized_prompt)\n", "\n", " # generate images\n", " images = []\n", " pbar = tqdm(\n", " range(num_images // jax.device_count()),\n", " desc=\"Generating Images\",\n", " leave=True,\n", " )\n", " for i in pbar:\n", " key, subkey = jax.random.split(key)\n", " encoded_images = p_generate(\n", " tokenized_prompt, shard_prng_key(subkey), model_params\n", " )\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n", " (-1, 256, 256, 3)\n", " )\n", " for img in decoded_images:\n", " images.append(\n", " Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n", " )\n", "\n", " def add_clip_results(results, processor, p_clip, clip_params):\n", " clip_inputs = processor(\n", " text=batch,\n", " images=images,\n", " return_tensors=\"np\",\n", " padding=\"max_length\",\n", " max_length=77,\n", " truncation=True,\n", " ).data\n", " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n", " images_per_prompt_indices = np.asarray(\n", " range(0, len(images), batch_size)\n", " )\n", " clip_inputs[\"pixel_values\"] = jnp.concatenate(\n", " list(\n", " clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n", " for i in range(batch_size)\n", " )\n", " )\n", " clip_inputs = shard(clip_inputs)\n", " logits = p_clip(clip_inputs, clip_params)\n", " logits = logits.reshape(-1, num_images)\n", " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n", " logits = jax.device_get(logits)\n", " # add to results table\n", " for i, (idx, scores, sample) in enumerate(\n", " zip(top_scores, logits, batch)\n", " ):\n", " if sample == padding_item:\n", " continue\n", " cur_images = [images[x] for x in images_per_prompt_indices + i]\n", " top_images = [\n", " wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n", " for x in idx\n", " ]\n", " results.append([sample] + top_images)\n", "\n", " # get clip scores\n", " pbar.set_description(\"Calculating CLIP 16 scores\")\n", " add_clip_results(results16, processor16, p_clip16, clip16_params)\n", "\n", " # get clip 32 scores\n", " if add_clip_32:\n", " pbar.set_description(\"Calculating CLIP 32 scores\")\n", " add_clip_results(results32, processor32, p_clip32, clip32_params)\n", "\n", " pbar.close()\n", "\n", " # log results\n", " table = wandb.Table(columns=columns, data=results16)\n", " run.log({\"Samples\": table, \"version\": version})\n", " wandb.finish()\n", "\n", " if add_clip_32:\n", " run = wandb.init(\n", " job_type=\"inference\",\n", " entity=\"dalle-mini\",\n", " project=\"dalle-mini\",\n", " config=training_config,\n", " id=f\"{run_id}-clip32{suffix}\",\n", " resume=\"allow\",\n", " )\n", " table = wandb.Table(columns=columns, data=results32)\n", " run.log({\"Samples\": table, \"version\": version})\n", " wandb.finish()\n", " run = None # ensure we don't log on this run" ] }, { "cell_type": "code", "execution_count": null, "id": "415d3f54-7226-43de-9eea-4283a948dc93", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }