{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ewer-Q-0w2xA" }, "source": [ "# Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NpsF9ipLLl2s", "outputId": "10bf54aa-b89d-4e42-9777-bc97b00a5f32" }, "outputs": [], "source": [ "#!pip install git+https://github.com/huggingface/transformers/\n", "#!pip install git+https://github.com/google/flax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M1wVkrpjU6zO" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%cd ../../vqgan-jax" ] }, { "cell_type": "markdown", "metadata": { "id": "t47CH1H_IOT8" }, "source": [ "# Custom BART Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9jQnM6S2vCpn" }, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_eEaJVxAKpV5" }, "outputs": [], "source": [ "import jax\n", "import flax.linen as nn\n", "\n", "from transformers.models.bart.modeling_flax_bart import *\n", "from transformers import BartTokenizer, FlaxBartForConditionalGeneration\n", "\n", "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import wandb\n", "run = wandb.init()\n", "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-1ef8yxby:latest', type='bart_model')\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_6-XKK40oEfP", "scrolled": true }, "outputs": [], "source": [ "# create our model and initialize it randomly\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.config.forced_bos_token_id = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49" }, "outputs": [], "source": [ "# we verify that the shape has not been modified\n", "model.params['final_logits_bias'].shape" ] }, { "cell_type": "markdown", "metadata": { "id": "zLl24Ez5t7x1" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XLLA2NK3uDQr" }, "outputs": [], "source": [ "tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_text = ['I enjoy walking with my cute dog']*8" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P32mJJSbrU1F" }, "outputs": [], "source": [ "input_ids_test = tokenizer(input_text, return_tensors='jax')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_ids_test" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C7cHbIHruELT" }, "outputs": [], "source": [ "greedy_output = model.generate(input_ids_test['input_ids'], max_length=257)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "greedy_output[0].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jYugh9cOuwc9", "outputId": "19c3a2ee-e7bc-4f1f-9c86-06bd7337b537" }, "outputs": [], "source": [ "greedy_output[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "greedy_output[0][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# VGAN Jax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "\n", "import requests\n", "from PIL import Image\n", "import numpy as np\n", "\n", "import torch\n", "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from torchvision.transforms import InterpolationMode" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from modeling_flax_vqgan import VQModel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def custom_to_pil(x):\n", " x = np.clip(x, 0., 1.)\n", " x = (255*x).astype(np.uint8)\n", " x = Image.fromarray(x)\n", " if not x.mode == \"RGB\":\n", " x = x.convert(\"RGB\")\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jz032w73nHEf", "outputId": "994d8e85-bff7-480b-8b69-f69dedc15c49", "scrolled": true }, "outputs": [], "source": [ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_images(indices, model):\n", " indices = indices[:, 1:]\n", " print(indices.shape)\n", " img = model.decode_code(indices)\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "custom_to_pil(np.asarray(get_images(jnp.expand_dims(greedy_output[0][0],0), model)[0]))" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "machine_shape": "hm", "name": "CustomBARTv4b-model-generate.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }