{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "from transformers import AutoTokenizer\n", "from tqdm import tqdm\n", "import json\n", "\n", "def load_and_process_token_file(input_path, tokenizer_name=\"answerdotai/ModernBERT-base\"):\n", " captions_dict = defaultdict(list)\n", " tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n", " max_length = 0 # Initialize max length counter\n", "\n", " # Read and process the token file with tokenization\n", " with open(input_path, 'r') as file:\n", " for line in tqdm(file, desc=\"Processing Captions\"):\n", " image_id, caption = line.strip().split('\\t')\n", " jpg_number = image_id.split('.')[0]\n", " \n", " # Tokenize without padding and truncation to calculate the true length\n", " tokens = tokenizer(caption, return_tensors=\"pt\", padding=False, truncation=False)\n", " token_ids = tokens['input_ids'].squeeze(0).tolist()\n", " \n", " # Update max_length based on this tokenized sequence length\n", " max_length = max(max_length, len(token_ids))\n", " \n", " # Tokenize with padding and attention mask (padded to 93 tokens)\n", " tokens_padded = tokenizer(caption, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=2**7) # 93 < 2**7\n", " token_ids_padded = tokens_padded['input_ids'].squeeze(0).tolist()\n", " attention_mask = tokens_padded['attention_mask'].squeeze(0).tolist()\n", "\n", " # Save both raw caption, tokenized version, and attention mask\n", " captions_dict[jpg_number].append({\n", " \"text\": caption,\n", " \"tokenized\": token_ids_padded,\n", " \"attention_mask\": attention_mask\n", " })\n", "\n", " print(f\"Maximum sequence length (before padding): {max_length}\")\n", " return captions_dict, max_length\n", "\n", "# Define the input path and process the file\n", "input_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/results_20130124.token'\n", "captions_dict, max_length = load_and_process_token_file(input_path)\n", "\n", "# Save the modified dictionary with tokenized captions and attention masks to a JSON file\n", "output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n", "with open(output_path, 'w') as json_file:\n", " json.dump(captions_dict, json_file)\n", "\n", "# Display the maximum token length\n", "print(f\"Final maximum token length across dataset: {max_length}\")\n", "\n", "# Display the first few entries to verify the content\n", "for jpg, captions in list(captions_dict.items())[:5]:\n", " print(f\"{jpg}: {captions}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# Save the dictionary to a JSON file\n", "output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_dict.json'\n", "with open(output_path, 'w') as json_file:\n", " json.dump(captions_dict, json_file)\n", "\n", "print(f\"Captions dictionary saved to {output_path}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "import os\n", "import json\n", "import numpy as np\n", "import random\n", "\n", "\n", "# Vision Caption Dataset\n", "class VisionCaptionDataset(torch.utils.data.Dataset):\n", " def __init__(self, captions_path, embeddings_dir, normalize=True):\n", " with open(captions_path, 'r') as f:\n", " self.captions_dict = json.load(f)\n", "\n", " self.embeddings_dir = embeddings_dir\n", " self.image_ids = list(self.captions_dict.keys())\n", " self.normalize = normalize\n", "\n", " def __len__(self):\n", " return len(self.image_ids)\n", "\n", " def __getitem__(self, idx):\n", " image_id = self.image_ids[idx]\n", " \n", " # Randomly select a caption and load the tokenized version\n", " caption_entry = random.choice(self.captions_dict[image_id])\n", " tokenized_caption = caption_entry[\"tokenized\"]\n", " attention_mask = caption_entry[\"attention_mask\"]\n", "\n", " # Load vision embedding\n", " embedding_path = os.path.join(self.embeddings_dir, f\"{image_id}.npy\")\n", " embedding = np.load(embedding_path)\n", "\n", " # Convert vision embedding and tokenized caption to tensors\n", " embedding = torch.tensor(embedding, dtype=torch.float32)\n", " tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)\n", " attention_mask = torch.tensor(attention_mask, dtype=torch.long)\n", "\n", " return embedding, tokenized_caption, attention_mask\n", "\n", "# Example usage\n", "# Paths for dataset\n", "captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n", "embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'\n", "\n", "# Initialize the dataset and split it into train/validation sets\n", "full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)\n", "\n", "# Initialize the DataLoaders with `num_workers` and `pin_memory`\n", "train_dataloader = DataLoader(full_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Verify a batch\n", "for batch in train_dataloader:\n", " embeddings, captions, attn_mask = batch\n", " print(embeddings.shape, len(captions))\n", " \n", "\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "hf-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }