In [None]:
# !pip install flax transformers
# !git clone https://github.com/patil-suraj/vqgan-jax.git

In [305]:
%cd ~/vqgan-jax

import random


import jax
import flax.linen as nn
from flax.training.common_utils import shard
from flax.jax_utils import replicate, unreplicate

from transformers.models.bart.modeling_flax_bart import *
from transformers import BartTokenizer, FlaxBartForConditionalGeneration

import io

import requests
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode


from modeling_flax_vqgan import VQModel

jax.devices()

/home/surajpatil/vqgan-jax


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [2]:
# TODO: set those args in a config file
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
BOS_TOKEN_ID = 16384
BASE_MODEL = 'facebook/bart-large'

In [3]:
class CustomFlaxBartModule(FlaxBartModule):
 def setup(self):
 # we keep shared to easily load pre-trained weights
 self.shared = nn.Embed(
 self.config.vocab_size,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 # a separate embedding is used for the decoder
 self.decoder_embed = nn.Embed(
 OUTPUT_VOCAB_SIZE,
 self.config.d_model,
 embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 dtype=self.dtype,
 )
 self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

 # the decoder has a different config
 decoder_config = BartConfig(self.config.to_dict())
 decoder_config.max_position_embeddings = OUTPUT_LENGTH
 decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
 self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)

class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 def setup(self):
 self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
 self.lm_head = nn.Dense(
 OUTPUT_VOCAB_SIZE,
 use_bias=False,
 dtype=self.dtype,
 kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
 )
 self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))

class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
 module_class = CustomFlaxBartForConditionalGenerationModule

In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:v7', type='bart_model')
artifact_dir = artifact.download()

In [164]:
# create our model and initialize it randomly
tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
model.config.force_bos_token_to_be_generated = False
model.config.forced_bos_token_id = None
model.config.forced_eos_token_id = None

# we verify that the shape has not been modified
model.params['final_logits_bias'].shape



(1, 16385)

In [6]:
vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=304307206.0, style=ProgressStyle(descri…


Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


In [295]:
def custom_to_pil(x):
 x = np.clip(x, 0., 1.)
 x = (255*x).astype(np.uint8)
 x = Image.fromarray(x)
 if not x.mode == "RGB":
 x = x.convert("RGB")
 return x

def generate(input, rng, params):
 return model.generate(
 **input,
 max_length=257,
 num_beams=1,
 do_sample=True,
 prng_key=rng,
 eos_token_id=50000,
 pad_token_id=50000,
 params=params
 )

def get_images(indices, params):
 return vqgan.decode_code(indices, params=params)


def plot_images(images):
 fig = plt.figure(figsize=(40, 20))
 columns = 4
 rows = 2
 plt.subplots_adjust(hspace=0, wspace=0)

 for i in range(1, columns*rows +1):
 fig.add_subplot(rows, columns, i)
 plt.imshow(images[i-1])
 plt.gca().axes.get_yaxis().set_visible(False)
 plt.show()
 
def stack_reconstructions(images):
 w, h = images[0].size[0], images[0].size[1]
 img = Image.new("RGB", (len(images)*w, h))
 for i, img_ in enumerate(images):
 img.paste(img_, (i*w,0))
 return img

In [166]:
p_generate = jax.pmap(generate, "batch")
p_get_images = jax.pmap(get_images, "batch")

In [None]:
bart_params = replicate(model.params)
vqgan_params = replicate(vqgan.params)

In [328]:
prompts = [
 "man in blue jacket walking on pathway in between trees during daytime",
 'white snow covered mountain under blue sky during daytime',
 'white snow covered mountain under blue sky during night',
 "orange tabby cat on persons hand",
 "aerial view of beach during daytime",
 "chess pieces on chess board",
 "laptop on brown wooden table",
 "white bus on road near high rise buildings",
]


prompt = [prompts[-1]] * 8
inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
inputs = shard(inputs)

In [None]:
%%time
for i in range(8):
 key = random.randint(0, 1e7)
 rng = jax.random.PRNGKey(key)
 rngs = jax.random.split(rng, jax.local_device_count())
 indices = p_generate(inputs, rngs, bart_params).sequences
 indices = indices[:, :, 1:]

 images = p_get_images(indices, vqgan_params)
 images = np.squeeze(np.asarray(images), 1)
 imges = [custom_to_pil(image) for image in images]

 plt.figure(figsize=(40, 20))
 plt.imshow(stack_reconstructions(imges))