from model_components import Block from constants import * import torch import torch.nn as nn import torch.nn.functional as F from utils import tokenizer, vocab_size class DecoderLanguageModel(nn.Module): """ Transformer Decoder Language Model with optional coordinate regression head. Processes a combined sequence of embeddings. Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS). """ def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS, n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT): super().__init__() # --- Input Embeddings --- self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.position_embedding_table = nn.Embedding(max_context, n_embd) self.dropout = nn.Dropout(dropout) # --- Transformer Blocks --- self.blocks = nn.ModuleList([ Block(n_embd, num_heads, dropout, is_decoder=True) for _ in range(n_layer) ]) # --- Final Layer Norm --- self.ln_f = nn.LayerNorm(n_embd) # --- Output Heads --- # 1. Head for token classification self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values) self.regression_head = nn.Sequential( nn.Linear(n_embd, n_embd // 2), nn.GELU(), nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y) nn.Sigmoid() # Output activation [0, 1] ) # --- End Output Heads --- self.n_embd = n_embd self.max_context = max_context self.token_embedding_table.weight = self.lm_head.weight self.apply(self._init_weights) print(f"DecoderLanguageModel initialized with {n_layer} layers.") def _init_weights(self, module): # ... (same as before) ... if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) def forward(self, combined_embeds, attention_mask=None, targets=None): """ Forward pass for training or inference where loss is calculated. Regression output is now handled *outside* this module by VLM. """ # --- Input Validation & Processing --- if combined_embeds.ndim != 3: raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}") B, T, C = combined_embeds.shape if T > self.max_context: # ... (context truncation logic - same as before) ... print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.") combined_embeds = combined_embeds[:, -self.max_context:, :] if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:] if targets is not None: targets = targets[:, -self.max_context:] T = self.max_context # --- Positional Encoding --- pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device) pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1) pos_emb = self.position_embedding_table(pos) # Shape: (T, C) x = combined_embeds + pos_emb.unsqueeze(0) x = self.dropout(x) # --- Transformer Blocks --- for block in self.blocks: x = block(x, attention_mask=attention_mask) # --- Final Layer Norm --- x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head # --- Classification Head Output --- logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize) # --- Classification Loss Calculation --- class_loss = None if targets is not None: # ... (cross_entropy calculation - same as before) ... try: class_loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100 ) if torch.isnan(class_loss): print("Warning: class_loss is NaN.") class_loss = None except Exception as e: print(f"Error calculating cross_entropy: {e}") print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}") class_loss = None # Return logits, class_loss, and the final normalized hidden states return logits, class_loss, x_norm # --- Generation Method (Example - if needed internally, otherwise VLM handles it) --- # If VLM needs this class to perform generation based on token IDs: def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ Autoregressive generation based on starting token IDs. NOTE: This version doesn't handle combined embeddings directly. The VisionLanguageModel should ideally use a method like generate_from_embeddings or implement the loop externally. """ self.eval() for _ in range(max_new_tokens): # --- Context Management --- # Crop idx if longer than context length idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:] # --- Forward Pass --- # Get embeddings tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C) # Get positional embeddings pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device) pos = pos.clamp(max=self.max_context - 1) pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C) x = self.dropout(tok_embeds + pos_emb) # Pass through blocks (no padding mask needed here as we handle single sequence) for block in self.blocks: x = block(x, attention_mask=None) # Causal mask is internal to block/head # Final layer norm and head for the last token only x = self.ln_f(x[:, -1:, :]) # (B, 1, C) logits = self.lm_head(x) # (B, 1, V) logits = logits.squeeze(1) # (B, V) # --- Sampling --- logits = logits / temperature if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # Append sampled token idx = torch.cat((idx, idx_next), dim=1) # Stop if EOS if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all(): break self.train() return idx