# !pip install flax transformers
# !git clone https://github.com/patil-suraj/vqgan-jax.git %cd ~/vqgan-jax

import random


import jax
import flax.linen as nn
from flax.training.common_utils import shard
from flax.jax_utils import replicate, unreplicate

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

import io

import requests
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode


from modeling_flax_vqgan import VQModel

jax.devices() "metadata": {}, "outputs": [], "source": [ "# TODO: set those args in a config file\n", "OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos\n", "OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos\n", "BOS_TOKEN_ID = 16384\n", "BASE_MODEL = 'facebook/bart-large'" ] }, { "cell_type": "code", "execution_count": 3, "id": "bbef1afb-0b36-44a5-83f7-643d7e2c0e30", "metadata": {}, "outputs": [], "source": [ "class CustomFlaxBartModule(FlaxBartModule):\n", " def setup(self):\n", " # we keep shared to easily load pre-trained weights\n", " self.shared = nn.Embed(\n", " self.config.vocab_size,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " # a separate embedding is used for the decoder\n", " self.decoder_embed = nn.Embed(\n", " OUTPUT_VOCAB_SIZE,\n", " self.config.d_model,\n", " embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " dtype=self.dtype,\n", " )\n", " self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)\n", "\n", " # the decoder has a different config\n", " decoder_config = BartConfig(self.config.to_dict())\n", " decoder_config.max_position_embeddings = OUTPUT_LENGTH\n", " decoder_config.vocab_size = OUTPUT_VOCAB_SIZE\n", " self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)\n", "\n", "class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):\n", " def setup(self):\n", " self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)\n", " self.lm_head = nn.Dense(\n", " OUTPUT_VOCAB_SIZE,\n", " use_bias=False,\n", " dtype=self.dtype,\n", " kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),\n", " )\n", " self.final_logits_bias = self.param(\"final_logits_bias\", self.bias_init, (1, OUTPUT_VOCAB_SIZE))\n", "\n", "class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):\n", " module_class = CustomFlaxBartForConditionalGenerationModule" ] }, { "cell_type": "code", "execution_count": null, "id": "879320b7-eaa0-4dc9-bbf2-c81efc53301d", "metadata": {}, "outputs": [], "source": [ "import wandb\n", "run = wandb.init()\n", "artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')\n", "artifact_dir = artifact.download()" ] }, { "cell_type": "code", "execution_count": 164, "id": "e8bcff33-e95b-4c01-b162-ee857a55c3e6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/surajpatil/transformers/src/transformers/models/bart/configuration_bart.py:177: UserWarning: Please make sure the config includes `forced_bos_token_id=16384` in future versions.The config can simply be saved and uploaded again to be fixed.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "(1, 16385)" ] }, "execution_count": 164, "metadata": # create our model and initialize it randomly
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
model.config.force_bos_token_to_be_generated = False
model.config.forced_bos_token_id = None
model.config.forced_eos_token_id = None

# we verify that the shape has not been modified
model.params['final_logits_bias'].shape vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384") max_length=257,\n", " num_beams=1,\n", " do_sample=True,\n", " prng_key=rng,\n", " eos_token_id=50000,\n", " pad_token_id=50000,\n", " params=params\n", " )\n", "\n", "def get_images(indices, params):\n", " return vqgan.decode_code(indices, params=params)\n", "\n", "\n", "def plot_images(images):\n", " fig = plt.figure(figsize=(40, 20))\n", " columns = 4\n", " rows = 2\n", " plt.subplots_adjust(hspace=0, wspace=0)\n", "\n", " for i in range(1, columns*rows +1):\n", " fig.add_subplot(rows, columns, i)\n", " plt.imshow(images[i-1])\n", " plt.gca().axes.get_yaxis().set_visible(False)\n", " plt.show()\n", " \n", "def stack_reconstructions(images):\n", " w, h = images[0].size[0], images[0].size[1]\n", " img = Image.new(\"RGB\", (len(images)*w, h))\n", " for i, img_ in enumerate(images):\n", " img.paste(img_, (i*w,0))\n", " return img" ] }, { "cell_type": "code", "execution_count": 166, "id": "b1bec3d2-ef17-4feb-aa0d-b51ed2fdcd3e", "metadata": {}, "outputs": [], "source": [ "p_generate = jax.pmap(generate, \"batch\")\n", "p_get_images = jax.pmap(get_images, \"batch\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a539823a-a775-4d92-96a5-dc8b1eef69c5", "metadata": {}, "outputs": [], "source": [ "bart_params = replicate(model.params)\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "code", "execution_count": 328, "id": "e8b268d8-6992-422a-8373-95651474ae70", "metadata": {}, "outputs": [], "source": [ "prompts = [\n", " \"man in blue jacket walking on pathway in between trees during daytime\",\n", " 'white snow covered mountain under blue sky during daytime',\n", " 'white snow covered mountain under blue sky during night',\n", " \"orange tabby cat on persons hand\",\n", " \"aerial view of beach during daytime\",\n", " \"chess pieces on chess board\",\n", " \"laptop on brown wooden table\",\n", " \"white bus on road near high rise buildings\",\n", "]\n", "\n", "\n", "prompt = [prompts[-1]] * 8\n", "inputs = tokenizer(prompt, return_tensors='jax', padding=\"max_length\", truncation=True, max_length=128).data\n", "inputs = shard(inputs)" ] }, { "cell_type": "code", "execution_count": null, "id": "68638cfa-9a4d-4e6a-8630-91aefb627bbd", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for i in range(8):\n", " key = random.randint(0, 1e7)\n", " rng = jax.random.PRNGKey(key)\n", " rngs = jax.random.split(rng, jax.local_device_count())\n", " indices = p_generate(inputs, rngs, bart_params).sequences\n", " indices = indices[:, :, 1:]\n", "\n", " images = p_get_images(indices, vqgan_params)\n", " images = np.squeeze(np.asarray(images), 1)\n", " imges = [custom_to_pil(image) for image in images]\n", "\n", " plt.figure(figsize=(40, 20))\n", " plt.imshow(stack_reconstructions(imges))" ] }, { "cell_type": "code", "execution_count": null, "id": "681af54e-da10-4b8e-80d0-ebcbdf23f376", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", 