{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "6eb74941-bb4d-4d7e-97f1-d5a3a07672bf", "metadata": {}, "outputs": [], "source": [ "# !pip install flax transformers\n", "# !git clone https://github.com/patil-suraj/vqgan-jax.git" ] }, { "cell_type": "code", "execution_count": 305, "id": "41db7534-f589-4b63-9165-9c9799e1b06e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/surajpatil/vqgan-jax\n" ] }, { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 305, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%cd ~/vqgan-jax\n", "\n", "import random\n", "\n", "\n", "import jax\n", "import flax.linen as nn\n", "from flax.training.common_utils import shard\n", "from flax.jax_utils import replicate, unreplicate\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode\n", "\n", "\n", "from modeling_flax_vqgan import VQModel\n", "\n", "jax.devices()" ] }, { "cell_type": "code", "execution_count": 2, "id": "b6a3462a-9004-4121-b365-3ae3aaf94dd2", "metadata": {}, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large'" ] }, { "cell_type": "code", "execution_count": 3, "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30", "metadata": {}, "outputs": [], "source": [ "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": null, "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d", "metadata": {}, "outputs": [], "source": [ "import wandb\n", "run = wandb.init()\n", "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": 164, "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/surajpatil/transformers/src/transformers/models/bart/configuration_bart.py:177: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1, 16385)" ] }, "execution_count": 164, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create our model and initialize it randomly\n", "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)\n", "model.config.force_bos_token_to_be_generated = False\n", "model.config.forced_bos_token_id = None\n", "model.config.forced_eos_token_id = None\n", "\n", "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ] }, { "cell_type": "code", "execution_count": 6, "id": "8d5e0f14-2502-470e-9553-daee6748601f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9b979a72ab9e449387a89bf9b3012af5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01730e0e9d02428ca9dad680f9fdda42", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=304307206.0, style=ProgressStyle(descri…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n" ] } ], "source": [ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "code", "execution_count": 295, "id": "6cca395a-93c2-49bc-a3be-98287e4403d4", "metadata": {}, "outputs": [], "source": [ "def custom_to_pil(x):\n", " x = np.clip(x, 0., 1.)\n", " x = (255*x).astype(np.uint8)\n", " x = Image.fromarray(x)\n", " if not x.mode == \"RGB\":\n", " x = x.convert(\"RGB\")\n", " return x\n", "\n", "def generate(input, rng, params):\n", " return model.generate(\n", " **input,\n", " max_length=257,\n", " num_beams=1,\n", " do_sample=True,\n", " prng_key=rng,\n", " eos_token_id=50000,\n", " pad_token_id=50000,\n", " params=params\n", " )\n", "\n", "def get_images(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "\n", "def plot_images(images):\n", " fig = plt.figure(figsize=(40, 20))\n", " columns = 4\n", " rows = 2\n", " plt.subplots_adjust(hspace=0, wspace=0)\n", "\n", " for i in range(1, columns*rows +1):\n", " fig.add_subplot(rows, columns, i)\n", " plt.imshow(images[i-1])\n", " plt.gca().axes.get_yaxis().set_visible(False)\n", " plt.show()\n", " \n", "def stack_reconstructions(images):\n", " w, h = images[0].size[0], images[0].size[1]\n", " img = Image.new(\"RGB\", (len(images)*w, h))\n", " for i, img_ in enumerate(images):\n", " img.paste(img_, (i*w,0))\n", " return img" ] }, { "cell_type": "code", "execution_count": 166, "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e", "metadata": {}, "outputs": [], "source": [ "p_generate = jax.pmap(generate, \"batch\")\n", "p_get_images = jax.pmap(get_images, \"batch\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5", "metadata": {}, "outputs": [], "source": [ "bart_params = replicate(model.params)\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "code", "execution_count": 328, "id": "e8b268d8-6992-422a-8373-95651474ae70", "metadata": {}, "outputs": [], "source": [ "prompts = [\n", " \"man in blue jacket walking on pathway in between trees during daytime\",\n", " 'white snow covered mountain under blue sky during daytime',\n", " 'white snow covered mountain under blue sky during night',\n", " \"orange tabby cat on persons hand\",\n", " \"aerial view of beach during daytime\",\n", " \"chess pieces on chess board\",\n", " \"laptop on brown wooden table\",\n", " \"white bus on road near high rise buildings\",\n", "]\n", "\n", "\n", "prompt = [prompts[-1]] * 8\n", "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n", "inputs = shard(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for i in range(8):\n", " key = random.randint(0, 1e7)\n", " rng = jax.random.PRNGKey(key)\n", " rngs = jax.random.split(rng, jax.local_device_count())\n", " indices = p_generate(inputs, rngs, bart_params).sequences\n", " indices = indices[:, :, 1:]\n", "\n", " images = p_get_images(indices, vqgan_params)\n", " images = np.squeeze(np.asarray(images), 1)\n", " imges = [custom_to_pil(image) for image in images]\n", "\n", " plt.figure(figsize=(40, 20))\n", " plt.imshow(stack_reconstructions(imges))" ] }, { "cell_type": "code", "execution_count": null, "id": "681af54e-da10-4b8e-80d0-ebcbdf23f376", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }