MartialTerran
commited on
Commit
•
95b814a
1
Parent(s):
0d7b558
Update SmolLM2_360M_model_debugging.py
Browse files- SmolLM2_360M_model_debugging.py +26 -25
SmolLM2_360M_model_debugging.py
CHANGED
@@ -48,12 +48,12 @@ import torch.nn.functional as F
|
|
48 |
# --- Utility Functions ---
|
49 |
|
50 |
def load_json(file_path: str) -> Dict:
|
51 |
-
|
52 |
with open(file_path, 'r', encoding='utf-8') as f:
|
53 |
return json.load(f)
|
54 |
|
55 |
def timed_step(start: float, step_name: str) -> float:
|
56 |
-
|
57 |
end = time.time()
|
58 |
print(f"Time taken for {step_name}: {end - start:.4f} seconds")
|
59 |
return end
|
@@ -61,23 +61,23 @@ def timed_step(start: float, step_name: str) -> float:
|
|
61 |
# --- Model Architecture ---
|
62 |
|
63 |
class RMSNorm(nn.Module):
|
64 |
-
|
65 |
def __init__(self, dim: int, eps: float = 1e-5):
|
66 |
super().__init__()
|
67 |
self.eps = eps
|
68 |
self.weight = nn.Parameter(torch.ones(dim))
|
69 |
|
70 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
71 |
-
|
72 |
norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
73 |
return self.weight * norm_x
|
74 |
|
75 |
def silu(x: torch.Tensor) -> torch.Tensor:
|
76 |
-
|
77 |
return x * torch.sigmoid(x)
|
78 |
|
79 |
class RotaryEmbedding(nn.Module):
|
80 |
-
|
81 |
def __init__(self, dim: int, base: int = 10000):
|
82 |
super().__init__()
|
83 |
self.dim = dim
|
@@ -85,23 +85,23 @@ class RotaryEmbedding(nn.Module):
|
|
85 |
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
86 |
|
87 |
def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
-
|
89 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
90 |
freqs = torch.outer(t, self.inv_freq)
|
91 |
return torch.cat((freqs, freqs), dim=-1)
|
92 |
|
93 |
def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
94 |
-
|
95 |
return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos))
|
96 |
|
97 |
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
98 |
-
|
99 |
x1 = x[..., : x.shape[-1] // 2]
|
100 |
x2 = x[..., x.shape[-1] // 2 :]
|
101 |
return torch.cat((-x2, x1), dim=-1)
|
102 |
|
103 |
class LlamaAttention(nn.Module):
|
104 |
-
|
105 |
def __init__(self, config: Dict):
|
106 |
super().__init__()
|
107 |
self.config = config
|
@@ -121,7 +121,7 @@ class LlamaAttention(nn.Module):
|
|
121 |
self.attn_dropout = nn.Dropout(config['attention_dropout'])
|
122 |
|
123 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
124 |
-
|
125 |
|
126 |
batch_size, seq_length, _ = hidden_states.size()
|
127 |
query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
@@ -170,7 +170,7 @@ class LlamaAttention(nn.Module):
|
|
170 |
return attn_output, present_key_value
|
171 |
|
172 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
173 |
-
|
174 |
#Stitch1
|
175 |
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
|
176 |
if n_rep == 1:
|
@@ -179,7 +179,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
179 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
|
180 |
|
181 |
class LlamaMLP(nn.Module):
|
182 |
-
|
183 |
def __init__(self, config: Dict):
|
184 |
super().__init__()
|
185 |
hidden_size = config['hidden_size']
|
@@ -190,11 +190,11 @@ class LlamaMLP(nn.Module):
|
|
190 |
self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act'])
|
191 |
|
192 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
193 |
-
|
194 |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
195 |
|
196 |
class LlamaBlock(nn.Module):
|
197 |
-
|
198 |
def __init__(self, config: Dict):
|
199 |
super().__init__()
|
200 |
self.hidden_size = config['hidden_size']
|
@@ -204,7 +204,7 @@ class LlamaBlock(nn.Module):
|
|
204 |
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
|
205 |
|
206 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
207 |
-
|
208 |
residual = hidden_states
|
209 |
hidden_states = self.input_layernorm(hidden_states)
|
210 |
hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
|
@@ -216,7 +216,7 @@ class LlamaBlock(nn.Module):
|
|
216 |
return hidden_states, present_key_value
|
217 |
|
218 |
class SmolLM2_360M(nn.Module):
|
219 |
-
|
220 |
def __init__(self, config_path: str):
|
221 |
super().__init__()
|
222 |
self.config = load_json(config_path)
|
@@ -247,7 +247,7 @@ class SmolLM2_360M(nn.Module):
|
|
247 |
self.past_keys_values = None
|
248 |
|
249 |
def load_weights(self, weights_path: str):
|
250 |
-
|
251 |
start = time.time()
|
252 |
try:
|
253 |
from safetensors import safe_open
|
@@ -274,7 +274,7 @@ class SmolLM2_360M(nn.Module):
|
|
274 |
end = timed_step(start, "Weight Loading")
|
275 |
|
276 |
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
277 |
-
|
278 |
use_cache = use_cache if use_cache is not None else self.use_cache
|
279 |
batch_size, seq_length = input_ids.shape
|
280 |
if position_ids is None:
|
@@ -311,7 +311,7 @@ class SmolLM2_360M(nn.Module):
|
|
311 |
# --- Tokenizer ---
|
312 |
|
313 |
class SmolLM2Tokenizer:
|
314 |
-
|
315 |
def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."):
|
316 |
self.tokenizer_path = tokenizer_path
|
317 |
self.special_tokens_map_path = special_tokens_map_path
|
@@ -355,7 +355,7 @@ class SmolLM2Tokenizer:
|
|
355 |
self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens]
|
356 |
|
357 |
def update_special_tokens_from_sp(self):
|
358 |
-
|
359 |
for token_name, token_data in self.special_tokens_map.items():
|
360 |
sp_id = self.sp_model.piece_to_id(token_data['content'])
|
361 |
if sp_id != self.sp_model.unk_id():
|
@@ -387,7 +387,7 @@ class SmolLM2Tokenizer:
|
|
387 |
|
388 |
|
389 |
def bpe(self, token: str) -> List[str]:
|
390 |
-
|
391 |
if not self.use_sentencepiece:
|
392 |
word = list(token)
|
393 |
while len(word) > 1:
|
@@ -412,7 +412,7 @@ class SmolLM2Tokenizer:
|
|
412 |
return [] # If SentencePiece is used, this function is not called.
|
413 |
|
414 |
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
415 |
-
|
416 |
if self.use_sentencepiece:
|
417 |
if add_special_tokens:
|
418 |
return self.sp_model.encode(text, out_type=int) #add_bos=True, add_eos=True if needed, adjust as per model requirement
|
@@ -428,7 +428,7 @@ class SmolLM2Tokenizer:
|
|
428 |
return token_ids
|
429 |
|
430 |
def decode(self, token_ids: List[int]) -> str:
|
431 |
-
|
432 |
if self.use_sentencepiece:
|
433 |
return self.sp_model.decode(token_ids)
|
434 |
else:
|
@@ -439,7 +439,7 @@ class SmolLM2Tokenizer:
|
|
439 |
# --- Inference ---
|
440 |
|
441 |
def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str:
|
442 |
-
|
443 |
input_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
444 |
input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
|
445 |
|
@@ -503,4 +503,5 @@ if __name__ == "__main__":
|
|
503 |
generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
|
504 |
print(f"Generated Text: {generated_text}")
|
505 |
end = timed_step(start, "Prompt Generation")
|
|
|
506 |
|
|
|
48 |
# --- Utility Functions ---
|
49 |
|
50 |
def load_json(file_path: str) -> Dict:
|
51 |
+
###Load JSON data from a file.###
|
52 |
with open(file_path, 'r', encoding='utf-8') as f:
|
53 |
return json.load(f)
|
54 |
|
55 |
def timed_step(start: float, step_name: str) -> float:
|
56 |
+
###Print time taken for a step and return new start time.###
|
57 |
end = time.time()
|
58 |
print(f"Time taken for {step_name}: {end - start:.4f} seconds")
|
59 |
return end
|
|
|
61 |
# --- Model Architecture ---
|
62 |
|
63 |
class RMSNorm(nn.Module):
|
64 |
+
###Root Mean Square Normalization.###
|
65 |
def __init__(self, dim: int, eps: float = 1e-5):
|
66 |
super().__init__()
|
67 |
self.eps = eps
|
68 |
self.weight = nn.Parameter(torch.ones(dim))
|
69 |
|
70 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
71 |
+
###Apply RMS normalization.###
|
72 |
norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
73 |
return self.weight * norm_x
|
74 |
|
75 |
def silu(x: torch.Tensor) -> torch.Tensor:
|
76 |
+
###SiLU activation function.###
|
77 |
return x * torch.sigmoid(x)
|
78 |
|
79 |
class RotaryEmbedding(nn.Module):
|
80 |
+
###Rotary Positional Embedding.###
|
81 |
def __init__(self, dim: int, base: int = 10000):
|
82 |
super().__init__()
|
83 |
self.dim = dim
|
|
|
85 |
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
86 |
|
87 |
def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
+
###Generate rotary embeddings for a given sequence length.###
|
89 |
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
90 |
freqs = torch.outer(t, self.inv_freq)
|
91 |
return torch.cat((freqs, freqs), dim=-1)
|
92 |
|
93 |
def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
94 |
+
###Apply rotary embeddings to the given tensor.###
|
95 |
return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos))
|
96 |
|
97 |
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
98 |
+
###Rotate half of the tensor.###
|
99 |
x1 = x[..., : x.shape[-1] // 2]
|
100 |
x2 = x[..., x.shape[-1] // 2 :]
|
101 |
return torch.cat((-x2, x1), dim=-1)
|
102 |
|
103 |
class LlamaAttention(nn.Module):
|
104 |
+
###Multi-headed attention layer for LLaMA.###
|
105 |
def __init__(self, config: Dict):
|
106 |
super().__init__()
|
107 |
self.config = config
|
|
|
121 |
self.attn_dropout = nn.Dropout(config['attention_dropout'])
|
122 |
|
123 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
124 |
+
###Compute multi-headed attention.###
|
125 |
|
126 |
batch_size, seq_length, _ = hidden_states.size()
|
127 |
query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
170 |
return attn_output, present_key_value
|
171 |
|
172 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
173 |
+
###Repeat hidden states n_rep times for key/value heads.###
|
174 |
#Stitch1
|
175 |
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
|
176 |
if n_rep == 1:
|
|
|
179 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
|
180 |
|
181 |
class LlamaMLP(nn.Module):
|
182 |
+
###Multi-Layer Perceptron for LLaMA.###
|
183 |
def __init__(self, config: Dict):
|
184 |
super().__init__()
|
185 |
hidden_size = config['hidden_size']
|
|
|
190 |
self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act'])
|
191 |
|
192 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
193 |
+
###Apply MLP to the input tensor.###
|
194 |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
195 |
|
196 |
class LlamaBlock(nn.Module):
|
197 |
+
###LLaMA block containing attention and MLP layers.###
|
198 |
def __init__(self, config: Dict):
|
199 |
super().__init__()
|
200 |
self.hidden_size = config['hidden_size']
|
|
|
204 |
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps'])
|
205 |
|
206 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
207 |
+
###Apply the LLaMA block.###
|
208 |
residual = hidden_states
|
209 |
hidden_states = self.input_layernorm(hidden_states)
|
210 |
hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
|
|
|
216 |
return hidden_states, present_key_value
|
217 |
|
218 |
class SmolLM2_360M(nn.Module):
|
219 |
+
###SmolLM2-360M model implementation.###
|
220 |
def __init__(self, config_path: str):
|
221 |
super().__init__()
|
222 |
self.config = load_json(config_path)
|
|
|
247 |
self.past_keys_values = None
|
248 |
|
249 |
def load_weights(self, weights_path: str):
|
250 |
+
###Load weights from a safetensors file.###
|
251 |
start = time.time()
|
252 |
try:
|
253 |
from safetensors import safe_open
|
|
|
274 |
end = timed_step(start, "Weight Loading")
|
275 |
|
276 |
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
277 |
+
###Forward pass of the model.###
|
278 |
use_cache = use_cache if use_cache is not None else self.use_cache
|
279 |
batch_size, seq_length = input_ids.shape
|
280 |
if position_ids is None:
|
|
|
311 |
# --- Tokenizer ---
|
312 |
|
313 |
class SmolLM2Tokenizer:
|
314 |
+
###Tokenizer for SmolLM2-360M using SentencePiece or a rudimentary BPE.###
|
315 |
def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."):
|
316 |
self.tokenizer_path = tokenizer_path
|
317 |
self.special_tokens_map_path = special_tokens_map_path
|
|
|
355 |
self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens]
|
356 |
|
357 |
def update_special_tokens_from_sp(self):
|
358 |
+
###Update special token IDs from SentencePiece model, if present.###
|
359 |
for token_name, token_data in self.special_tokens_map.items():
|
360 |
sp_id = self.sp_model.piece_to_id(token_data['content'])
|
361 |
if sp_id != self.sp_model.unk_id():
|
|
|
387 |
|
388 |
|
389 |
def bpe(self, token: str) -> List[str]:
|
390 |
+
###Rudimentary BPE tokenization.###
|
391 |
if not self.use_sentencepiece:
|
392 |
word = list(token)
|
393 |
while len(word) > 1:
|
|
|
412 |
return [] # If SentencePiece is used, this function is not called.
|
413 |
|
414 |
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
|
415 |
+
###Encode text to token IDs.###
|
416 |
if self.use_sentencepiece:
|
417 |
if add_special_tokens:
|
418 |
return self.sp_model.encode(text, out_type=int) #add_bos=True, add_eos=True if needed, adjust as per model requirement
|
|
|
428 |
return token_ids
|
429 |
|
430 |
def decode(self, token_ids: List[int]) -> str:
|
431 |
+
###Decode token IDs to text.###
|
432 |
if self.use_sentencepiece:
|
433 |
return self.sp_model.decode(token_ids)
|
434 |
else:
|
|
|
439 |
# --- Inference ---
|
440 |
|
441 |
def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str:
|
442 |
+
###Generate text using greedy decoding.###
|
443 |
input_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
444 |
input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
|
445 |
|
|
|
503 |
generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device)
|
504 |
print(f"Generated Text: {generated_text}")
|
505 |
end = timed_step(start, "Prompt Generation")
|
506 |
+
|
507 |
|