File size: 67,053 Bytes
d2b2a20 |
1 |
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!kaggle datasets download -d mnkbiswas/pixmo-points-50k\n!unzip -q pixmo-points-50k.zip\n!rm pixmo-points-50k.zip","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport torchvision.transforms as transforms\nfrom PIL import Image\nimport pickle\nfrom tqdm.auto import tqdm\nfrom torch.utils.data import Dataset, DataLoader\nimport numpy as np\nfrom transformers import AutoTokenizer\nfrom transformers.image_utils import load_image\nimport os\nimport wandb\nfrom datetime import datetime\nfrom torch.optim.lr_scheduler import OneCycleLR\n\n\n# --- Constants (Copied from your original code for completeness) ---\n# Model architecture constants\nIMAGE_SIZE = 512\nPATCH_SIZE = 16\nHIDDEN_DIM = 256 # smolvlm has 576 -> Keep this consistent for text and projected image\nCONTEXT_LENGTH = 1536 # Max combined length: (512/16)**2 + 1 (CLS) + 512 = 1024 + 1 + 512 = 1537. Adjust if needed. Let's keep 1536 for now.\nTEXT_LENGTH = 512 # Max *text* length\nDROPOUT = 0.1 # Slightly increased dropout might help regularization\nNUM_HEADS = 8 # Reduced heads for smaller HIDDEN_DIM might be more stable\nNUM_LAYERS = 12 # Reduced layers might be easier to train initially\n\n# Training constants\nBATCH_SIZE = 4 # Increased batch size slightly if GPU memory allows\nLEARNING_RATE = 3e-4 # Common starting point for transformers\nDTYPE = torch.float32 #torch.bfloat16\nGRAD_ACCUMULATION_STEPS = 8 # Adjusted accumulation\n# Image normalization constants\nIMAGE_MEAN = [0.485, 0.456, 0.406]\nIMAGE_STD = [0.229, 0.224, 0.225]\n\nIMAGE_SIZE = 512\nPATCH_SIZE = 16\nHIDDEN_DIM = 256\nCONTEXT_LENGTH = 1536\nTEXT_LENGTH = 512 # Max length for *target* sequence (coords)\nPROMPT_LENGTH = 64 # Max length for *prompt* sequence (description) - Adjust as needed\nDROPOUT = 0.1\nNUM_HEADS = 8\nNUM_LAYERS = 12 # Keep moderate layers\nBATCH_SIZE = 2\nLEARNING_RATE = 1e-3 # Lower LR might be needed with contrastive loss\nDTYPE = torch.float32\nGRAD_ACCUMULATION_STEPS = 8\nIMAGE_MEAN = [0.485, 0.456, 0.406]\nIMAGE_STD = [0.229, 0.224, 0.225]\nDEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\nIMAGE_LOCATION = \"./images/\"\nNUM_BINS = 32\nSHARED_EMBED_DIM = 256 # Dimension for contrastive space\nLAMBDA_CONTRASTIVE = 2 # Weight for contrastive loss - TUNE THIS\nLAMBDA_REGRESSION = 2\n\n# Device configuration\nDEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n# Image storage location\nIMAGE_LOCATION = \"./images/\"\nNUM_BINS = 32\n\n# --- Tokenizer and Data Loading (Assume unchanged from your code) ---\ndef get_tokenizer():\n tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n point_tokens = [f\"coord_bin_{i}\" for i in range(0, NUM_BINS)]\n new_tokens = [\n \"<point_start>\", \"<point_end>\", \"<result_start>\",\n \"<result_end>\", \"<pointx_start>\", \"<pointx_end>\",\n \"<pointy_start>\", \"<pointy_end>\", \"<img_embed>\", # We might not explicitly use <img_embed> token if we prepend\n *point_tokens\n ]\n tokenizer.add_tokens(new_tokens)\n # Ensure pad token is set (GPT2 usually doesn't have one by default)\n if tokenizer.pad_token is None:\n tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Or use eos_token if preferred\n # tokenizer.pad_token_id = tokenizer.eos_token_id # Alternative if you want padding to be EOS\n\n # Resize model embeddings if pad token was added\n # This step should be done *after* defining the model that uses the tokenizer\n # model.resize_token_embeddings(len(tokenizer))\n\n print(f\"Tokenizer pad token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}\")\n print(f\"Tokenizer EOS token: {tokenizer.eos_token}, ID: {tokenizer.eos_token_id}\")\n\n # Check if pad token ID is valid\n if tokenizer.pad_token_id is None:\n raise ValueError(\"Tokenizer pad token ID is not set!\")\n\n return tokenizer, len(tokenizer)\n\ndef image_to_tensor(image, image_size=IMAGE_SIZE):\n if image.mode != 'RGB':\n image = image.convert('RGB')\n transform = transforms.Compose([\n transforms.Resize((image_size, image_size)),\n transforms.ToTensor(),\n transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD)\n ])\n return transform(image)\n\ndef tensor_to_image(tensor):\n tensor = tensor.clone().detach()\n if tensor.is_cuda:\n tensor = tensor.cpu()\n mean = torch.tensor(IMAGE_MEAN).view(3, 1, 1)\n std = torch.tensor(IMAGE_STD).view(3, 1, 1)\n tensor = tensor * std + mean\n tensor = torch.clamp(tensor, 0, 1)\n image_np = tensor.numpy().transpose(1, 2, 0)\n image_np = (image_np * 255).astype(np.uint8)\n return Image.fromarray(image_np)\n","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def format_point_text(points):\n # ... (returns text string like <result_start>...<eos>) ...\n text = \"<result_start>\"\n for point in points:\n px = min(int(point['x'] * IMAGE_SIZE / 100), IMAGE_SIZE - 1)\n py = min(int(point['y'] * IMAGE_SIZE / 100), IMAGE_SIZE - 1)\n x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)\n y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)\n text += f\"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>\"\n text += \"<result_end>\" + tokenizer.eos_token\n return text\n\ndef format_data_for_training(sample):\n \"\"\"Format data sample for training, adding continuous coordinates.\"\"\"\n try:\n image = Image.open(f\"{IMAGE_LOCATION}{sample['image_url']}\")\n image_tensor = image_to_tensor(image)\n\n prompt_text = f\"<point_start>{sample['label']}<point_end>\"\n target_text = format_point_text(sample['points'])\n\n prompt_tokens = tokenizer(prompt_text, return_tensors=\"pt\", max_length=PROMPT_LENGTH, truncation=True, padding=False)\n target_tokens = tokenizer(target_text, return_tensors=\"pt\", max_length=TEXT_LENGTH, truncation=True, padding=False)\n\n if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0:\n return None\n\n # --- Add Continuous Coordinates ---\n # Assuming only one point for now, scale to [0, 1] range\n if len(sample['points']) == 1:\n point = sample['points'][0]\n # Scale coordinates to [0, 1] range relative to image size\n # Avoid dividing by zero if IMAGE_SIZE is not set\n img_w, img_h = IMAGE_SIZE, IMAGE_SIZE # Use constant size\n coord_x = min(max(point['x'] / 100.0, 0.0), 1.0) # Assuming points['x'] is percentage\n coord_y = min(max(point['y'] / 100.0, 0.0), 1.0)\n continuous_coords = torch.tensor([coord_x, coord_y], dtype=torch.float32)\n else:\n # Handle cases with zero or multiple points if necessary\n # For now, return None or a placeholder if not exactly one point\n print(f\"Warning: Skipping sample with {len(sample['points'])} points (expected 1). URL: {sample.get('image_url', 'N/A')}\")\n # If you want to handle multiple points, continuous_coords should be (N, 2)\n # and collate_fn needs modification.\n return None # Simplest for now: only train on single points\n\n return {\n \"image\": image_tensor,\n \"prompt_ids\": prompt_tokens.input_ids[0],\n \"target_ids\": target_tokens.input_ids[0],\n \"continuous_coords\": continuous_coords, # Add scaled coords (x, y)\n \"label\": sample['label'],\n \"image_url\": sample['image_url']\n }\n except FileNotFoundError:\n print(f\"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.\")\n return None\n except Exception as e:\n print(f\"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.\")\n return None\n\n\nclass PointDataset(Dataset):\n # __init__ remains largely the same, just needs to handle None return from format_data\n def __init__(self, data_path=\"active_point_dataset.pkl\", split=\"train\", test_size=1000):\n # ... (loading raw data, filtering for 1 point samples) ...\n with open(data_path, \"rb\") as f:\n raw_data = pickle.load(f)\n # Keep only single-point samples for simplicity with regression target\n raw_data = [sample for sample in raw_data if len(sample['points']) == 1]\n # ... (train/test split logic) ...\n total_samples = len(raw_data)\n if total_samples <= test_size:\n print(f\"Warning: Dataset size {total_samples} <= test_size {test_size}.\")\n test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0\n train_end = total_samples - test_size\n print(f\"Dataset: {total_samples} total, {train_end} train, {test_size} test (single point only)\")\n\n if split == \"train\": self.raw_data = raw_data[:train_end]\n elif split == \"test\": self.raw_data = raw_data[train_end:]\n else: raise ValueError(\"split must be 'train' or 'test'\")\n # Add small subset for debugging if needed\n # self.raw_data = self.raw_data[:100]\n\n print(f\"Loading {len(self.raw_data)} raw samples for {split}...\")\n self.data = []\n for sample in tqdm(self.raw_data, desc=f\"Processing {split} data\"):\n formatted = format_data_for_training(sample) # Handles exceptions inside\n if formatted is not None:\n self.data.append(formatted)\n\n print(f\"Successfully loaded {len(self.data)} samples for {split} set\")\n if len(self.data) == 0 and len(self.raw_data) > 0:\n print(\"ERROR: No samples loaded. Check paths/formatting/filtering.\")\n\n def __len__(self):\n return len(self.data)\n\n def __getitem__(self, idx):\n return self.data[idx]\n\n # Collate function needs to handle 'continuous_coords'\n @staticmethod\n def collate_fn(batch):\n batch = [item for item in batch if item is not None]\n if not batch: return None\n\n images = torch.stack([item['image'] for item in batch]).to(DTYPE)\n\n # --- Pad Prompt IDs ---\n # ... (same as before) ...\n max_prompt_len = max(item['prompt_ids'].size(0) for item in batch)\n prompt_ids_padded, prompt_attention_mask = [], []\n for item in batch:\n ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0)\n prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))\n prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]))\n prompt_ids = torch.stack(prompt_ids_padded)\n prompt_attention_mask = torch.stack(prompt_attention_mask)\n\n # --- Pad Target IDs & Create Generative Targets ---\n # ... (same as before) ...\n max_target_len = max(item['target_ids'].size(0) for item in batch)\n target_ids_padded, target_attention_mask, generative_targets = [], [], []\n for item in batch:\n ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0)\n padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])\n target_ids_padded.append(padded_ids)\n mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])\n target_attention_mask.append(mask)\n targets = torch.full_like(padded_ids, -100)\n targets[:ids.size(0)-1] = ids[1:]\n if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id: targets[ids.size(0)-1] = tokenizer.eos_token_id\n generative_targets.append(targets)\n target_ids = torch.stack(target_ids_padded)\n target_attention_mask = torch.stack(target_attention_mask)\n generative_targets = torch.stack(generative_targets)\n\n # --- Stack Continuous Coords ---\n # Shape: (B, 2) where 2 is for (x, y)\n continuous_coords = torch.stack([item['continuous_coords'] for item in batch])\n\n labels = [item['label'] for item in batch]\n image_urls = [item.get('image_url', '') for item in batch]\n\n return {\n 'image': images,\n 'prompt_ids': prompt_ids,\n 'prompt_attention_mask': prompt_attention_mask,\n 'target_ids': target_ids,\n 'target_attention_mask': target_attention_mask,\n 'generative_targets': generative_targets, # For bin classification loss\n 'continuous_coords': continuous_coords, # For regression loss\n 'label': labels,\n 'image_url': image_urls\n }","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0): # Use 0 workers for debugging\n dataset = PointDataset(split=\"train\")\n if len(dataset) == 0: return None\n return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=PointDataset.collate_fn, pin_memory=True, num_workers=num_workers)\n\ndef create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0):\n dataset = PointDataset(split=\"test\")\n # Allow empty test loader\n return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=PointDataset.collate_fn, pin_memory=True, num_workers=num_workers)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class PatchEmbeddings(nn.Module):\n def __init__(self, patch_size=PATCH_SIZE, hidden_dim=HIDDEN_DIM):\n super().__init__()\n self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)\n\n def forward(self, X):\n X = self.conv(X) # (B, C, H/P, W/P)\n X = X.flatten(2) # (B, C, N) where N = (H/P)*(W/P)\n X = X.transpose(1, 2) # (B, N, C)\n return X\n\nclass Head(nn.Module):\n def __init__(self, n_embd, head_size, dropout=DROPOUT, is_decoder=False):\n super().__init__()\n self.key = nn.Linear(n_embd, head_size, bias=False)\n self.query = nn.Linear(n_embd, head_size, bias=False)\n self.value = nn.Linear(n_embd, head_size, bias=False)\n self.dropout = nn.Dropout(dropout)\n self.is_decoder = is_decoder\n # causal mask is registered persistent=False so it's not saved in state_dict\n if self.is_decoder:\n self.register_buffer(\"bias\", torch.tril(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH, dtype=torch.bool))\n .view(1, CONTEXT_LENGTH, CONTEXT_LENGTH), persistent=False)\n\n\n def forward(self, x, attention_mask=None):\n B, T, C = x.shape\n # print(f\"B = {B} T={T}, C={C}\")\n k = self.key(x) # (B, T, hs)\n q = self.query(x) # (B, T, hs)\n v = self.value(x) # (B, T, hs)\n\n # Compute attention scores (\"affinities\")\n wei = q @ k.transpose(-2, -1) * (k.size(-1)**-0.5) # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n\n if self.is_decoder:\n # Apply causal mask\n # Ensure the mask is sliced correctly if T < CONTEXT_LENGTH\n causal_mask = self.bias[:, :T, :T]\n wei = wei.masked_fill(causal_mask == 0, float('-inf'))\n\n if attention_mask is not None:\n # Apply padding mask (for text tokens)\n # attention_mask shape: (B, T_combined) -> needs expansion\n # Expand mask: (B, T) -> (B, 1, 1, T) or (B, 1, T, T) depending on what needs masking\n # Mask where attention_mask is 0\n # attention_mask shape: (B, T) == (B, T_key)\n # Expand mask to align with wei's key dimension for broadcasting across queries\n # Target shape for mask: [B, 1, T_key]\n # print(f\"attn mask = {attention_mask.shape}\")\n # print(f\"wei shape = {wei.shape}\")\n mask = attention_mask.unsqueeze(1) # Shape [B, 1, T]\n # Apply mask using broadcasting rules. masked_fill condition needs to be broadcastable to wei [B, T_query, T_key]\n # (mask == 0) gives a boolean tensor of shape [B, 1, T]\n # This broadcasts correctly: dim 2 (T vs T) matches, dim 1 (1 vs T) broadcasts 1->T, dim 0 (B vs B) matches.\n wei = wei.masked_fill(mask == 0, float('-inf'))\n\n\n # Apply softmax\n wei = F.softmax(wei, dim=-1)\n wei = self.dropout(wei)\n\n # Perform weighted aggregation of values\n out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n # print(f\"out shape = {out.shape}\")\n return out\n\nclass MultiHeadAttention(nn.Module):\n def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):\n super().__init__()\n assert n_embd % num_heads == 0\n head_size = n_embd // num_heads\n self.heads = nn.ModuleList([\n Head(n_embd, head_size, dropout, is_decoder)\n for _ in range(num_heads)\n ])\n self.proj = nn.Linear(n_embd, n_embd) # n_embd = num_heads * head_size\n self.dropout = nn.Dropout(dropout)\n self.is_decoder = is_decoder # Store is_decoder status\n\n def forward(self, x, attention_mask=None):\n # Pass attention_mask only if it's a decoder block dealing with combined sequence\n out = torch.cat([h(x, attention_mask=attention_mask if self.is_decoder else None) for h in self.heads], dim=-1)\n out = self.dropout(self.proj(out))\n return out\n\n\nclass FeedForward(nn.Module):\n \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n def __init__(self, n_embd, dropout=DROPOUT):\n super().__init__()\n self.net = nn.Sequential(\n nn.Linear(n_embd, 4 * n_embd),\n nn.GELU(), # Changed from ReLU to GELU, common in transformers\n nn.Linear(4 * n_embd, n_embd), # Projection back to residual stream\n nn.Dropout(dropout),\n )\n\n def forward(self, x):\n return self.net(x)\n\nclass Block(nn.Module):\n \"\"\" Transformer block: communication followed by computation \"\"\"\n def __init__(self, n_embd, num_heads=NUM_HEADS, dropout=DROPOUT, is_decoder=False):\n super().__init__()\n self.ln1 = nn.LayerNorm(n_embd)\n self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)\n self.ln2 = nn.LayerNorm(n_embd)\n self.ffn = FeedForward(n_embd, dropout)\n self.is_decoder = is_decoder # Store is_decoder status\n\n def forward(self, x, attention_mask=None):\n # Pass attention_mask only if it's a decoder block\n # print(f\"is decoder = {self.is_decoder} input shape = {x.shape}\")\n x = x + self.attn(self.ln1(x), attention_mask=attention_mask if self.is_decoder else None)\n x = x + self.ffn(self.ln2(x))\n # print(f\"output shape = {x.shape}\")\n return x\n\nclass ViT(nn.Module):\n def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_hiddens=HIDDEN_DIM,\n num_heads=NUM_HEADS, num_blks=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT):\n super().__init__()\n self.patch_embedding = PatchEmbeddings(patch_size, num_hiddens)\n self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))\n num_patches = (img_size // patch_size) ** 2\n self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens) * 0.02) # Smaller init\n self.dropout = nn.Dropout(emb_dropout)\n # ViT blocks are NOT decoders (no causal mask)\n self.blocks = nn.ModuleList([Block(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])\n self.layer_norm = nn.LayerNorm(num_hiddens) # Final LN\n\n def forward(self, X):\n x = self.patch_embedding(X) # (B, N, C)\n cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, C)\n x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, C)\n # Add positional embedding\n x = x + self.pos_embedding # Uses broadcasting\n x = self.dropout(x)\n for block in self.blocks:\n # ViT blocks don't need attention_mask\n x = block(x)\n x = self.layer_norm(x) # Apply final layer norm\n return x\n\nclass MultiModalProjector(nn.Module):\n # Projects image embedding dim to text embedding dim\n def __init__(self, image_embed_dim=HIDDEN_DIM, text_embed_dim=HIDDEN_DIM, dropout=DROPOUT):\n super().__init__()\n self.net = nn.Sequential(\n nn.Linear(image_embed_dim, text_embed_dim * 4), # Intermediate expansion\n nn.GELU(),\n nn.Linear(text_embed_dim * 4, text_embed_dim),\n nn.Dropout(dropout)\n )\n\n def forward(self, x):\n return self.net(x)\n","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class DecoderLanguageModel(nn.Module):\n \"\"\"\n Transformer Decoder Language Model with optional coordinate regression head.\n\n Processes a combined sequence of embeddings (e.g., image + prompt + target).\n Outputs logits for token prediction (classification) and optionally\n regressed coordinates.\n \"\"\"\n def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,\n n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):\n super().__init__()\n # --- Input Embeddings ---\n self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n self.position_embedding_table = nn.Embedding(max_context, n_embd)\n self.dropout = nn.Dropout(dropout) # Dropout after embeddings + pos enc\n\n # --- Transformer Blocks ---\n # Ensure Block class definition is accessible and includes max_context if needed by Head\n self.blocks = nn.ModuleList([\n Block(n_embd, num_heads, dropout, is_decoder=True, max_context=max_context)\n for _ in range(n_layer)\n ])\n\n # --- Final Layer Norm (applied before output heads) ---\n self.ln_f = nn.LayerNorm(n_embd)\n\n # --- Output Heads ---\n # 1. Head for token classification (predicting token bins)\n self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)\n\n # 2. Head for direct coordinate regression (predicting continuous x, y)\n self.regression_head = nn.Sequential(\n nn.Linear(n_embd, n_embd // 2), # Intermediate layer\n nn.GELU(), # Non-linearity\n nn.Linear(n_embd // 2, 2), # Output: 2 values for (x, y)\n nn.Sigmoid() # Output activation to constrain coords to [0, 1]\n )\n # --- End Output Heads ---\n\n # Store config\n self.n_embd = n_embd\n self.max_context = max_context\n\n # Weight tying for LM head\n self.token_embedding_table.weight = self.lm_head.weight\n\n # Initialize weights\n self.apply(self._init_weights)\n print(f\"DecoderLanguageModel initialized with {n_layer} layers.\")\n\n def _init_weights(self, module):\n \"\"\"Initializes weights for linear and embedding layers.\"\"\"\n if isinstance(module, nn.Linear):\n torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n if module.bias is not None:\n torch.nn.init.zeros_(module.bias)\n elif isinstance(module, nn.Embedding):\n torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n elif isinstance(module, nn.LayerNorm):\n # Initialize LayerNorm bias to 0, weight to 1\n torch.nn.init.zeros_(module.bias)\n torch.nn.init.ones_(module.weight)\n\n def forward(self, combined_embeds, attention_mask=None, targets=None):\n \"\"\"\n Forward pass for training or inference where loss is calculated.\n\n Args:\n combined_embeds (torch.Tensor): Input embeddings combined from different\n modalities/sources (e.g., image+prompt+target).\n Shape: (B, T_combined, C)\n attention_mask (torch.Tensor, optional): Mask for padding tokens in the\n combined sequence. Shape: (B, T_combined)\n targets (torch.Tensor, optional): Target token IDs for classification loss,\n shifted left, with padding/ignored sections\n marked by -100. Shape: (B, T_combined)\n\n Returns:\n tuple:\n - logits (torch.Tensor): Output logits for token classification.\n Shape: (B, T_combined, VocabSize)\n - class_loss (torch.Tensor | None): Calculated cross-entropy loss for\n token prediction, or None if targets\n are not provided.\n - x_norm (torch.Tensor): The normalized hidden states *before* the\n output heads. Shape: (B, T_combined, C).\n Useful for passing to auxiliary heads (like regression)\n outside this module if needed.\n \"\"\"\n # --- Input Validation & Processing ---\n if combined_embeds.ndim != 3:\n raise ValueError(f\"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}\")\n\n B, T, C = combined_embeds.shape\n\n # Truncate sequence if longer than max context length\n if T > self.max_context:\n print(f\"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.\")\n # Keep the most recent 'max_context' tokens\n combined_embeds = combined_embeds[:, -self.max_context:, :]\n if attention_mask is not None:\n attention_mask = attention_mask[:, -self.max_context:]\n if targets is not None:\n targets = targets[:, -self.max_context:]\n T = self.max_context # Update sequence length\n\n # --- Positional Encoding ---\n # Create position indices: 0, 1, ..., T-1\n pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)\n # Clamp indices to be within the range of the position embedding table\n pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)\n pos_emb = self.position_embedding_table(pos) # Shape: (T, C)\n # Add positional embeddings (broadcasts along batch dim)\n x = combined_embeds + pos_emb.unsqueeze(0)\n x = self.dropout(x) # Apply dropout\n\n # --- Transformer Blocks ---\n # Pass through decoder layers\n for block in self.blocks:\n # Pass attention_mask to handle padding within the combined sequence\n x = block(x, attention_mask=attention_mask) # Output Shape: (B, T, C)\n\n # --- Final Layer Norm ---\n # Apply layer normalization before the output heads\n x_norm = self.ln_f(x) # Shape: (B, T, C)\n\n # --- Classification Head Output ---\n # Calculate logits for predicting the next token bin\n logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)\n\n # --- Classification Loss Calculation ---\n class_loss = None\n if targets is not None:\n # Calculate cross-entropy loss, ignoring padding/masked tokens (-100)\n # Reshape logits and targets for cross_entropy:\n # Logits: (B * T, VocabSize)\n # Targets: (B * T)\n try:\n class_loss = F.cross_entropy(\n logits.view(-1, logits.size(-1)),\n targets.view(-1),\n ignore_index=-100\n )\n # Handle potential NaN loss (e.g., if all targets are ignored)\n if torch.isnan(class_loss):\n print(\"Warning: class_loss is NaN.\")\n class_loss = None # Or set to zero tensor? torch.tensor(0.0, device=DEVICE, requires_grad=True)\n\n except Exception as e:\n print(f\"Error calculating cross_entropy: {e}\")\n print(f\"Logits shape: {logits.shape}, Targets shape: {targets.shape}\")\n # Potentially inspect targets for issues (e.g., all -100?)\n # print(f\"Unique target values: {torch.unique(targets)}\")\n class_loss = None\n\n # Note: Regression output/loss is calculated *outside* this module\n # in the main VisionLanguageModel forward pass, using x_norm.\n return logits, class_loss, x_norm\n\n # --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---\n # If VLM needs this class to perform generation based on token IDs:\n def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n \"\"\"\n Autoregressive generation based on starting token IDs.\n NOTE: This version doesn't handle combined embeddings directly.\n The VisionLanguageModel should ideally use a method like\n generate_from_embeddings or implement the loop externally.\n \"\"\"\n self.eval()\n for _ in range(max_new_tokens):\n # --- Context Management ---\n # Crop idx if longer than context length\n idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]\n\n # --- Forward Pass ---\n # Get embeddings\n tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)\n # Get positional embeddings\n pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)\n pos = pos.clamp(max=self.max_context - 1)\n pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)\n x = self.dropout(tok_embeds + pos_emb)\n # Pass through blocks (no padding mask needed here as we handle single sequence)\n for block in self.blocks:\n x = block(x, attention_mask=None) # Causal mask is internal to block/head\n # Final layer norm and head for the last token only\n x = self.ln_f(x[:, -1:, :]) # (B, 1, C)\n logits = self.lm_head(x) # (B, 1, V)\n logits = logits.squeeze(1) # (B, V)\n\n # --- Sampling ---\n logits = logits / temperature\n if top_k is not None and top_k > 0:\n v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n logits[logits < v[:, [-1]]] = -float('Inf')\n probs = F.softmax(logits, dim=-1)\n idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n\n # Append sampled token\n idx = torch.cat((idx, idx_next), dim=1)\n\n # Stop if EOS\n if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():\n break\n self.train()\n return idx","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class VisionLanguageModel(nn.Module):\n \"\"\"\n Vision Language Model integrating:\n - A Vision Transformer (ViT) for image encoding.\n - A multimodal projector for image features.\n - Contrastive alignment loss between image (CLS) and text prompt (last token).\n - A Transformer Decoder for autoregressive generation.\n - Dual output heads in the Decoder:\n - Classification head (lm_head) for predicting token bins.\n - Regression head for predicting continuous [0, 1] coordinates.\n - Combined loss calculation.\n \"\"\"\n def __init__(self,\n n_embd=HIDDEN_DIM,\n vocab_size=vocab_size,\n img_size=IMAGE_SIZE,\n patch_size=PATCH_SIZE,\n num_heads=NUM_HEADS,\n num_blks_vit=NUM_LAYERS, # Num layers for ViT\n num_blks_dec=NUM_LAYERS, # Num layers for Decoder\n emb_dropout=DROPOUT,\n blk_dropout=DROPOUT,\n max_context=CONTEXT_LENGTH,\n shared_embed_dim=SHARED_EMBED_DIM,\n lambda_contrastive=LAMBDA_CONTRASTIVE,\n lambda_regression=LAMBDA_REGRESSION\n ):\n super().__init__()\n\n # --- Vision Backbone ---\n self.vision_encoder = ViT(\n img_size=img_size,\n patch_size=patch_size,\n num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim\n num_heads=num_heads,\n num_blks=num_blks_vit,\n emb_dropout=emb_dropout,\n blk_dropout=blk_dropout\n )\n\n # --- Multimodal Components ---\n # Projector for adapting ViT patch embeddings for the decoder sequence\n self.multimodal_projector = MultiModalProjector(\n image_embed_dim=n_embd, # Input from ViT\n text_embed_dim=n_embd, # Output matches decoder dim\n dropout=emb_dropout\n )\n # Projection heads for contrastive loss\n self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)\n self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)\n # Learnable temperature for contrastive loss\n self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))\n\n # --- Text Decoder ---\n # DecoderLanguageModel now includes both lm_head and regression_head internally\n self.decoder = DecoderLanguageModel(\n n_embd=n_embd,\n vocab_size=vocab_size,\n num_heads=num_heads,\n n_layer=num_blks_dec,\n max_context=max_context,\n dropout=blk_dropout # Use block dropout for decoder consistency\n )\n\n # --- Store Configuration ---\n self.n_embd = n_embd\n self.vocab_size = vocab_size # Store vocab size for resizing check\n self.num_patches = (img_size // patch_size)**2 + 1 # Including CLS token\n self.lambda_contrastive = lambda_contrastive\n self.lambda_regression = lambda_regression\n\n # Check and potentially resize embeddings after full init\n self._resize_embeddings_if_needed(self.vocab_size)\n print(\"VisionLanguageModel initialized.\")\n\n\n def _resize_embeddings_if_needed(self, current_vocab_size):\n \"\"\" Resizes decoder token embeddings if vocab size changed after init. \"\"\"\n decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings\n if decoder_embedding_size != current_vocab_size:\n print(f\"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}\")\n # Freeze original weights before replacing layers\n self.decoder.token_embedding_table.weight.requires_grad = False\n self.decoder.lm_head.weight.requires_grad = False\n # Create new layers\n new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)\n new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)\n # Assign new layers\n self.decoder.token_embedding_table = new_embedding\n self.decoder.lm_head = new_lm_head\n # Re-tie weights\n self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight\n print(\"VLM decoder embeddings resized and weights retied.\")\n\n\n def _calculate_contrastive_loss(self, image_features, text_features):\n \"\"\" Calculates the symmetric InfoNCE loss. \"\"\"\n # Assumes features are already projected to shared_embed_dim\n # image_features: (B, E)\n # text_features: (B, E)\n\n # Normalize features\n image_features = F.normalize(image_features, dim=-1)\n text_features = F.normalize(text_features, dim=-1)\n\n # Cosine similarity as logits (using learnable temperature)\n logit_scale = self.logit_scale.exp()\n logits_per_image = logit_scale * image_features @ text_features.t()\n logits_per_text = logits_per_image.t()\n\n # Calculate symmetric cross-entropy loss\n labels = torch.arange(len(logits_per_image), device=logits_per_image.device)\n loss_i = F.cross_entropy(logits_per_image, labels)\n loss_t = F.cross_entropy(logits_per_text, labels)\n contrastive_loss = (loss_i + loss_t) / 2.0\n\n # Handle potential NaNs\n if torch.isnan(contrastive_loss):\n print(\"Warning: Contrastive loss is NaN.\")\n return None # Return None or zero tensor\n\n return contrastive_loss\n\n def forward(self,\n img_array, # (B, 3, H, W)\n prompt_ids, # (B, T_prompt)\n prompt_attention_mask, # (B, T_prompt)\n target_ids, # (B, T_target) - Input sequence <result_start>...<eos>\n target_attention_mask, # (B, T_target) - Mask for target_ids padding\n generative_targets=None, # (B, T_target) - Shifted target tokens for CLASS loss (-100 padded)\n continuous_coords=None # (B, 2) - Ground truth [0,1] coords for REGRESSION loss\n ):\n \"\"\"\n Main forward pass for training. Calculates combined loss.\n\n Args:\n img_array: Batch of images.\n prompt_ids: Batch of prompt token IDs.\n prompt_attention_mask: Mask for prompt padding.\n target_ids: Batch of target sequence token IDs (e.g., starting with <result_start>).\n target_attention_mask: Mask for target_ids padding.\n generative_targets: Shifted target IDs for cross-entropy loss (-100 ignored).\n continuous_coords: Ground truth continuous coordinates [0,1] for regression loss.\n\n Returns:\n tuple:\n - logits (torch.Tensor): Logits from the classification head.\n - regression_output (torch.Tensor | None): Output from the regression head.\n - total_loss (torch.Tensor): Combined weighted loss.\n - class_loss (torch.Tensor | None): Classification (cross-entropy) loss.\n - contrastive_loss (torch.Tensor | None): Contrastive alignment loss.\n - regression_loss (torch.Tensor | None): Coordinate regression loss (L1/L2).\n \"\"\"\n\n # --- 1. Encode Image ---\n image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)\n # N_img = num_patches + 1 (CLS)\n B, N_img, C_img = image_embeds_raw.shape\n img_cls_token = image_embeds_raw[:, 0] # Use CLS token for contrastive (B, C)\n\n # --- 2. Contrastive Loss Path ---\n contrastive_loss = None\n # Project image CLS token for contrastive loss\n image_features_contrast = self.image_contrastive_head(img_cls_token) # (B, E)\n # Get prompt text representation for contrastive loss (using last token embedding)\n with torch.no_grad(): # Avoid tracking gradients for embedding lookup if not needed? No, keep grad for text head.\n prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids) # (B, T_prompt, C)\n # Find embedding of the last *actual* prompt token\n prompt_lengths = prompt_attention_mask.sum(dim=1) # (B,)\n last_token_indices = (prompt_lengths - 1).clamp(min=0) # (B,)\n gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img) # (B, 1, C)\n prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1) # (B, C)\n # Project prompt representation\n text_features_contrast = self.text_contrastive_head(prompt_last_token_embed) # (B, E)\n # Calculate loss\n contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)\n\n # --- 3. Generative / Regression Path ---\n # Project ViT patch embeddings for decoder input sequence\n image_embeds_decoder = self.multimodal_projector(image_embeds_raw) # (B, N_img, C)\n # Embed prompt and target tokens\n prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids) # (B, T_prompt, C)\n target_embeds_decoder = self.decoder.token_embedding_table(target_ids) # (B, T_target, C)\n B, T_prompt, C = prompt_embeds_decoder.shape\n B, T_target, _ = target_embeds_decoder.shape\n\n # Prepare combined input sequence and attention mask for the decoder\n combined_embeds = torch.cat([\n image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder\n ], dim=1)\n combined_attention_mask = torch.cat([\n torch.ones(B, N_img, dtype=torch.long, device=DEVICE), # Image part is never padded\n prompt_attention_mask, # Prompt padding mask\n target_attention_mask # Target padding mask\n ], dim=1)\n T_combined = combined_embeds.shape[1] # Total sequence length for decoder\n\n # Prepare combined targets for the classification loss (ignore image & prompt)\n combined_class_targets = None\n if generative_targets is not None:\n combined_class_targets = torch.cat([\n torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),\n generative_targets # Already shifted and -100 padded for target sequence\n ], dim=1)\n\n # --- Pass through Decoder ---\n # Decoder returns logits, classification loss, and final normalized hidden states\n logits, class_loss, x_norm = self.decoder(\n combined_embeds,\n attention_mask=combined_attention_mask,\n targets=combined_class_targets\n )\n # x_norm shape: (B, T_combined, C)\n\n # --- Calculate Regression Output & Loss ---\n regression_loss = None\n regression_output = None\n if continuous_coords is not None and x_norm is not None: # Ensure hidden states are available\n # Strategy: Use hidden state corresponding to token before <result_end> (or <eos>)\n target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)\n # Index relative to start of *target sequence* is length - 2\n # Clamp index >= 0 for very short sequences (e.g., just <result_start><eos>)\n relative_target_idx = (target_lengths - 2).clamp(min=0)\n # Absolute index in the combined sequence's hidden states (x_norm)\n absolute_idx = N_img + T_prompt + relative_target_idx\n # Clamp index to be within the actual length of x_norm\n absolute_idx = absolute_idx.clamp(max=T_combined - 1)\n\n # Gather the hidden states at these specific indices for regression\n # Index needs to be (B, 1, C) for gather operation along dim 1\n gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)\n try:\n hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)\n # Pass through the regression head in the decoder\n regression_output = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, 2)\n\n # Calculate regression loss (L1 - Mean Absolute Error)\n regression_loss = F.l1_loss(regression_output, continuous_coords)\n # Optional: L2 Loss (Mean Squared Error)\n # regression_loss = F.mse_loss(regression_output, continuous_coords)\n\n # Handle potential NaNs in regression loss\n if torch.isnan(regression_loss):\n print(\"Warning: Regression loss is NaN.\")\n regression_loss = None # Or zero tensor\n\n except Exception as e:\n print(f\"Error during regression calculation: {e}\")\n print(f\"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}\")\n regression_loss = None\n regression_output = None\n\n\n # --- 4. Combine All Losses ---\n total_loss = torch.tensor(0.0, device=DEVICE)\n # Add valid losses with their respective weights\n if class_loss is not None:\n total_loss += class_loss # Weight = 1.0 assumed\n else:\n class_loss = torch.tensor(float('nan')) # Use NaN for logging if None\n\n if contrastive_loss is not None:\n total_loss += self.lambda_contrastive * contrastive_loss\n else:\n contrastive_loss = torch.tensor(float('nan'))\n\n if regression_loss is not None:\n total_loss += self.lambda_regression * regression_loss\n else:\n regression_loss = torch.tensor(float('nan'))\n\n # Handle case where all losses might be None/NaN\n if torch.isnan(total_loss):\n print(\"Warning: Total loss is NaN. Setting to zero.\")\n # Or potentially raise an error, depending on desired behavior\n total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Ensure it requires grad if needed later\n\n # Return all relevant outputs\n return logits, regression_output, total_loss, class_loss, contrastive_loss, regression_loss\n\n\n # --- Generation Method ---\n @torch.no_grad() # Ensure no gradients are computed during generation\n def generate(self, img_array, idx_prompt, max_new_tokens,\n temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None\n force_result_start=True # Option to manually add <result_start>\n ):\n \"\"\"\n Generates token sequences autoregressively based on image and prompt.\n Uses the classification head (lm_head).\n\n Args:\n img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.\n idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).\n max_new_tokens (int): Maximum number of new tokens to generate.\n temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.\n top_k (int | None): If set, restricts sampling to top K most likely tokens.\n force_result_start (bool): If True, manually appends <result_start> embedding\n after the prompt before starting generation loop.\n\n Returns:\n torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).\n \"\"\"\n self.eval() # Ensure model is in eval mode\n B = img_array.shape[0]\n if B > 1:\n # This simplified generation loop assumes B=1 for clarity\n # Batch generation requires careful handling of EOS and padding within the loop\n print(\"Warning: Generation function currently assumes batch size B=1.\")\n # Process only the first item for now\n img_array = img_array[:1]\n idx_prompt = idx_prompt[:1]\n B = 1\n\n # --- 1. Prepare Initial Embeddings ---\n image_embeds_raw = self.vision_encoder(img_array)\n image_embeds_decoder = self.multimodal_projector(image_embeds_raw)\n prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)\n\n # Initial sequence for the decoder loop\n current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)\n generated_ids_list = [] # Store newly generated IDs as a list\n\n # Manually add <result_start> if forced\n if force_result_start:\n try:\n result_start_token_id = tokenizer.encode(\"<result_start>\", add_special_tokens=False)[0]\n result_start_embed = self.decoder.token_embedding_table(\n torch.tensor([[result_start_token_id]], device=DEVICE)\n )\n current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)\n # Also store this token ID if we added it\n generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))\n except Exception as e:\n print(f\"Warning: Could not encode or add <result_start>: {e}\")\n\n\n # --- 2. Autoregressive Loop ---\n for _ in range(max_new_tokens):\n T_current = current_embeds.shape[1]\n\n # Context truncation\n if T_current > self.decoder.max_context:\n current_embeds = current_embeds[:, -self.decoder.max_context:, :]\n T_current = self.decoder.max_context\n\n # Prepare inputs for decoder blocks\n pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)\n pos = pos.clamp(max=self.decoder.max_context - 1)\n pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)\n x = current_embeds + pos_emb\n attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed\n\n # Pass through decoder blocks\n for block in self.decoder.blocks:\n x = block(x, attention_mask=attention_mask)\n\n # Get logits for the last token\n x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)\n logits = self.decoder.lm_head(x) # (B, 1, V)\n logits = logits.squeeze(1) / temperature # Apply temperature (B, V)\n\n # --- Sampling / Decoding ---\n # Optional: Top-K filtering\n if top_k is not None and top_k > 0:\n v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask\n\n # Get probabilities\n probs = F.softmax(logits, dim=-1)\n\n # Sample next token ID\n # For deterministic output (greedy), use torch.argmax instead of multinomial\n if temperature == 0.0 or top_k == 1: # Greedy condition\n idx_next = torch.argmax(probs, dim=-1, keepdim=True)\n else:\n idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n\n # Append the generated token ID\n generated_ids_list.append(idx_next)\n\n # Stop if EOS is generated\n if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:\n break\n\n # Prepare for next iteration\n next_token_embed = self.decoder.token_embedding_table(idx_next)\n current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)\n\n\n # --- 3. Combine results ---\n if generated_ids_list:\n generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)\n full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)\n else:\n full_sequence_ids = idx_prompt # Return only prompt if nothing generated\n\n self.train() # Set model back to training mode\n return full_sequence_ids","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torch.optim.lr_scheduler import OneCycleLR\nfrom tqdm.auto import tqdm\nimport wandb\nfrom datetime import datetime\nimport numpy as np # Needed for isnan checks maybe\n\n# Assuming necessary imports and definitions are available:\n# - VisionLanguageModel class (with dual heads and combined loss)\n# - create_train_dataloader, create_test_dataloader (returning continuous_coords)\n# - Constants: DEVICE, vocab_size, HIDDEN_DIM, NUM_LAYERS, IMAGE_SIZE, PATCH_SIZE,\n# NUM_HEADS, DROPOUT, CONTEXT_LENGTH, SHARED_EMBED_DIM,\n# LAMBDA_CONTRASTIVE, LAMBDA_REGRESSION, LEARNING_RATE, BATCH_SIZE,\n# GRAD_ACCUMULATION_STEPS, DTYPE, NUM_EPOCHS (replace num_epochs=50 below)\n\n# --- Constants and Configuration ---\nNUM_EPOCHS = 50 # Or your desired number of epochs\nLOGGING_STEPS = 10 # Log every N optimization steps\nMAX_GRAD_NORM = 1.0\n\nprint(f\"Using device: {DEVICE}\")\nprint(f\"Vocab size: {vocab_size}\")\n\n# --- Initialize Model ---\n# Ensure lambda_regression is passed during initialization\nmodel = VisionLanguageModel(\n n_embd=HIDDEN_DIM,\n vocab_size=vocab_size,\n img_size=IMAGE_SIZE,\n patch_size=PATCH_SIZE,\n num_heads=NUM_HEADS,\n num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers\n num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers\n emb_dropout=DROPOUT,\n blk_dropout=DROPOUT,\n max_context=CONTEXT_LENGTH,\n shared_embed_dim=SHARED_EMBED_DIM,\n lambda_contrastive=LAMBDA_CONTRASTIVE,\n lambda_regression=LAMBDA_REGRESSION # Pass the regression weight\n).to(DEVICE)\n\n# --- Optimizer ---\n# Optimizer will automatically include all model parameters, including the new regression head\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)\n\n# --- Dataloaders ---\n# Ensure these functions now return 'continuous_coords' in the batch dictionary\ntrain_loader = create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0) # Use num_workers=0 for easier debugging first\ntest_loader = create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0)\nif train_loader is None: exit(\"Training loader failed to initialize.\")\ntest_loader_has_data = test_loader and len(test_loader.dataset) > 0\n\n# --- LR Scheduler ---\nif train_loader and len(train_loader) > 0:\n steps_per_epoch = max(1, len(train_loader) // GRAD_ACCUMULATION_STEPS) # Ensure at least 1 step\n total_steps = steps_per_epoch * NUM_EPOCHS\n # Adjust warmup steps if total steps are very low\n warmup_steps = min(max(1, total_steps // 10), 10000) # Ensure at least 1, max 10k warmup\n print(f\"Total estimated optimization steps: {total_steps}, Warmup steps: {warmup_steps}\")\n lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_steps, pct_start=warmup_steps/total_steps if total_steps > 0 else 0.1)\nelse:\n print(\"Warning: Train loader empty. Using constant LR.\")\n total_steps = 0; warmup_steps = 0\n lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)\n\n# --- Wandb Setup ---\ntry:\n wandb.init(\n # project=\"point-language-model-dualhead\", # Suggest new project name\n project=\"point-language-model-regression\",\n name=f\"point-vlm-dual-{datetime.now().strftime('%Y%m%d-%H%M%S')}\",\n config={ # Add new hyperparameters\n \"image_size\": IMAGE_SIZE, \"patch_size\": PATCH_SIZE, \"hidden_dim\": HIDDEN_DIM,\n \"context_length\": CONTEXT_LENGTH, \"dropout\": DROPOUT,\n \"num_heads\": NUM_HEADS, \"num_layers\": NUM_LAYERS, \"batch_size\": BATCH_SIZE,\n \"learning_rate\": LEARNING_RATE, \"grad_accum_steps\": GRAD_ACCUMULATION_STEPS,\n \"shared_embed_dim\": SHARED_EMBED_DIM, \"lambda_contrastive\": LAMBDA_CONTRASTIVE,\n \"lambda_regression\": LAMBDA_REGRESSION, # Log regression weight\n \"architecture\": \"VisionLanguageModel (Dual Head)\", \"optimizer\": \"AdamW\",\n \"num_epochs\": NUM_EPOCHS, \"total_steps\": total_steps, \"warmup_steps\": warmup_steps\n }\n )\n wandb_enabled = True\n # Watch model gradients and parameters\n # wandb.watch(model, log=\"all\", log_freq=LOGGING_STEPS * GRAD_ACCUMULATION_STEPS)\nexcept Exception as e:\n print(f\"Wandb initialization failed: {e}. Running without wandb.\")\n wandb_enabled = False\n\n# --- Training Loop ---\nprint(\"Starting training with Classification + Contrastive + Regression Loss...\")\nstep_counter = 0\noptimizer.zero_grad() # Initialize gradients to zero\n\nfor epoch in range(NUM_EPOCHS):\n model.train() # Set model to training mode\n # Accumulators for average loss calculation over logging period\n epoch_total_loss_accum = 0.0\n epoch_class_loss_accum = 0.0\n epoch_con_loss_accum = 0.0\n epoch_reg_loss_accum = 0.0\n batches_since_log = 0\n\n pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f\"Epoch {epoch+1}/{num_epochs}\", leave=False)\n\n for batch_idx, batch in pbar:\n if batch is None: continue # Skip potentially empty batches from collate_fn\n\n # --- Unpack Batch Data ---\n # Ensure all required keys are present\n try:\n images = batch['image'].to(DEVICE, non_blocking=True).to(DTYPE)\n prompt_ids = batch['prompt_ids'].to(DEVICE, non_blocking=True)\n prompt_attention_mask = batch['prompt_attention_mask'].to(DEVICE, non_blocking=True)\n target_ids = batch['target_ids'].to(DEVICE, non_blocking=True)\n target_attention_mask = batch['target_attention_mask'].to(DEVICE, non_blocking=True)\n generative_targets = batch['generative_targets'].to(DEVICE, non_blocking=True)\n continuous_coords = batch['continuous_coords'].to(DEVICE, non_blocking=True) # Get regression targets\n except KeyError as e:\n print(f\"Error: Missing key {e} in batch. Check dataloader and collate_fn.\")\n continue # Skip batch if data is missing\n\n # Clamp logit_scale for contrastive loss stability\n with torch.no_grad():\n model.logit_scale.clamp_(0, torch.log(torch.tensor(100.0)))\n\n # --- Forward Pass ---\n # Model now returns regression_output and individual losses\n logits, reg_output, total_loss, class_loss, contrastive_loss, regression_loss = model(\n img_array=images,\n prompt_ids=prompt_ids,\n prompt_attention_mask=prompt_attention_mask,\n target_ids=target_ids,\n target_attention_mask=target_attention_mask,\n generative_targets=generative_targets,\n continuous_coords=continuous_coords # Pass regression targets\n )\n\n # --- Loss Handling & Accumulation ---\n # Check for invalid total loss before backward pass\n if total_loss is None or torch.isnan(total_loss) or torch.isinf(total_loss):\n print(f\"Warning: Invalid total_loss ({total_loss}) detected at Epoch {epoch+1}, Batch {batch_idx}. Skipping backward/step.\")\n # Don't accumulate invalid losses, reset gradients for safety\n optimizer.zero_grad()\n continue\n\n # Scale loss for gradient accumulation\n scaled_loss = total_loss / GRAD_ACCUMULATION_STEPS\n\n # Accumulate valid loss components for logging\n # Use .item() for scalar Tensors, handle potential None/NaN for logging\n epoch_total_loss_accum += total_loss.item()\n epoch_class_loss_accum += class_loss.item() if torch.is_tensor(class_loss) else 0.0\n epoch_con_loss_accum += contrastive_loss.item() if torch.is_tensor(contrastive_loss) else 0.0\n epoch_reg_loss_accum += regression_loss.item() if torch.is_tensor(regression_loss) else 0.0\n batches_since_log += 1\n\n # --- Backward Pass ---\n # Consider adding try-except around backward potentially\n try:\n scaled_loss.backward()\n except Exception as e:\n print(f\"Error during backward pass: {e}. Skipping step.\")\n optimizer.zero_grad() # Reset gradients if backward failed\n continue\n\n # --- Gradient Accumulation Step ---\n if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(train_loader):\n # Clip gradients (optional but recommended)\n grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n\n # Optimizer step\n optimizer.step()\n lr_scheduler.step() # Step scheduler on each optimizer step\n optimizer.zero_grad() # Zero grads *after* stepping optimizer\n\n step_counter += 1 # Increment global step counter\n\n # --- Logging ---\n if step_counter % LOGGING_STEPS == 0 and batches_since_log > 0:\n # Calculate average losses over the logging period\n avg_total_loss = epoch_total_loss_accum / batches_since_log\n avg_class_loss = epoch_class_loss_accum / batches_since_log\n avg_con_loss = epoch_con_loss_accum / batches_since_log\n avg_reg_loss = epoch_reg_loss_accum / batches_since_log\n current_lr = optimizer.param_groups[0]['lr']\n\n # --- Test Evaluation ---\n test_class_loss_val = float('nan')\n test_con_loss_val = float('nan')\n test_reg_loss_val = float('nan')\n if test_loader_has_data:\n model.eval() # Switch to eval mode\n with torch.no_grad():\n try:\n # Get a test batch (handle potential errors)\n test_batch = next(iter(test_loader))\n if test_batch:\n # Unpack test batch (ensure keys match train batch)\n t_images = test_batch['image'].to(DEVICE).to(DTYPE)\n t_p_ids = test_batch['prompt_ids'].to(DEVICE)\n t_p_mask = test_batch['prompt_attention_mask'].to(DEVICE)\n t_t_ids = test_batch['target_ids'].to(DEVICE)\n t_t_mask = test_batch['target_attention_mask'].to(DEVICE)\n t_gen_targets = test_batch['generative_targets'].to(DEVICE)\n t_cont_coords = test_batch['continuous_coords'].to(DEVICE) # Get test coords\n\n # Run model forward on test batch\n _, _, _, t_class_loss, t_con_loss, t_reg_loss = model(\n t_images, t_p_ids, t_p_mask, t_t_ids, t_t_mask,\n t_gen_targets, t_cont_coords # Pass test coords\n )\n # Store scalar loss values safely\n test_class_loss_val = t_class_loss.item() if torch.is_tensor(t_class_loss) and not torch.isnan(t_class_loss) else float('nan')\n test_con_loss_val = t_con_loss.item() if torch.is_tensor(t_con_loss) and not torch.isnan(t_con_loss) else float('nan')\n test_reg_loss_val = t_reg_loss.item() if torch.is_tensor(t_reg_loss) and not torch.isnan(t_reg_loss) else float('nan')\n except StopIteration:\n print(\"Info: Test loader exhausted during logging.\") # Don't treat as error\n except KeyError as e:\n print(f\"Error: Missing key {e} in test batch.\")\n except Exception as e:\n print(f\"Error during test evaluation: {e}\")\n model.train() # Switch back to train mode\n\n # Prepare data for logging\n log_data = {\n \"train/total_loss\": avg_total_loss,\n \"train/class_loss\": avg_class_loss,\n \"train/contrastive_loss\": avg_con_loss,\n \"train/regression_loss\": avg_reg_loss, # Log train regression loss\n \"test/class_loss\": test_class_loss_val,\n \"test/contrastive_loss\": test_con_loss_val,\n \"test/regression_loss\": test_reg_loss_val, # Log test regression loss\n \"epoch\": epoch + ((batch_idx + 1) / len(train_loader)), # Fractional epoch\n \"step\": step_counter,\n \"learning_rate\": current_lr,\n \"gradient_norm\": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, # Handle if not tensor\n \"logit_scale\": model.logit_scale.exp().item()\n }\n # Update progress bar description with key metrics\n pbar.set_postfix({\n \"lr\": f\"{current_lr:.2e}\",\n \"loss\": f\"{avg_total_loss:.3f}\",\n \"cls\": f\"{avg_class_loss:.3f}\",\n \"con\": f\"{avg_con_loss:.3f}\",\n \"reg\": f\"{avg_reg_loss:.3f}\",\n \"gnorm\": f\"{log_data['gradient_norm']:.2f}\"\n })\n # Log to wandb if enabled\n if wandb_enabled:\n wandb.log(log_data, step=step_counter)\n\n # Reset accumulators for the next logging period\n epoch_total_loss_accum, epoch_class_loss_accum, epoch_con_loss_accum, epoch_reg_loss_accum = 0.0, 0.0, 0.0, 0.0\n batches_since_log = 0\n\n # --- End of Epoch ---\n print(f\"\\nEpoch {epoch+1}/{NUM_EPOCHS} completed.\")\n # Optional: Add end-of-epoch evaluation or model saving here\n\n# --- End of Training ---\nprint(\"\\nTraining completed!\")\nif wandb_enabled:\n wandb.finish()","metadata":{"trusted":true},"outputs":[],"execution_count":null}]} |