Ram07 commited on
Commit
22f2b3e
·
verified ·
1 Parent(s): b0880a6

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - bitnet
8
+ - quantization
9
+ - early-exit
10
+ - layer-skipping
11
+ - efficient-transformers
12
+ datasets:
13
+ - roneneldan/TinyStories
14
+ ---
15
+
16
+ # llama3-earlyexit
17
+
18
+ Llama3-style baseline with full precision weights and activations
19
+
20
+ ## Model Description
21
+
22
+ This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
23
+
24
+ ## Architecture Details
25
+
26
+ - **Layers**: 24
27
+ - **Hidden dimension**: 2048
28
+ - **Attention heads**: 32 (64-dimensional each)
29
+ - **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
30
+ - **FFN intermediate size**: 4096
31
+ - **Position embeddings**: Rotary Position Embeddings (RoPE)
32
+ - **Normalization**: RMSNorm
33
+ - **Activation**: SwiGLU (for MLP)
34
+ - **Parameters**: ~1.06B
35
+
36
+ ### Quantization Scheme
37
+
38
+ - **Weights**: Full precision (FP32)
39
+ - **Activations**: Full precision (FP32)
40
+ - **Hadamard**: No
41
+
42
+ ## Training Details
43
+
44
+ ### Dataset
45
+ - **Source**: TinyStories (2.1M stories)
46
+ - **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
47
+ - **Sequence length**: 512 tokens
48
+
49
+ ### Training Techniques
50
+
51
+ **Quadratic Layer Dropout:**
52
+ - Progressive dropout: p_l = 0.5 × (l/L)²
53
+ - Normalized so Σp_l = 1.0
54
+ - Never drops final layer
55
+ - Makes earlier layers more accurate
56
+
57
+ **Early Exit Loss:**
58
+ - All layers share the same LM head
59
+ - Loss = main_loss + 0.3 × early_exit_loss
60
+ - Layer-proportional weighting: w_i = (i+1)/L
61
+ - Enables flexible early exit at inference
62
+
63
+ ### Hyperparameters
64
+
65
+ - **Optimizer**: AdamW
66
+ - **Learning rate**: 6e-4
67
+ - **Warmup steps**: 1000
68
+ - **Batch size**: 16 (effective: 64)
69
+ - **Training steps**: 50000
70
+ - **Gradient clipping**: 1.0
71
+
72
+ ## Performance
73
+
74
+ ### Perplexity (TinyStories validation)
75
+
76
+ | Exit Layer | Perplexity | Speed (tok/s) |
77
+ |------------|------------|---------------|
78
+ | All layers | TBD | TBD |
79
+ | Layer 18 | TBD | TBD |
80
+ | Layer 12 | TBD | TBD |
81
+ | Layer 6 | TBD | TBD |
82
+
83
+ ### Training Stability
84
+
85
+ - **Gradient norms**: TBD
86
+ - **Final loss**: TBD
87
+
88
+ ## Usage
89
+
90
+ ### Installation
91
+
92
+ ```bash
93
+ pip install transformers torch
94
+ ```
95
+
96
+ ### Basic Inference
97
+
98
+ ```python
99
+ from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained("your-username/llama3-earlyexit")
103
+ tokenizer = AutoTokenizer.from_pretrained("your-username/llama3-earlyexit")
104
+
105
+ # Generate text
106
+ inputs = tokenizer("Once upon a time", return_tensors="pt")
107
+ outputs = model.generate(**inputs, max_length=100)
108
+ print(tokenizer.decode(outputs[0]))
109
+ ```
110
+
111
+ ### Early Exit Inference
112
+
113
+ ```python
114
+ # Exit at layer 12 for faster inference
115
+ model.set_exit_layer(12)
116
+ outputs = model.generate(**inputs, max_length=100)
117
+ # 1.5-2x faster with minimal quality loss
118
+ ```
119
+
120
+ ### Benchmark Different Exit Layers
121
+
122
+ ```python
123
+ for exit_layer in [6, 12, 18, 24]:
124
+ model.set_exit_layer(exit_layer)
125
+ outputs = model.generate(**inputs, max_length=100)
126
+ print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
127
+ ```
128
+
129
+ ## Limitations
130
+
131
+ - **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
132
+ - **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
133
+ - **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
134
+
135
+ ## Citation
136
+
137
+ If you use this model, please cite:
138
+
139
+ ```bibtex
140
+ @article{bitnet,
141
+ title={BitNet: Scaling 1-bit Transformers for Large Language Models},
142
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
143
+ journal={arXiv preprint arXiv:2310.11453},
144
+ year={2023}
145
+ }
146
+
147
+ @article{layerskip,
148
+ title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
149
+ author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
150
+ journal={arXiv preprint arXiv:2404.16710},
151
+ year={2024}
152
+ }
153
+ ```
154
+
155
+ ## License
156
+
157
+ MIT License
158
+
159
+ ## Contact
160
+
161
+ For questions or issues, please open an issue on the model repository.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Llama3ForCausalLMWithEarlyExit"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "llama3_earlyexit.Llama3EarlyExitConfig",
7
+ "AutoModelForCausalLM": "llama3_earlyexit.Llama3ForCausalLMWithEarlyExit"
8
+ },
9
+ "early_exit_loss_weight": 0.3,
10
+ "hidden_size": 2048,
11
+ "inference_exit_layer": null,
12
+ "intermediate_size": 4096,
13
+ "max_dropout_prob": 0.5,
14
+ "max_position_embeddings": 2048,
15
+ "model_type": "llama3_earlyexit",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_key_value_heads": 8,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.45.2",
23
+ "vocab_size": 50257
24
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.45.2"
4
+ }
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for llama3-earlyexit
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ def main():
9
+ # Load from HuggingFace Hub or local path
10
+ model_path = "." # Current directory or specify repo_id
11
+
12
+ print("Loading model...")
13
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ model.eval()
17
+ print("Model loaded!")
18
+
19
+ # Example generation
20
+ prompt = "Once upon a time"
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+
23
+ print(f"\nPrompt: {prompt}\n")
24
+
25
+ # Full model
26
+ print("Generating with all layers...")
27
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
28
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
29
+
30
+ # Early exit at layer 12
31
+ print("\nGenerating with early exit at layer 12...")
32
+ model.set_exit_layer(12)
33
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
34
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
35
+
36
+ if __name__ == "__main__":
37
+ main()
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:432ab72d0a19d57c95ff3997fd6d5e5e43f51972ae05c298c784f0965bdd4093
3
+ size 3834689608
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model files for llama3-earlyexit"""
models/llama3_earlyexit.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Llama3 with Early Exit Loss and Quadratic Dropout
3
+ - Full precision (no quantization)
4
+ - Quadratic layer dropout (normalized sum=1)
5
+ - Early exit loss from all layers
6
+ - HuggingFace compatible baseline
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import math
12
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+
16
+ class Llama3EarlyExitConfig(PretrainedConfig):
17
+ model_type = "llama3_earlyexit"
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=50257,
22
+ hidden_size=2048,
23
+ num_hidden_layers=24,
24
+ num_attention_heads=32,
25
+ num_key_value_heads=8,
26
+ intermediate_size=4096,
27
+ max_position_embeddings=2048,
28
+ rms_norm_eps=1e-5,
29
+ rope_theta=10000.0,
30
+ early_exit_loss_weight=0.3,
31
+ max_dropout_prob=0.5,
32
+ inference_exit_layer=None,
33
+ **kwargs
34
+ ):
35
+ self.vocab_size = vocab_size
36
+ self.hidden_size = hidden_size
37
+ self.num_hidden_layers = num_hidden_layers
38
+ self.num_attention_heads = num_attention_heads
39
+ self.num_key_value_heads = num_key_value_heads
40
+ self.intermediate_size = intermediate_size
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.rms_norm_eps = rms_norm_eps
43
+ self.rope_theta = rope_theta
44
+ self.early_exit_loss_weight = early_exit_loss_weight
45
+ self.max_dropout_prob = max_dropout_prob
46
+ self.inference_exit_layer = inference_exit_layer
47
+ super().__init__(**kwargs)
48
+
49
+
50
+ class QuadraticLayerDropout(nn.Module):
51
+ """Quadratic layer dropout normalized to sum=1."""
52
+
53
+ def __init__(self, num_layers, max_dropout_prob=0.5):
54
+ super().__init__()
55
+ self.num_layers = num_layers
56
+
57
+ dropout_probs = []
58
+ for i in range(num_layers):
59
+ prob = max_dropout_prob * ((i / max(num_layers - 1, 1)) ** 2)
60
+ dropout_probs.append(prob)
61
+
62
+ total_prob = sum(dropout_probs)
63
+ if total_prob > 0:
64
+ dropout_probs = [p / total_prob for p in dropout_probs]
65
+
66
+ self.dropout_probs = dropout_probs
67
+
68
+ def should_drop_layer(self, layer_idx):
69
+ if not self.training or layer_idx >= self.num_layers - 1:
70
+ return False
71
+ return torch.rand(1).item() < self.dropout_probs[layer_idx]
72
+
73
+
74
+ class RMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.ones(hidden_size))
78
+ self.variance_epsilon = eps
79
+
80
+ def forward(self, hidden_states):
81
+ input_dtype = hidden_states.dtype
82
+ hidden_states = hidden_states.to(torch.float32)
83
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
84
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
85
+ return self.weight * hidden_states.to(input_dtype)
86
+
87
+
88
+ class RotaryEmbedding(nn.Module):
89
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
90
+ super().__init__()
91
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
92
+ self.register_buffer("inv_freq", inv_freq)
93
+
94
+ def forward(self, x, position_ids):
95
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
96
+ position_ids_expanded = position_ids[:, None, :].float()
97
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
98
+ emb = torch.cat((freqs, freqs), dim=-1)
99
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
100
+
101
+
102
+ def rotate_half(x):
103
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
104
+ return torch.cat((-x2, x1), dim=-1)
105
+
106
+
107
+ def apply_rotary_pos_emb(q, k, cos, sin):
108
+ q_embed = (q * cos) + (rotate_half(q) * sin)
109
+ k_embed = (k * cos) + (rotate_half(k) * sin)
110
+ return q_embed, k_embed
111
+
112
+
113
+ class Llama3Attention(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.hidden_size = config.hidden_size
117
+ self.num_heads = config.num_attention_heads
118
+ self.head_dim = self.hidden_size // self.num_heads
119
+ self.num_key_value_heads = config.num_key_value_heads
120
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
121
+
122
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
123
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
124
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
125
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
126
+
127
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
128
+
129
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
130
+ bsz, q_len, _ = hidden_states.size()
131
+
132
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
133
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
134
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
135
+
136
+ cos, sin = self.rotary_emb(value_states, position_ids)
137
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
138
+
139
+ if past_key_value is not None:
140
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
141
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
142
+
143
+ past_key_value = (key_states, value_states) if use_cache else None
144
+
145
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
146
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
147
+
148
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
149
+ if attention_mask is not None:
150
+ attn_weights = attn_weights + attention_mask
151
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
152
+ attn_output = torch.matmul(attn_weights, value_states)
153
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
154
+ attn_output = self.o_proj(attn_output)
155
+
156
+ return attn_output, None, past_key_value
157
+
158
+
159
+ class Llama3MLP(nn.Module):
160
+ def __init__(self, config):
161
+ super().__init__()
162
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
163
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
164
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
165
+
166
+ def forward(self, x):
167
+ return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
168
+
169
+
170
+ class Llama3DecoderLayer(nn.Module):
171
+ def __init__(self, config):
172
+ super().__init__()
173
+ self.self_attn = Llama3Attention(config)
174
+ self.mlp = Llama3MLP(config)
175
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
176
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
177
+
178
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
179
+ residual = hidden_states
180
+ hidden_states = self.input_layernorm(hidden_states)
181
+ hidden_states, _, present_key_value = self.self_attn(
182
+ hidden_states, attention_mask, position_ids, past_key_value, use_cache
183
+ )
184
+ hidden_states = residual + hidden_states
185
+
186
+ residual = hidden_states
187
+ hidden_states = self.post_attention_layernorm(hidden_states)
188
+ hidden_states = self.mlp(hidden_states)
189
+ hidden_states = residual + hidden_states
190
+
191
+ return (hidden_states,) + ((present_key_value,) if use_cache else ())
192
+
193
+
194
+ class Llama3PreTrainedModel(PreTrainedModel):
195
+ config_class = Llama3EarlyExitConfig
196
+ base_model_prefix = "model"
197
+ supports_gradient_checkpointing = True
198
+
199
+ def _init_weights(self, module):
200
+ if isinstance(module, nn.Linear):
201
+ module.weight.data.normal_(mean=0.0, std=0.02)
202
+ if module.bias is not None:
203
+ module.bias.data.zero_()
204
+ elif isinstance(module, nn.Embedding):
205
+ module.weight.data.normal_(mean=0.0, std=0.02)
206
+
207
+
208
+ class Llama3Model(Llama3PreTrainedModel):
209
+ def __init__(self, config):
210
+ super().__init__(config)
211
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
212
+ self.layers = nn.ModuleList([Llama3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
213
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
214
+ self.gradient_checkpointing = False
215
+ self.layer_dropout = QuadraticLayerDropout(config.num_hidden_layers, config.max_dropout_prob)
216
+ self.post_init()
217
+
218
+ def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, output_hidden_states=False, return_all_layer_outputs=False):
219
+ hidden_states = self.embed_tokens(input_ids)
220
+
221
+ if position_ids is None:
222
+ position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
223
+ position_ids = position_ids.unsqueeze(0)
224
+
225
+ next_decoder_cache = () if use_cache else None
226
+ all_layer_hidden_states = []
227
+
228
+ num_layers_to_run = self.config.inference_exit_layer if self.config.inference_exit_layer else len(self.layers)
229
+ num_layers_to_run = min(num_layers_to_run, len(self.layers))
230
+
231
+ for idx in range(num_layers_to_run):
232
+ layer = self.layers[idx]
233
+ past_key_value = past_key_values[idx] if past_key_values else None
234
+
235
+ if self.training and self.layer_dropout.should_drop_layer(idx):
236
+ all_layer_hidden_states.append(hidden_states)
237
+ continue
238
+
239
+ if self.gradient_checkpointing and self.training:
240
+ layer_outputs = self._gradient_checkpointing_func(
241
+ layer.__call__,
242
+ hidden_states,
243
+ attention_mask,
244
+ position_ids,
245
+ past_key_value,
246
+ use_cache,
247
+ )
248
+ else:
249
+ layer_outputs = layer(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
250
+
251
+ hidden_states = layer_outputs[0]
252
+ all_layer_hidden_states.append(hidden_states)
253
+
254
+ if use_cache:
255
+ next_decoder_cache += (layer_outputs[1],)
256
+
257
+ hidden_states = self.norm(hidden_states)
258
+ all_layer_hidden_states.append(hidden_states)
259
+
260
+ if return_all_layer_outputs:
261
+ return hidden_states, next_decoder_cache, all_layer_hidden_states
262
+ else:
263
+ return hidden_states, next_decoder_cache, None
264
+
265
+
266
+ class Llama3ForCausalLMWithEarlyExit(Llama3PreTrainedModel, GenerationMixin):
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = Llama3Model(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.post_init()
275
+
276
+ def get_input_embeddings(self):
277
+ return self.model.embed_tokens
278
+
279
+ def set_input_embeddings(self, value):
280
+ self.model.embed_tokens = value
281
+
282
+ def get_output_embeddings(self):
283
+ return self.lm_head
284
+
285
+ def set_output_embeddings(self, new_embeddings):
286
+ self.lm_head = new_embeddings
287
+
288
+ def compute_early_exit_loss(self, all_layer_hidden_states, labels):
289
+ """Compute early exit loss with layer-proportional weighting."""
290
+ num_layers = len(all_layer_hidden_states)
291
+
292
+ weights = [(i + 1) / num_layers for i in range(num_layers)]
293
+ weight_sum = sum(weights)
294
+ weights = [w / weight_sum for w in weights]
295
+
296
+ total_exit_loss = 0.0
297
+
298
+ for i, hidden_states in enumerate(all_layer_hidden_states):
299
+ logits = self.lm_head(hidden_states)
300
+ shift_logits = logits[..., :-1, :].contiguous()
301
+ shift_labels = labels[..., 1:].contiguous()
302
+
303
+ loss_fct = nn.CrossEntropyLoss()
304
+ layer_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
305
+
306
+ total_exit_loss += weights[i] * layer_loss
307
+
308
+ return total_exit_loss
309
+
310
+ def forward(
311
+ self,
312
+ input_ids=None,
313
+ attention_mask=None,
314
+ position_ids=None,
315
+ past_key_values=None,
316
+ inputs_embeds=None,
317
+ labels=None,
318
+ use_cache=None,
319
+ output_attentions=None,
320
+ output_hidden_states=None,
321
+ return_dict=None,
322
+ ):
323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
324
+ return_all = self.training and labels is not None
325
+
326
+ hidden_states, past_key_values_output, all_layer_hidden_states = self.model(
327
+ input_ids=input_ids,
328
+ attention_mask=attention_mask,
329
+ position_ids=position_ids,
330
+ past_key_values=past_key_values,
331
+ use_cache=use_cache,
332
+ output_hidden_states=output_hidden_states,
333
+ return_all_layer_outputs=return_all,
334
+ )
335
+
336
+ logits = self.lm_head(hidden_states)
337
+ logits = logits.float()
338
+
339
+ loss = None
340
+ if labels is not None:
341
+ shift_logits = logits[..., :-1, :].contiguous()
342
+ shift_labels = labels[..., 1:].contiguous()
343
+ loss_fct = nn.CrossEntropyLoss()
344
+ main_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
345
+
346
+ if all_layer_hidden_states is not None and len(all_layer_hidden_states) > 0:
347
+ early_exit_loss = self.compute_early_exit_loss(all_layer_hidden_states[:-1], labels)
348
+ loss = main_loss + self.config.early_exit_loss_weight * early_exit_loss
349
+ else:
350
+ loss = main_loss
351
+
352
+ if not return_dict:
353
+ output = (logits,) + (past_key_values_output,)
354
+ return (loss,) + output if loss is not None else output
355
+
356
+ return CausalLMOutputWithPast(
357
+ loss=loss,
358
+ logits=logits,
359
+ past_key_values=past_key_values_output,
360
+ hidden_states=None,
361
+ attentions=None,
362
+ )
363
+
364
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
365
+ if past_key_values is not None:
366
+ past_length = past_key_values[0][0].shape[2]
367
+ if input_ids.shape[1] > past_length:
368
+ remove_prefix_length = past_length
369
+ else:
370
+ remove_prefix_length = input_ids.shape[1] - 1
371
+ input_ids = input_ids[:, remove_prefix_length:]
372
+
373
+ position_ids = kwargs.get("position_ids", None)
374
+ if attention_mask is not None and position_ids is None:
375
+ position_ids = attention_mask.long().cumsum(-1) - 1
376
+ position_ids.masked_fill_(attention_mask == 0, 1)
377
+ if past_key_values:
378
+ position_ids = position_ids[:, -input_ids.shape[1] :]
379
+
380
+ if inputs_embeds is not None and past_key_values is None:
381
+ model_inputs = {"inputs_embeds": inputs_embeds}
382
+ else:
383
+ model_inputs = {"input_ids": input_ids}
384
+
385
+ model_inputs.update({
386
+ "position_ids": position_ids,
387
+ "past_key_values": past_key_values,
388
+ "use_cache": kwargs.get("use_cache"),
389
+ "attention_mask": attention_mask,
390
+ })
391
+ return model_inputs
392
+
393
+ @staticmethod
394
+ def _reorder_cache(past_key_values, beam_idx):
395
+ reordered_past = ()
396
+ for layer_past in past_key_values:
397
+ reordered_past += (
398
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
399
+ )
400
+ return reordered_past
401
+
402
+ def set_exit_layer(self, exit_layer):
403
+ self.config.inference_exit_layer = exit_layer
404
+ self.model.config.inference_exit_layer = exit_layer
405
+
406
+
407
+ Llama3EarlyExitConfig.register_for_auto_class()
408
+ Llama3ForCausalLMWithEarlyExit.register_for_auto_class("AutoModelForCausalLM")
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "model_max_length": 1024,
17
+ "pad_token": "<|endoftext|>",
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff