Spaces:
Sleeping
Sleeping
| """ | |
| model.py | |
| ======== | |
| Complete SmolLM2-135M model implementation | |
| Architecture: | |
| - 30 transformer blocks | |
| - 576 hidden dimensions | |
| - 9 query heads, 3 KV heads (Grouped Query Attention) | |
| - SwiGLU feed-forward network | |
| - RoPE position embeddings | |
| - RMSNorm layer normalization | |
| - Weight tying (embeddings = lm_head) | |
| Total parameters: 134,515,008 (~135M) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from components import RMSNorm, TransformerBlock | |
| from transformers import AutoConfig | |
| class SmolLM2Model(nn.Module): | |
| """ | |
| SmolLM2-135M Language Model | |
| A decoder-only transformer based on Llama architecture with: | |
| - Grouped Query Attention (memory efficient) | |
| - SwiGLU FFN (improved expressiveness) | |
| - RoPE position embeddings (length extrapolation) | |
| - RMSNorm (faster than LayerNorm) | |
| Model configuration: | |
| - Layers: 30 | |
| - Hidden size: 576 | |
| - Attention heads: 9 (Q) / 3 (KV) | |
| - FFN size: 1536 | |
| - Vocab size: 49,152 | |
| - Context length: 2048 | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Initialize SmolLM2 model | |
| Args: | |
| config: Model configuration object with attributes: | |
| - vocab_size: Size of vocabulary (49152) | |
| - hidden_size: Model dimension (576) | |
| - num_hidden_layers: Number of transformer blocks (30) | |
| - tie_word_embeddings: Whether to tie input/output embeddings | |
| - rms_norm_eps: Epsilon for RMSNorm | |
| """ | |
| super().__init__() | |
| self.config = config | |
| # Token embeddings | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) | |
| # Transformer blocks (30 layers) | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(config) for _ in range(config.num_hidden_layers) | |
| ]) | |
| # Final layer normalization | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| # Language modeling head (output projection) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Weight tying: share embeddings with output projection | |
| if config.tie_word_embeddings: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| print(f"β Model initialized with {config.num_hidden_layers} transformer blocks") | |
| print(f"β Weight tying: {config.tie_word_embeddings}") | |
| def forward(self, input_ids, attention_mask=None, position_ids=None): | |
| """ | |
| Forward pass through the model | |
| Args: | |
| input_ids (torch.Tensor): Input token IDs [batch, seq_len] | |
| attention_mask (torch.Tensor, optional): Attention mask | |
| position_ids (torch.Tensor, optional): Position indices | |
| Returns: | |
| torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size] | |
| """ | |
| batch_size, seq_len = input_ids.shape | |
| # Create position IDs if not provided | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=input_ids.device) | |
| # Embed tokens | |
| hidden_states = self.embed_tokens(input_ids) | |
| # Pass through all transformer blocks | |
| for layer in self.layers: | |
| hidden_states = layer(hidden_states, attention_mask, position_ids) | |
| # Final normalization | |
| hidden_states = self.norm(hidden_states) | |
| # Project to vocabulary | |
| logits = self.lm_head(hidden_states) | |
| return logits | |
| def generate( | |
| self, | |
| input_ids, | |
| max_new_tokens=50, | |
| temperature=1.0, | |
| top_p=0.9, | |
| top_k=None, | |
| do_sample=True | |
| ): | |
| """ | |
| Generate text autoregressively | |
| Supports multiple sampling strategies: | |
| - Greedy decoding (temperature=0) | |
| - Temperature sampling | |
| - Nucleus (top-p) sampling | |
| - Top-k sampling | |
| Args: | |
| input_ids (torch.Tensor): Input token IDs [batch, seq_len] | |
| max_new_tokens (int): Number of tokens to generate | |
| temperature (float): Sampling temperature (0 = greedy, >1 = more random) | |
| top_p (float): Nucleus sampling threshold (0-1) | |
| top_k (int, optional): Top-k sampling threshold | |
| do_sample (bool): Whether to sample or use greedy decoding | |
| Returns: | |
| torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens] | |
| """ | |
| self.eval() | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| # Forward pass | |
| logits = self(input_ids) | |
| # Get next token logits | |
| next_token_logits = logits[:, -1, :] | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Greedy decoding | |
| if not do_sample or temperature == 0: | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
| else: | |
| # Top-k sampling | |
| if top_k is not None: | |
| top_k = min(top_k, next_token_logits.size(-1)) | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Nucleus (top-p) sampling | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Keep at least one token | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = False | |
| # Scatter to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Sample from distribution | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Append to sequence | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| return input_ids | |
| def get_num_params(self, non_embedding=False): | |
| """ | |
| Count model parameters | |
| Args: | |
| non_embedding (bool): If True, exclude embedding parameters | |
| Returns: | |
| int: Number of parameters | |
| """ | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| if non_embedding: | |
| n_params -= self.embed_tokens.weight.numel() | |
| # If weights are tied, don't double-count | |
| if not self.config.tie_word_embeddings: | |
| n_params -= self.lm_head.weight.numel() | |
| return n_params | |
| def initialize_weights(model, config): | |
| """ | |
| Initialize model weights using GPT-style initialization | |
| Strategy: | |
| - All weights: Normal(0, 0.02) | |
| - Residual projections: Scaled by 1/sqrt(2 * num_layers) | |
| - RMSNorm: Initialized to 1.0 (PyTorch default) | |
| The residual scaling prevents variance explosion in deep networks. | |
| Args: | |
| model (SmolLM2Model): Model to initialize | |
| config: Model configuration | |
| """ | |
| std = 0.02 | |
| num_layers = config.num_hidden_layers | |
| # Residual scaling factor: 1/sqrt(2 * num_layers) | |
| residual_scaling = 1.0 / math.sqrt(2 * num_layers) | |
| print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}") | |
| # Initialize embeddings | |
| nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std) | |
| # Initialize each transformer block | |
| for layer in model.layers: | |
| # Attention projections | |
| nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std) | |
| nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std) | |
| nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std) | |
| # Output projection with residual scaling | |
| nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling) | |
| # FFN projections | |
| nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std) | |
| nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std) | |
| # Output projection with residual scaling | |
| nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling) | |
| # RMSNorm weights are initialized to 1.0 by default (PyTorch) | |
| print(f"β Initialized {sum(1 for _ in model.parameters())} weight tensors") | |
| def load_pretrained_weights(our_model, official_model, device='cuda'): | |
| """ | |
| Load weights from HuggingFace official model | |
| Maps weight names from official model to our implementation: | |
| - model.embed_tokens.weight -> embed_tokens.weight | |
| - model.layers.{i}.* -> layers[i].* | |
| - model.norm.weight -> norm.weight | |
| - lm_head.weight (tied with embeddings) | |
| Args: | |
| our_model (SmolLM2Model): Our model to load weights into | |
| official_model: HuggingFace official model | |
| device (str): Device to load weights to | |
| Returns: | |
| int: Number of weight tensors loaded | |
| """ | |
| print("=" * 70) | |
| print("LOADING PRETRAINED WEIGHTS") | |
| print("=" * 70) | |
| official_state = official_model.state_dict() | |
| loaded_count = 0 | |
| # 1. Load token embeddings | |
| our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device) | |
| loaded_count += 1 | |
| # 2. Load all transformer blocks | |
| num_layers = our_model.config.num_hidden_layers | |
| for layer_idx in range(num_layers): | |
| prefix = f'model.layers.{layer_idx}' | |
| # Layer norms | |
| our_model.layers[layer_idx].input_layernorm.weight.data = \ | |
| official_state[f'{prefix}.input_layernorm.weight'].clone().to(device) | |
| our_model.layers[layer_idx].post_attention_layernorm.weight.data = \ | |
| official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device) | |
| # Attention projections | |
| our_model.layers[layer_idx].self_attn.q_proj.weight.data = \ | |
| official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device) | |
| our_model.layers[layer_idx].self_attn.k_proj.weight.data = \ | |
| official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device) | |
| our_model.layers[layer_idx].self_attn.v_proj.weight.data = \ | |
| official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device) | |
| our_model.layers[layer_idx].self_attn.o_proj.weight.data = \ | |
| official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device) | |
| # FFN projections | |
| our_model.layers[layer_idx].mlp.gate_proj.weight.data = \ | |
| official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device) | |
| our_model.layers[layer_idx].mlp.up_proj.weight.data = \ | |
| official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device) | |
| our_model.layers[layer_idx].mlp.down_proj.weight.data = \ | |
| official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device) | |
| loaded_count += 9 # 2 norms + 4 attn + 3 ffn | |
| # 3. Load final norm | |
| our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device) | |
| loaded_count += 1 | |
| print(f"\nβ Loaded {num_layers} transformer blocks") | |
| print(f"β Total loaded: {loaded_count} weight tensors") | |
| print("=" * 70) | |
| return loaded_count | |
| if __name__ == "__main__": | |
| """Test model creation and parameter count""" | |
| # Load config | |
| config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
| # Create model | |
| model = SmolLM2Model(config) | |
| # Count parameters | |
| total_params = model.get_num_params() | |
| print(f"\nTotal parameters: {total_params:,}") | |
| print(f"Expected: 134,515,008") | |
| print(f"Match: {total_params == 134_515_008}") | |
| # Test forward pass | |
| test_input = torch.randint(0, config.vocab_size, (1, 10)) | |
| output = model(test_input) | |
| print(f"\nForward pass test:") | |
| print(f" Input shape: {test_input.shape}") | |
| print(f" Output shape: {output.shape}") | |
| print(f" Expected: torch.Size([1, 10, 49152])") | |
| # Test generation | |
| generated = model.generate(test_input, max_new_tokens=5) | |
| print(f"\nGeneration test:") | |
| print(f" Generated shape: {generated.shape}") | |
| print(f" Expected: torch.Size([1, 15])") |