""" SmolLM2-135M Text Generation - Gradio App Trained from scratch on Suits TV series scripts """ import torch import torch.nn as nn from torch.nn import functional as F from dataclasses import dataclass from typing import Optional, Tuple import tiktoken import gradio as gr # ============================================================================ # Model Architecture # ============================================================================ @dataclass class SmolLM2Config: """SmolLM2-135M Configuration""" vocab_size: int = 50304 hidden_size: int = 576 intermediate_size: int = 1536 num_hidden_layers: int = 30 num_attention_heads: int = 9 num_key_value_heads: int = 3 max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-5 rope_theta: float = 10000.0 hidden_act: str = "silu" initializer_range: float = 0.041666666666666664 tie_word_embeddings: bool = True bos_token_id: int = 0 eos_token_id: int = 0 @property def head_dim(self) -> int: return self.hidden_size // self.num_attention_heads class RMSNorm(nn.Module): """Root Mean Square Layer Normalization""" def __init__(self, hidden_size: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * rms * self.weight class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE)""" def __init__(self, dim: int, max_position_embeddings: int = 2048, theta: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.theta = theta inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len: int): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) return ( self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype) ) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA)""" def __init__(self, config: SmolLM2Config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.num_kv_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings, theta=config.rope_theta ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.size() q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(x, T) q, k = apply_rotary_pos_emb(q, k, cos, sin) k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, self.hidden_size) y = self.o_proj(y) return y class SwiGLUMLP(nn.Module): """SwiGLU Feed-Forward Network""" def __init__(self, config: SmolLM2Config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class SmolLM2Block(nn.Module): """SmolLM2 Transformer Block""" def __init__(self, config: SmolLM2Config): super().__init__() self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = GroupedQueryAttention(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = SwiGLUMLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.self_attn(self.input_layernorm(x)) x = x + self.mlp(self.post_attention_layernorm(x)) return x class SmolLM2(nn.Module): """SmolLM2-135M Model""" def __init__(self, config: SmolLM2Config): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([SmolLM2Block(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, T = input_ids.size() x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss # ============================================================================ # Load Model # ============================================================================ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load model config = SmolLM2Config() model = SmolLM2(config) # Load trained weights checkpoint = torch.load("smollm2_135m_final.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() print(f"Model loaded successfully! Parameters: {sum(p.numel() for p in model.parameters()):,}") # Load tokenizer tokenizer = tiktoken.get_encoding('gpt2') # ============================================================================ # Generation Function # ============================================================================ def generate_text( prompt: str, max_new_tokens: int = 100, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, ) -> str: """Generate text from a prompt""" if not prompt.strip(): return "Please enter a prompt." model.eval() # Encode prompt tokens = tokenizer.encode(prompt) tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) # Generate tokens with torch.no_grad(): for _ in range(max_new_tokens): # Crop to max position embeddings idx_cond = tokens[:, -config.max_position_embeddings:] # Get predictions logits, _ = model(idx_cond) logits = logits[:, -1, :] / temperature # Top-k filtering if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Top-p (nucleus) filtering if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) tokens = torch.cat([tokens, next_token], dim=1) # Decode generated = tokenizer.decode(tokens[0].tolist()) return generated # ============================================================================ # Gradio Interface # ============================================================================ title = "SmolLM2-135M Text Generator" description = """ ## About This Model This is a **SmolLM2-135M** model trained from scratch on dialogue scripts from the TV series "Suits". ### Model Architecture - **Type**: Llama-based decoder-only transformer - **Parameters**: ~135M - **Features**: RMSNorm, RoPE, Grouped Query Attention (GQA), SwiGLU MLP ### Training Details - Trained for 5,050 steps - Sequence length: 1024 tokens - Uses GPT-2 tokenizer Enter a prompt below and adjust the generation parameters to see what the model generates! """ examples = [ ["Harvey walked into the office and said,"], ["The legal case was complicated because"], ["Once upon a time"], ["In a world where lawyers"], ["Mike looked at the contract and noticed"], ] # Create interface with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3, ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Max New Tokens", ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", ) with gr.Row(): top_k_slider = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-K", ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): output_text = gr.Textbox( label="Generated Text", lines=15, ) gr.Markdown("### Example Prompts") gr.Examples( examples=examples, inputs=prompt_input, ) # Connect the generate button generate_btn.click( fn=generate_text, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider], outputs=output_text, ) # Also generate on Enter key prompt_input.submit( fn=generate_text, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider], outputs=output_text, ) gr.Markdown(""" --- ### Parameter Guide - **Temperature**: Higher = more creative/random, Lower = more focused/deterministic - **Top-K**: Only sample from the top K most likely tokens - **Top-P**: Only sample from tokens whose cumulative probability is below P - **Max New Tokens**: Maximum number of tokens to generate --- *Model trained from scratch using PyTorch. Architecture based on SmolLM2-135M (Llama-style).* """) if __name__ == "__main__": demo.launch()