In [1]:
import os
import yaml
import time
import torch
import matplotlib.pyplot as plt
from vqvae import VQBASE
import webdataset as wds
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import make_grid, save_image
device = "cuda"

# Load configuration for the model
with open("make_a_scene/img_config.yaml", 'r') as file:
 params = yaml.safe_load(file)["model"]
 del params["_target_"]


# Initialize and load the second model in bfloat16
vq_vae = VQBASE(**params).to(device)
vq_vae.load_state_dict(torch.load("make_a_scene/checkpoint_63.0.pt", map_location=device)["model"])
vq_vae = vq_vae.to(dtype=torch.bfloat16)
vq_vae.eval().requires_grad_(False)


VQBASE(
 (encoder): Encoder(
 (model): Sequential(
 (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ResnetBlock(
 (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
 (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
 (dropout): Dropout(p=0.0, inplace=False)
 (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (2): ResnetBlock(
 (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
 (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
 (dropout): Dropout(p=0.0, inplace=False)
 (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (3): Downsample(
 (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
 )
 (4): ResnetBlock(
 (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
 (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(

In [2]:
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

In [3]:
llama_tokenizer = PreTrainedTokenizerFast.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
# Number of to-add vocabs = number of codes in codebook of vq-vae
vq_vae.quantize.embedding

Embedding(8192, 256)

In [5]:
llama_tokenizer.add_tokens("<|img_start|>",special_tokens=True)
llama_tokenizer.add_tokens("<|img_end|>",special_tokens=True)
for img_token in range(0, 8192):
 padded_token = f"" # This pads the img_token with zeros to ensure it is 4 digits long.
 llama_tokenizer.add_tokens(padded_token)
llama_tokenizer.get_added_vocab()

{'<|begin_of_text|>': 128000,
 '<|end_of_text|>': 128001,
 '<|reserved_special_token_0|>': 128002,
 '<|reserved_special_token_1|>': 128003,
 '<|reserved_special_token_2|>': 128004,
 '<|reserved_special_token_3|>': 128005,
 '<|start_header_id|>': 128006,
 '<|end_header_id|>': 128007,
 '<|reserved_special_token_4|>': 128008,
 '<|eot_id|>': 128009,
 '<|reserved_special_token_5|>': 128010,
 '<|reserved_special_token_6|>': 128011,
 '<|reserved_special_token_7|>': 128012,
 '<|reserved_special_token_8|>': 128013,
 '<|reserved_special_token_9|>': 128014,
 '<|reserved_special_token_10|>': 128015,
 '<|reserved_special_token_11|>': 128016,
 '<|reserved_special_token_12|>': 128017,
 '<|reserved_special_token_13|>': 128018,
 '<|reserved_special_token_14|>': 128019,
 '<|reserved_special_token_15|>': 128020,
 '<|reserved_special_token_16|>': 128021,
 '<|reserved_special_token_17|>': 128022,
 '<|reserved_special_token_18|>': 128023,
 '<|reserved_special_token_19|>': 128024,
 '<|reserved_special_token_

In [8]:
llama_tokenizer.regist

AssertionError: Key <|img_start|> is not a special token

In [None]:
llama_tokenizer.add_special_tokens({"image start":"<|img_start|>"})
llama_tokenizer.add_special_tokens("<|img_end|>")

In [55]:
# We need to pad the beginning position of vqvae since discrete token is range(0,8192)
pad_idx_vqvae = llama_tokenizer.vocab['']

## Process some sample dataset from LAION

Our target here is to generate a simple text only dataset that is suitable to be used by traditional tokenizers, but also suitable for representing the images

In [56]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import webdataset as wds

def process_data(data):
 pretransforms = A.Compose([
 A.SmallestMaxSize(512),
 A.CenterCrop(512, 512, always_apply=True),
 ToTensorV2()
 ])
 data["jpg"] = pretransforms(image=data["jpg"])["image"]
 # Convert image to bfloat16
 data["jpg"] = data["jpg"].to(torch.bfloat16)
 return data

url = "file:make_a_scene/00000.tar"
dataset = wds.WebDataset(url).decode("rgb").map(process_data).to_tuple("jpg", "txt")

def collate(batch):
 images = torch.stack([i[0] for i in batch], dim=0)
 captions = [i[1] for i in batch]
 return [images, captions]

loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=collate)


In [67]:
image_text_dataset = [] 

In [69]:
counter = 0
for image,description in dataset:
 counter += 1
 if counter > 10:
 break
 tokenized_image_text = []

 # Get the tokens for the image part, do not forget to pad the position of codebook
 discrete_tokens_padded = list(vq_vae.encode(image.to(device).unsqueeze(0))[2] + pad_idx_vqvae)

 # Get the tokens of the image description
 describe_text = f"""DESCRIPTION:
 {description}
 IMAGE:
 """

 describe_text_tokens = llama_tokenizer.encode(describe_text)

 pos_img_start = llama_tokenizer.vocab['<|img_start|>']
 pos_img_end = llama_tokenizer.vocab['<|img_end|>']
 # Combine the tokens of image and text
 tokenized_image_text = describe_text_tokens + [pos_img_start] + discrete_tokens_padded + [pos_img_end]

 # Reconstruct the text
 recontructed_text = llama_tokenizer.decode(tokenized_image_text) 
 image_text_dataset.append(recontructed_text)

In [74]:
print(image_text_dataset[0])

<|begin_of_text|>DESCRIPTION:
 No Choc Easter Gifts for Babies First Easter Shoes
 IMAGE:
 <|img_start|><|img_end|>


In [75]:
llama_tokenizer.tokenize(image_text_dataset[0])

['<|begin_of_text|>',
 'DESCRIPTION',
 ':Ċ',
 'ĠĠĠ',
 'ĠNo',
 'ĠCh',
 'oc',
 'ĠEaster',
 'ĠGifts',
 'Ġfor',
 'ĠBabies',
 'ĠFirst',
 'ĠEaster',
 'ĠShoes',
 'Ċ',
 'ĠĠĠ',
 'ĠIMAGE',
 ':Ċ',
 'ĠĠĠĠ',
 '<|img_start|>',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '