In [None]:
from collections import defaultdict
from transformers import AutoTokenizer
from tqdm import tqdm
import json

def load_and_process_token_file(input_path, tokenizer_name="answerdotai/ModernBERT-base"):
 captions_dict = defaultdict(list)
 tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
 max_length = 0 # Initialize max length counter

 # Read and process the token file with tokenization
 with open(input_path, 'r') as file:
 for line in tqdm(file, desc="Processing Captions"):
 image_id, caption = line.strip().split('\t')
 jpg_number = image_id.split('.')[0]
 
 # Tokenize without padding and truncation to calculate the true length
 tokens = tokenizer(caption, return_tensors="pt", padding=False, truncation=False)
 token_ids = tokens['input_ids'].squeeze(0).tolist()
 
 # Update max_length based on this tokenized sequence length
 max_length = max(max_length, len(token_ids))
 
 # Tokenize with padding and attention mask (padded to 93 tokens)
 tokens_padded = tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=2**7) # 93 < 2**7
 token_ids_padded = tokens_padded['input_ids'].squeeze(0).tolist()
 attention_mask = tokens_padded['attention_mask'].squeeze(0).tolist()

 # Save both raw caption, tokenized version, and attention mask
 captions_dict[jpg_number].append({
 "text": caption,
 "tokenized": token_ids_padded,
 "attention_mask": attention_mask
 })

 print(f"Maximum sequence length (before padding): {max_length}")
 return captions_dict, max_length

# Define the input path and process the file
input_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/results_20130124.token'
captions_dict, max_length = load_and_process_token_file(input_path)

# Save the modified dictionary with tokenized captions and attention masks to a JSON file
output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'
with open(output_path, 'w') as json_file:
 json.dump(captions_dict, json_file)

# Display the maximum token length
print(f"Final maximum token length across dataset: {max_length}")

# Display the first few entries to verify the content
for jpg, captions in list(captions_dict.items())[:5]:
 print(f"{jpg}: {captions}")

In [None]:

# Save the dictionary to a JSON file
output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_dict.json'
with open(output_path, 'w') as json_file:
 json.dump(captions_dict, json_file)

print(f"Captions dictionary saved to {output_path}")

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import json
import numpy as np
import random


# Vision Caption Dataset
class VisionCaptionDataset(torch.utils.data.Dataset):
 def __init__(self, captions_path, embeddings_dir, normalize=True):
 with open(captions_path, 'r') as f:
 self.captions_dict = json.load(f)

 self.embeddings_dir = embeddings_dir
 self.image_ids = list(self.captions_dict.keys())
 self.normalize = normalize

 def __len__(self):
 return len(self.image_ids)

 def __getitem__(self, idx):
 image_id = self.image_ids[idx]
 
 # Randomly select a caption and load the tokenized version
 caption_entry = random.choice(self.captions_dict[image_id])
 tokenized_caption = caption_entry["tokenized"]
 attention_mask = caption_entry["attention_mask"]

 # Load vision embedding
 embedding_path = os.path.join(self.embeddings_dir, f"{image_id}.npy")
 embedding = np.load(embedding_path)

 # Convert vision embedding and tokenized caption to tensors
 embedding = torch.tensor(embedding, dtype=torch.float32)
 tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)
 attention_mask = torch.tensor(attention_mask, dtype=torch.long)

 return embedding, tokenized_caption, attention_mask

# Example usage
# Paths for dataset
captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'
embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'

# Initialize the dataset and split it into train/validation sets
full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)

# Initialize the DataLoaders with `num_workers` and `pin_memory`
train_dataloader = DataLoader(full_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)


In [None]:
# Verify a batch
for batch in train_dataloader:
 embeddings, captions, attn_mask = batch
 print(embeddings.shape, len(captions))
 

 break