MartialTerran commited on
Commit
95b814a
1 Parent(s): 0d7b558

Update SmolLM2_360M_model_debugging.py

Browse files
Files changed (1) hide show
  1. SmolLM2_360M_model_debugging.py +26 -25
SmolLM2_360M_model_debugging.py CHANGED
@@ -48,12 +48,12 @@ import torch.nn.functional as F
48
  # --- Utility Functions ---
49
 
50
  def load_json(file_path: str) -> Dict:
51
- """Load JSON data from a file."""
52
  with open(file_path, 'r', encoding='utf-8') as f:
53
  return json.load(f)
54
 
55
  def timed_step(start: float, step_name: str) -> float:
56
- """Print time taken for a step and return new start time."""
57
  end = time.time()
58
  print(f"Time taken for {step_name}: {end - start:.4f} seconds")
59
  return end
@@ -61,23 +61,23 @@ def timed_step(start: float, step_name: str) -> float:
61
  # --- Model Architecture ---
62
 
63
  class RMSNorm(nn.Module):
64
- """Root Mean Square Normalization."""
65
  def __init__(self, dim: int, eps: float = 1e-5):
66
  super().__init__()
67
  self.eps = eps
68
  self.weight = nn.Parameter(torch.ones(dim))
69
 
70
  def forward(self, x: torch.Tensor) -> torch.Tensor:
71
- """Apply RMS normalization."""
72
  norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
73
  return self.weight * norm_x
74
 
75
  def silu(x: torch.Tensor) -> torch.Tensor:
76
- """SiLU activation function."""
77
  return x * torch.sigmoid(x)
78
 
79
  class RotaryEmbedding(nn.Module):
80
- """Rotary Positional Embedding."""
81
  def __init__(self, dim: int, base: int = 10000):
82
  super().__init__()
83
  self.dim = dim
@@ -85,23 +85,23 @@ class RotaryEmbedding(nn.Module):
85
  self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
86
 
87
  def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
88
- """Generate rotary embeddings for a given sequence length."""
89
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
90
  freqs = torch.outer(t, self.inv_freq)
91
  return torch.cat((freqs, freqs), dim=-1)
92
 
93
  def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
94
- """Apply rotary embeddings to the given tensor."""
95
  return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos))
96
 
97
  def rotate_half(x: torch.Tensor) -> torch.Tensor:
98
- """Rotate half of the tensor."""
99
  x1 = x[..., : x.shape[-1] // 2]
100
  x2 = x[..., x.shape[-1] // 2 :]
101
  return torch.cat((-x2, x1), dim=-1)
102
 
103
  class LlamaAttention(nn.Module):
104
- """Multi-headed attention layer for LLaMA."""
105
  def __init__(self, config: Dict):
106
  super().__init__()
107
  self.config = config
@@ -121,7 +121,7 @@ class LlamaAttention(nn.Module):
121
  self.attn_dropout = nn.Dropout(config['attention_dropout'])
122
 
123
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
124
- """Compute multi-headed attention."""
125
 
126
  batch_size, seq_length, _ = hidden_states.size()
127
  query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
@@ -170,7 +170,7 @@ class LlamaAttention(nn.Module):
170
  return attn_output, present_key_value
171
 
172
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
173
- """Repeat hidden states n_rep times for key/value heads."""
174
  #Stitch1
175
  batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
176
  if n_rep == 1:
@@ -179,7 +179,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
179
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
180
 
181
  class LlamaMLP(nn.Module):
182
- """Multi-Layer Perceptron for LLaMA."""
183
  def __init__(self, config: Dict):
184
  super().__init__()
185
  hidden_size = config['hidden_size']
@@ -190,11 +190,11 @@ class LlamaMLP(nn.Module):
190
  self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act'])
191
 
192
  def forward(self, x: torch.Tensor) -> torch.Tensor:
193
- """Apply MLP to the input tensor."""
194
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
195
 
196
  class LlamaBlock(nn.Module):
197
- """LLaMA block containing attention and MLP layers."""
198
  def __init__(self, config: Dict):
199
  super().__init__()
200
  self.hidden_size = config['hidden_size']
@@ -204,7 +204,7 @@ class LlamaBlock(nn.Module):
204
  self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
205
 
206
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
207
- """Apply the LLaMA block."""
208
  residual = hidden_states
209
  hidden_states = self.input_layernorm(hidden_states)
210
  hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
@@ -216,7 +216,7 @@ class LlamaBlock(nn.Module):
216
  return hidden_states, present_key_value
217
 
218
  class SmolLM2_360M(nn.Module):
219
- """SmolLM2-360M model implementation."""
220
  def __init__(self, config_path: str):
221
  super().__init__()
222
  self.config = load_json(config_path)
@@ -247,7 +247,7 @@ class SmolLM2_360M(nn.Module):
247
  self.past_keys_values = None
248
 
249
  def load_weights(self, weights_path: str):
250
- """Load weights from a safetensors file."""
251
  start = time.time()
252
  try:
253
  from safetensors import safe_open
@@ -274,7 +274,7 @@ class SmolLM2_360M(nn.Module):
274
  end = timed_step(start, "Weight Loading")
275
 
276
  def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
277
- """Forward pass of the model."""
278
  use_cache = use_cache if use_cache is not None else self.use_cache
279
  batch_size, seq_length = input_ids.shape
280
  if position_ids is None:
@@ -311,7 +311,7 @@ class SmolLM2_360M(nn.Module):
311
  # --- Tokenizer ---
312
 
313
  class SmolLM2Tokenizer:
314
- """Tokenizer for SmolLM2-360M using SentencePiece or a rudimentary BPE."""
315
  def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."):
316
  self.tokenizer_path = tokenizer_path
317
  self.special_tokens_map_path = special_tokens_map_path
@@ -355,7 +355,7 @@ class SmolLM2Tokenizer:
355
  self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens]
356
 
357
  def update_special_tokens_from_sp(self):
358
- """Update special token IDs from SentencePiece model, if present."""
359
  for token_name, token_data in self.special_tokens_map.items():
360
  sp_id = self.sp_model.piece_to_id(token_data['content'])
361
  if sp_id != self.sp_model.unk_id():
@@ -387,7 +387,7 @@ class SmolLM2Tokenizer:
387
 
388
 
389
  def bpe(self, token: str) -> List[str]:
390
- """Rudimentary BPE tokenization."""
391
  if not self.use_sentencepiece:
392
  word = list(token)
393
  while len(word) > 1:
@@ -412,7 +412,7 @@ class SmolLM2Tokenizer:
412
  return [] # If SentencePiece is used, this function is not called.
413
 
414
  def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
415
- """Encode text to token IDs."""
416
  if self.use_sentencepiece:
417
  if add_special_tokens:
418
  return self.sp_model.encode(text, out_type=int) #add_bos=True, add_eos=True if needed, adjust as per model requirement
@@ -428,7 +428,7 @@ class SmolLM2Tokenizer:
428
  return token_ids
429
 
430
  def decode(self, token_ids: List[int]) -> str:
431
- """Decode token IDs to text."""
432
  if self.use_sentencepiece:
433
  return self.sp_model.decode(token_ids)
434
  else:
@@ -439,7 +439,7 @@ class SmolLM2Tokenizer:
439
  # --- Inference ---
440
 
441
  def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str:
442
- """Generate text using greedy decoding."""
443
  input_ids = tokenizer.encode(prompt, add_special_tokens=True)
444
  input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
445
 
@@ -503,4 +503,5 @@ if __name__ == "__main__":
503
  generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
504
  print(f"Generated Text: {generated_text}")
505
  end = timed_step(start, "Prompt Generation")
 
506
 
 
48
  # --- Utility Functions ---
49
 
50
  def load_json(file_path: str) -> Dict:
51
+ ###Load JSON data from a file.###
52
  with open(file_path, 'r', encoding='utf-8') as f:
53
  return json.load(f)
54
 
55
  def timed_step(start: float, step_name: str) -> float:
56
+ ###Print time taken for a step and return new start time.###
57
  end = time.time()
58
  print(f"Time taken for {step_name}: {end - start:.4f} seconds")
59
  return end
 
61
  # --- Model Architecture ---
62
 
63
  class RMSNorm(nn.Module):
64
+ ###Root Mean Square Normalization.###
65
  def __init__(self, dim: int, eps: float = 1e-5):
66
  super().__init__()
67
  self.eps = eps
68
  self.weight = nn.Parameter(torch.ones(dim))
69
 
70
  def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ ###Apply RMS normalization.###
72
  norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
73
  return self.weight * norm_x
74
 
75
  def silu(x: torch.Tensor) -> torch.Tensor:
76
+ ###SiLU activation function.###
77
  return x * torch.sigmoid(x)
78
 
79
  class RotaryEmbedding(nn.Module):
80
+ ###Rotary Positional Embedding.###
81
  def __init__(self, dim: int, base: int = 10000):
82
  super().__init__()
83
  self.dim = dim
 
85
  self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
86
 
87
  def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ ###Generate rotary embeddings for a given sequence length.###
89
  t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
90
  freqs = torch.outer(t, self.inv_freq)
91
  return torch.cat((freqs, freqs), dim=-1)
92
 
93
  def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
94
+ ###Apply rotary embeddings to the given tensor.###
95
  return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos))
96
 
97
  def rotate_half(x: torch.Tensor) -> torch.Tensor:
98
+ ###Rotate half of the tensor.###
99
  x1 = x[..., : x.shape[-1] // 2]
100
  x2 = x[..., x.shape[-1] // 2 :]
101
  return torch.cat((-x2, x1), dim=-1)
102
 
103
  class LlamaAttention(nn.Module):
104
+ ###Multi-headed attention layer for LLaMA.###
105
  def __init__(self, config: Dict):
106
  super().__init__()
107
  self.config = config
 
121
  self.attn_dropout = nn.Dropout(config['attention_dropout'])
122
 
123
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
124
+ ###Compute multi-headed attention.###
125
 
126
  batch_size, seq_length, _ = hidden_states.size()
127
  query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
 
170
  return attn_output, present_key_value
171
 
172
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
173
+ ###Repeat hidden states n_rep times for key/value heads.###
174
  #Stitch1
175
  batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
176
  if n_rep == 1:
 
179
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
180
 
181
  class LlamaMLP(nn.Module):
182
+ ###Multi-Layer Perceptron for LLaMA.###
183
  def __init__(self, config: Dict):
184
  super().__init__()
185
  hidden_size = config['hidden_size']
 
190
  self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act'])
191
 
192
  def forward(self, x: torch.Tensor) -> torch.Tensor:
193
+ ###Apply MLP to the input tensor.###
194
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
195
 
196
  class LlamaBlock(nn.Module):
197
+ ###LLaMA block containing attention and MLP layers.###
198
  def __init__(self, config: Dict):
199
  super().__init__()
200
  self.hidden_size = config['hidden_size']
 
204
  self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
205
 
206
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
207
+ ###Apply the LLaMA block.###
208
  residual = hidden_states
209
  hidden_states = self.input_layernorm(hidden_states)
210
  hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
 
216
  return hidden_states, present_key_value
217
 
218
  class SmolLM2_360M(nn.Module):
219
+ ###SmolLM2-360M model implementation.###
220
  def __init__(self, config_path: str):
221
  super().__init__()
222
  self.config = load_json(config_path)
 
247
  self.past_keys_values = None
248
 
249
  def load_weights(self, weights_path: str):
250
+ ###Load weights from a safetensors file.###
251
  start = time.time()
252
  try:
253
  from safetensors import safe_open
 
274
  end = timed_step(start, "Weight Loading")
275
 
276
  def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
277
+ ###Forward pass of the model.###
278
  use_cache = use_cache if use_cache is not None else self.use_cache
279
  batch_size, seq_length = input_ids.shape
280
  if position_ids is None:
 
311
  # --- Tokenizer ---
312
 
313
  class SmolLM2Tokenizer:
314
+ ###Tokenizer for SmolLM2-360M using SentencePiece or a rudimentary BPE.###
315
  def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."):
316
  self.tokenizer_path = tokenizer_path
317
  self.special_tokens_map_path = special_tokens_map_path
 
355
  self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens]
356
 
357
  def update_special_tokens_from_sp(self):
358
+ ###Update special token IDs from SentencePiece model, if present.###
359
  for token_name, token_data in self.special_tokens_map.items():
360
  sp_id = self.sp_model.piece_to_id(token_data['content'])
361
  if sp_id != self.sp_model.unk_id():
 
387
 
388
 
389
  def bpe(self, token: str) -> List[str]:
390
+ ###Rudimentary BPE tokenization.###
391
  if not self.use_sentencepiece:
392
  word = list(token)
393
  while len(word) > 1:
 
412
  return [] # If SentencePiece is used, this function is not called.
413
 
414
  def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
415
+ ###Encode text to token IDs.###
416
  if self.use_sentencepiece:
417
  if add_special_tokens:
418
  return self.sp_model.encode(text, out_type=int) #add_bos=True, add_eos=True if needed, adjust as per model requirement
 
428
  return token_ids
429
 
430
  def decode(self, token_ids: List[int]) -> str:
431
+ ###Decode token IDs to text.###
432
  if self.use_sentencepiece:
433
  return self.sp_model.decode(token_ids)
434
  else:
 
439
  # --- Inference ---
440
 
441
  def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str:
442
+ ###Generate text using greedy decoding.###
443
  input_ids = tokenizer.encode(prompt, add_special_tokens=True)
444
  input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
445
 
 
503
  generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
504
  print(f"Generated Text: {generated_text}")
505
  end = timed_step(start, "Prompt Generation")
506
+
507