appledora commited on
Commit
7a1d06b
1 Parent(s): 3f11215

Upload 6 files

Browse files
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import (
2
+ OptionalDependencyNotAvailable,
3
+ _LazyModule,
4
+ is_torch_available,
5
+ )
6
+
7
+ try:
8
+ if not is_torch_available():
9
+ raise OptionalDependencyNotAvailable()
10
+ except OptionalDependencyNotAvailable:
11
+ pass
12
+ else:
13
+ from .modeling_recastmlp_llama import (
14
+ RECASTMLP_llamaModel,
15
+ RECASTMLP_LlamaForCausalLM,
16
+ )
17
+ from .configuration_recastmlp_llama import RECASTMLP_llama
18
+
19
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
20
+
21
+ # Register your models with Auto classes
22
+ AutoConfig.register("recastmlp_llama", RECASTMLP_llama)
23
+ AutoModel.register(RECASTMLP_llama, RECASTMLP_llamaModel)
24
+ AutoModelForCausalLM.register(RECASTMLP_llama, RECASTMLP_LlamaForCausalLM)
25
+
26
+ _import_structure = {
27
+ "configuration_recastmlp_llama": ["RECASTMLP_llama"],
28
+ "modeling_recastmlp_llama": ["RECASTMLP_llamaModel", "RECASTMLP_LlamaForCausalLM"],
29
+ }
30
+
31
+ __all__ = ["RECASTMLP_llamaModel", "RECASTMLP_LlamaForCausalLM", "RECASTMLP_llama"]
config.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 128256,
3
+ "max_position_embeddings": 131072,
4
+ "hidden_size": 4096,
5
+ "intermediate_size": 14336,
6
+ "num_hidden_layers": 32,
7
+ "num_attention_heads": 32,
8
+ "num_key_value_heads": 8,
9
+ "hidden_act": "silu",
10
+ "initializer_range": 0.02,
11
+ "rms_norm_eps": 1e-05,
12
+ "pretraining_tp": 1,
13
+ "use_cache": true,
14
+ "mlp_bias": false,
15
+ "attention_bias": false,
16
+ "attention_dropout": 0.0,
17
+ "rope_theta": 500000.0,
18
+ "rope_scaling": {
19
+ "factor": 8.0,
20
+ "low_freq_factor": 1.0,
21
+ "high_freq_factor": 4.0,
22
+ "original_max_position_embeddings": 8192,
23
+ "rope_type": "llama3"
24
+ },
25
+ "torch_dtype": null,
26
+ "num_templates": 4,
27
+ "num_groups": 8,
28
+ "num_cf": 1,
29
+ "return_dict": true,
30
+ "output_hidden_states": false,
31
+ "output_attentions": false,
32
+ "torchscript": false,
33
+ "use_bfloat16": false,
34
+ "tf_legacy_loss": false,
35
+ "pruned_heads": {},
36
+ "tie_word_embeddings": false,
37
+ "chunk_size_feed_forward": 0,
38
+ "is_encoder_decoder": false,
39
+ "is_decoder": false,
40
+ "cross_attention_hidden_size": null,
41
+ "add_cross_attention": false,
42
+ "tie_encoder_decoder": false,
43
+ "max_length": 20,
44
+ "min_length": 0,
45
+ "do_sample": false,
46
+ "early_stopping": false,
47
+ "num_beams": 1,
48
+ "num_beam_groups": 1,
49
+ "diversity_penalty": 0.0,
50
+ "temperature": 1.0,
51
+ "top_k": 50,
52
+ "top_p": 1.0,
53
+ "typical_p": 1.0,
54
+ "repetition_penalty": 1.0,
55
+ "length_penalty": 1.0,
56
+ "no_repeat_ngram_size": 0,
57
+ "encoder_no_repeat_ngram_size": 0,
58
+ "bad_words_ids": null,
59
+ "num_return_sequences": 1,
60
+ "output_scores": false,
61
+ "return_dict_in_generate": false,
62
+ "forced_bos_token_id": null,
63
+ "forced_eos_token_id": null,
64
+ "remove_invalid_values": false,
65
+ "exponential_decay_length_penalty": null,
66
+ "suppress_tokens": null,
67
+ "begin_suppress_tokens": null,
68
+ "architectures": [
69
+ "RECASTMLP_LlamaForCausalLM"
70
+ ],
71
+ "finetuning_task": null,
72
+ "id2label": {
73
+ "0": "LABEL_0",
74
+ "1": "LABEL_1"
75
+ },
76
+ "label2id": {
77
+ "LABEL_0": 0,
78
+ "LABEL_1": 1
79
+ },
80
+ "tokenizer_class": null,
81
+ "prefix": null,
82
+ "bos_token_id": 128000,
83
+ "pad_token_id": null,
84
+ "eos_token_id": 128001,
85
+ "sep_token_id": null,
86
+ "decoder_start_token_id": null,
87
+ "task_specific_params": null,
88
+ "problem_type": null,
89
+ "_name_or_path": "",
90
+ "transformers_version": "4.36.0",
91
+ "model_type": "recastmlp_llama",
92
+ "auto_map": {
93
+ "AutoConfig": "configuration_recastmlp_llama.RECASTMLP_llama",
94
+ "AutoModel": "modeling_recastmlp_llama.RECASTMLP_llamaModel",
95
+ "AutoModelForCausalLM": "modeling_recastmlp_llama.RECASTMLP_LlamaForCausalLM"
96
+ }
97
+ }
configuration_recastmlp_llama.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RECASTMLP_llama(PretrainedConfig):
5
+ model_type = "recastmlp_llama"
6
+ attribute_map = {
7
+ "hidden_size": "hidden_size",
8
+ "num_attention_heads": "num_attention_heads",
9
+ }
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=128256,
14
+ hidden_size=4096,
15
+ intermediate_size=14336,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=8,
19
+ hidden_act="silu",
20
+ max_position_embeddings=131072,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=128000,
26
+ eos_token_id=128001,
27
+ pretraining_tp=1,
28
+ tie_word_embeddings=False,
29
+ rope_theta=500000.0,
30
+ rope_scaling={
31
+ "factor": 8.0,
32
+ "low_freq_factor": 1.0,
33
+ "high_freq_factor": 4.0,
34
+ "original_max_position_embeddings": 8192,
35
+ "rope_type": "llama3",
36
+ },
37
+ attention_bias=False,
38
+ attention_dropout=0.0,
39
+ mlp_bias=False,
40
+ # Template-specific configs
41
+ num_templates=4,
42
+ num_groups=8,
43
+ num_cf=1,
44
+ torch_dtype="bfloat16",
45
+ **kwargs
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_size = hidden_size
50
+ self.intermediate_size = intermediate_size
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.pretraining_tp = pretraining_tp
58
+ self.use_cache = use_cache
59
+ self.mlp_bias = mlp_bias
60
+ self.attention_bias = attention_bias
61
+ self.attention_dropout = attention_dropout
62
+ self.rope_theta = rope_theta
63
+ self.rope_scaling = rope_scaling
64
+ self.torch_dtype = torch_dtype
65
+
66
+ # Template-specific configs
67
+ self.num_templates = num_templates
68
+ self.num_groups = num_groups
69
+ self.num_cf = num_cf
70
+
71
+ super().__init__(
72
+ pad_token_id=pad_token_id,
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ tie_word_embeddings=tie_word_embeddings,
76
+ **kwargs
77
+ )
metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"library_name": "transformers", "model_type": "recastmlp_llama", "architectures": ["RECASTMLP_llamaModel"]}
model_card.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - llama
5
+ - template-mlp
6
+ - parameter-efficient
7
+ - mlp-modification
8
+ datasets:
9
+ - none
10
+ license: apache-2.0
11
+ pipeline_tag: text-generation
12
+ library_name: transformers
13
+ ---
14
+
15
+ # RECASTMLP-LLaMA
16
+
17
+ This model implements a parameter-efficient modification of the LLaMA architecture by replacing the standard MLP layers with template-based shared MLPs. The model maintains LLaMA's attention mechanism while reducing parameters in the feed-forward networks.
18
+
19
+ ## Model Description
20
+
21
+ ### Overview
22
+ RECASTMLP-LLaMA modifies the original LLaMA architecture by introducing template banks for MLP layers. Instead of having separate MLP weights for each transformer layer, it uses a shared set of template weights that are combined using learned coefficients.
23
+
24
+ ### Architecture Details
25
+ - **Base Model:** LLaMA 3.1 8B
26
+ - **Number of Templates:** 4
27
+ - **Number of Groups:** 8
28
+ - **Coefficients per Template:** 1
29
+ - **Coefficients** 392
30
+ - **Hidden Size:** 4096
31
+ - **Intermediate Size:** 14336
32
+ - **Number of Attention Heads:** 32
33
+ - **Number of Key-Value Heads:** 8
34
+ - **Number of Layers:** 32
35
+ - **Max Position Embeddings:** 131072
36
+ - **Vocabulary Size:** 128256
37
+
38
+
39
+ ### Key Features
40
+ 1. **Template Banks:** Uses shared template weights across groups of layers
41
+ 2. **Parameter Efficiency:** Reduces the total number of parameters by sharing MLP weights
42
+ 3. **Group-wise Sharing:** Organizes layers into groups that share template banks
43
+ 4. **Coefficient Learning:** Uses learned coefficients to combine template weights
44
+
45
+ ## Usage
46
+
47
+ ```python
48
+ from transformers import AutoModel, AutoTokenizer
49
+
50
+ # Load model and tokenizer
51
+ model = AutoModel.from_pretrained("appledora/RECASTMLP-llama3.1-f8t4", trust_remote_code=True)
52
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8b")
53
+
54
+ # Prepare input
55
+ text = "Hello, how are you?"
56
+ inputs = tokenizer(text, return_tensors="pt")
57
+
58
+ # Generate output
59
+ outputs = model(**inputs)
60
+ hidden_states = outputs.last_hidden_state
modeling_recastmlp_llama.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recastmlp_llama import RECASTMLP_llama
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, Union, List
8
+ from transformers import AutoConfig
9
+ from transformers.utils import logging
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class MLPTemplateBank(nn.Module):
19
+ def __init__(self, config, num_templates):
20
+ """
21
+ Initialize template bank for MLP layers
22
+ Args:
23
+ config: LlamaConfig instance
24
+ num_templates: Number of templates in bank
25
+ """
26
+ super().__init__()
27
+ self.num_templates = config.num_templates
28
+ self.hidden_size = config.hidden_size
29
+ self.intermediate_size = config.intermediate_size
30
+
31
+ # Create templates for gate, up and down projections
32
+ self.gate_templates = nn.Parameter(
33
+ torch.stack(
34
+ [
35
+ torch.empty(self.intermediate_size, self.hidden_size)
36
+ for _ in range(self.num_templates)
37
+ ]
38
+ )
39
+ )
40
+
41
+ self.up_templates = nn.Parameter(
42
+ torch.stack(
43
+ [
44
+ torch.empty(self.intermediate_size, self.hidden_size)
45
+ for _ in range(self.num_templates)
46
+ ]
47
+ )
48
+ )
49
+
50
+ self.down_templates = nn.Parameter(
51
+ torch.stack(
52
+ [
53
+ torch.empty(self.hidden_size, self.intermediate_size)
54
+ for _ in range(self.num_templates)
55
+ ]
56
+ )
57
+ )
58
+
59
+ # Initialize templates
60
+ for i in range(self.num_templates):
61
+ nn.init.kaiming_normal_(self.gate_templates[i])
62
+ nn.init.kaiming_normal_(self.up_templates[i])
63
+ nn.init.kaiming_normal_(self.down_templates[i])
64
+
65
+ self.coefficient_shape = (self.num_templates, 1, 1)
66
+
67
+ def forward(self, gate_coeffs, up_coeffs, down_coeffs):
68
+ """Generate weights from coefficients"""
69
+ gate_weights = (self.gate_templates * gate_coeffs).sum(0)
70
+ up_weights = (self.up_templates * up_coeffs).sum(0)
71
+ down_weights = (self.down_templates * down_coeffs).sum(0)
72
+ return gate_weights, up_weights, down_weights
73
+
74
+ def __repr__(self):
75
+ return f"MLPTemplateBank(num_templates={self.num_templates}, hidden_size={self.hidden_size}, intermediate_size={self.intermediate_size})"
76
+
77
+
78
+ class SharedLlamaMLP(nn.Module):
79
+ def __init__(self, config, bank):
80
+ super().__init__()
81
+ self.config = config
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.bank = bank
85
+ num_cf = config.num_cf
86
+
87
+ # Coefficients for template bank
88
+ self.gate_coefficients = nn.ParameterList(
89
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
90
+ )
91
+ self.up_coefficients = nn.ParameterList(
92
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
93
+ )
94
+ self.down_coefficients = nn.ParameterList(
95
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
96
+ )
97
+
98
+ # Initialize coefficients
99
+ for cf in self.gate_coefficients:
100
+ nn.init.orthogonal_(cf)
101
+ for cf in self.up_coefficients:
102
+ nn.init.orthogonal_(cf)
103
+ for cf in self.down_coefficients:
104
+ nn.init.orthogonal_(cf)
105
+
106
+ # Biases
107
+ self.gate_bias = (
108
+ nn.Parameter(torch.zeros(self.intermediate_size))
109
+ if config.mlp_bias
110
+ else None
111
+ )
112
+ self.up_bias = (
113
+ nn.Parameter(torch.zeros(self.intermediate_size))
114
+ if config.mlp_bias
115
+ else None
116
+ )
117
+ self.down_bias = (
118
+ nn.Parameter(torch.zeros(self.hidden_size)) if config.mlp_bias else None
119
+ )
120
+
121
+ # Activation
122
+ # self.act_fn = nn.functional.__dict__[config.hidden_act]
123
+ # self.act_fn = keras.activations.swish
124
+ self.act_fn = F.silu
125
+
126
+ def forward(self, x):
127
+ # Generate weights using coefficients
128
+ gate_weights = []
129
+ up_weights = []
130
+ down_weights = []
131
+
132
+ for i in range(len(self.gate_coefficients)):
133
+ gate, up, down = self.bank(
134
+ self.gate_coefficients[i],
135
+ self.up_coefficients[i],
136
+ self.down_coefficients[i],
137
+ )
138
+ gate_weights.append(gate)
139
+ up_weights.append(up)
140
+ down_weights.append(down)
141
+
142
+ gate_weights = torch.stack(gate_weights).mean(0)
143
+ up_weights = torch.stack(up_weights).mean(0)
144
+ down_weights = torch.stack(down_weights).mean(0)
145
+
146
+ # Apply MLP operations
147
+ gate_output = F.linear(x, gate_weights, self.gate_bias)
148
+ up_output = F.linear(x, up_weights, self.up_bias)
149
+
150
+ # Apply activation and down projection
151
+ hidden_states = self.act_fn(gate_output) * up_output
152
+ output = F.linear(hidden_states, down_weights, self.down_bias)
153
+
154
+ return output
155
+
156
+ def __repr__(self):
157
+ return (
158
+ f"SharedLlamaMLP(hidden_size={self.hidden_size}, "
159
+ f"intermediate_size={self.intermediate_size}, "
160
+ f"gate_coefficients={len(self.gate_coefficients)}, "
161
+ f"up_coefficients={len(self.up_coefficients)}, "
162
+ f"down_coefficients={len(self.down_coefficients)})"
163
+ )
164
+
165
+
166
+ def fixed_cross_entropy(
167
+ source,
168
+ target,
169
+ num_items_in_batch: int = None,
170
+ ignore_index: int = -100,
171
+ **kwargs,
172
+ ):
173
+ reduction = "sum" if num_items_in_batch is not None else "mean"
174
+ loss = nn.functional.cross_entropy(
175
+ source, target, ignore_index=ignore_index, reduction=reduction
176
+ )
177
+ if reduction == "sum":
178
+ loss = loss / num_items_in_batch
179
+ return loss
180
+
181
+
182
+ from transformers.models.llama.modeling_llama import (
183
+ LlamaDecoderLayer,
184
+ LlamaRotaryEmbedding,
185
+ LlamaRMSNorm,
186
+ apply_rotary_pos_emb,
187
+ )
188
+ from transformers.modeling_outputs import BaseModelOutputWithPast
189
+
190
+
191
+ class RECASTMLP_llamaModel(PreTrainedModel):
192
+ config_class = RECASTMLP_llama
193
+ base_model_prefix = "llama"
194
+ supports_gradient_checkpointing = True
195
+
196
+ def __init__(self, config):
197
+ super().__init__(config)
198
+ self.padding_idx = config.pad_token_id
199
+ self.vocab_size = config.vocab_size
200
+
201
+ self.embed_tokens = nn.Embedding(
202
+ config.vocab_size, config.hidden_size, self.padding_idx
203
+ )
204
+ # Initialize rotary embeddings
205
+ rope_config = config.rope_scaling
206
+ if rope_config:
207
+ rope_type = rope_config.get("rope_type", "default")
208
+ scaling_factor = rope_config.get("factor", 1.0)
209
+ else:
210
+ rope_type = "default"
211
+ scaling_factor = None
212
+ original_config = AutoConfig.from_pretrained(
213
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
214
+ )
215
+ self.rotary_emb = LlamaRotaryEmbedding(
216
+ config=original_config,
217
+ )
218
+
219
+ # Create template banks first
220
+ self.banks = []
221
+ layers_per_group = config.num_hidden_layers // config.num_groups
222
+ for _ in range(config.num_groups):
223
+ bank = MLPTemplateBank(config, config.num_templates)
224
+ self.banks.append(bank)
225
+
226
+ # Create layers using LlamaDecoderLayer but replace MLPs
227
+ self.layers = nn.ModuleList()
228
+ for layer_idx in range(config.num_hidden_layers):
229
+ # Create standard LlamaDecoderLayer
230
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
231
+
232
+ # Replace its MLP with our SharedLlamaMLP
233
+ group_idx = layer_idx // layers_per_group
234
+ group_bank = self.banks[group_idx]
235
+ decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
236
+
237
+ self.layers.append(decoder_layer)
238
+
239
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+ self.gradient_checkpointing = False
241
+
242
+ def forward(
243
+ self,
244
+ input_ids: torch.LongTensor = None,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ position_ids: Optional[torch.LongTensor] = None,
247
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
248
+ inputs_embeds: Optional[torch.FloatTensor] = None,
249
+ use_cache: Optional[bool] = None,
250
+ output_attentions: Optional[bool] = None,
251
+ output_hidden_states: Optional[bool] = None,
252
+ return_dict: Optional[bool] = None,
253
+ cache_position: Optional[torch.LongTensor] = None,
254
+ **flash_attn_kwargs,
255
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
267
+ return_dict = (
268
+ return_dict if return_dict is not None else self.config.use_return_dict
269
+ )
270
+
271
+ if (input_ids is None) ^ (inputs_embeds is not None):
272
+ raise ValueError(
273
+ "You must specify exactly one of input_ids or inputs_embeds"
274
+ )
275
+
276
+ if self.gradient_checkpointing and self.training and use_cache:
277
+ logger.warning_once(
278
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
279
+ )
280
+ use_cache = False
281
+
282
+ if inputs_embeds is None:
283
+ inputs_embeds = self.embed_tokens(input_ids)
284
+
285
+ # Create position embeddings to be shared across the decoder layers
286
+ if position_ids is None:
287
+ past_seen_tokens = (
288
+ past_key_values.get_seq_length() if past_key_values is not None else 0
289
+ )
290
+ position_ids = torch.arange(
291
+ past_seen_tokens,
292
+ past_seen_tokens + inputs_embeds.shape[1],
293
+ device=inputs_embeds.device,
294
+ ).unsqueeze(0)
295
+
296
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
297
+ hidden_states = inputs_embeds
298
+
299
+ # Get updated causal mask
300
+ causal_mask = self._update_causal_mask(
301
+ attention_mask,
302
+ inputs_embeds,
303
+ cache_position,
304
+ past_key_values,
305
+ output_attentions,
306
+ )
307
+
308
+ # Initialize outputs
309
+ all_hidden_states = () if output_hidden_states else None
310
+ all_self_attns = () if output_attentions else None
311
+ next_decoder_cache = None
312
+
313
+ # Process through layers
314
+ for decoder_layer in self.layers:
315
+ if output_hidden_states:
316
+ all_hidden_states += (hidden_states,)
317
+
318
+ if self.gradient_checkpointing and self.training:
319
+ layer_outputs = self._gradient_checkpointing_func(
320
+ decoder_layer.__call__,
321
+ hidden_states,
322
+ causal_mask,
323
+ position_ids,
324
+ past_key_values,
325
+ output_attentions,
326
+ use_cache,
327
+ position_embeddings,
328
+ )
329
+ else:
330
+ layer_outputs = decoder_layer(
331
+ hidden_states,
332
+ attention_mask=causal_mask,
333
+ position_ids=position_ids,
334
+ past_key_value=past_key_values,
335
+ output_attentions=output_attentions,
336
+ use_cache=use_cache,
337
+ position_embeddings=position_embeddings,
338
+ **flash_attn_kwargs,
339
+ )
340
+
341
+ hidden_states = layer_outputs[0]
342
+
343
+ if use_cache:
344
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
345
+
346
+ if output_attentions:
347
+ all_self_attns += (layer_outputs[1],)
348
+
349
+ # Final layer norm
350
+ hidden_states = self.norm(hidden_states)
351
+
352
+ # Add last hidden state
353
+ if output_hidden_states:
354
+ all_hidden_states += (hidden_states,)
355
+
356
+ next_cache = next_decoder_cache if use_cache else None
357
+
358
+ if not return_dict:
359
+ return tuple(
360
+ v
361
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
362
+ if v is not None
363
+ )
364
+
365
+ return BaseModelOutputWithPast(
366
+ last_hidden_state=hidden_states,
367
+ past_key_values=next_cache,
368
+ hidden_states=all_hidden_states,
369
+ attentions=all_self_attns,
370
+ )
371
+
372
+ @classmethod
373
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
374
+ if isinstance(
375
+ pretrained_model_name_or_path, str
376
+ ) and pretrained_model_name_or_path.endswith(".pt"):
377
+ print("Loading from local checkpoint")
378
+ # Load from local checkpoint
379
+ config = kwargs.get("config", None)
380
+ if config is None:
381
+ config = AutoConfig.from_pretrained(
382
+ pretrained_model_name_or_path, trust_remote_code=True
383
+ )
384
+
385
+ model = cls(config)
386
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
387
+ state_dict = checkpoint["model_state_dict"]
388
+ logger.info(
389
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
390
+ )
391
+
392
+ missing_keys, unexpected_keys = model.load_state_dict(
393
+ state_dict, strict=False
394
+ )
395
+
396
+ if len(missing_keys) > 0:
397
+ logger.warning(f"Missing keys: {missing_keys}")
398
+ if len(unexpected_keys) > 0:
399
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
400
+
401
+ return model
402
+ else:
403
+ print("Loading from hub")
404
+ # Load from hub using parent's from_pretrained
405
+ return super().from_pretrained(
406
+ pretrained_model_name_or_path, *model_args, **kwargs
407
+ )
408
+
409
+ def get_input_embeddings(self):
410
+ return self.embed_tokens
411
+
412
+ def set_input_embeddings(self, value):
413
+ self.embed_tokens = value
414
+
415
+ def _update_causal_mask(
416
+ self,
417
+ attention_mask: torch.Tensor,
418
+ input_tensor: torch.Tensor,
419
+ cache_position: torch.Tensor,
420
+ past_key_values: Cache,
421
+ output_attentions: bool,
422
+ ):
423
+ if self.config._attn_implementation == "flash_attention_2":
424
+ if attention_mask is not None and 0.0 in attention_mask:
425
+ return attention_mask
426
+ return None
427
+
428
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
429
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
430
+ # to infer the attention mask.
431
+ past_seen_tokens = (
432
+ past_key_values.get_seq_length() if past_key_values is not None else 0
433
+ )
434
+ using_static_cache = isinstance(past_key_values, StaticCache)
435
+
436
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
437
+ if (
438
+ self.config._attn_implementation == "sdpa"
439
+ and not using_static_cache
440
+ and not output_attentions
441
+ ):
442
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
443
+ attention_mask,
444
+ inputs_embeds=input_tensor,
445
+ past_key_values_length=past_seen_tokens,
446
+ is_training=self.training,
447
+ ):
448
+ return None
449
+
450
+ dtype, device = input_tensor.dtype, input_tensor.device
451
+ sequence_length = input_tensor.shape[1]
452
+ if using_static_cache:
453
+ target_length = past_key_values.get_max_cache_shape()
454
+ else:
455
+ target_length = (
456
+ attention_mask.shape[-1]
457
+ if isinstance(attention_mask, torch.Tensor)
458
+ else past_seen_tokens + sequence_length + 1
459
+ )
460
+
461
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
462
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
463
+ attention_mask,
464
+ sequence_length=sequence_length,
465
+ target_length=target_length,
466
+ dtype=dtype,
467
+ device=device,
468
+ cache_position=cache_position,
469
+ batch_size=input_tensor.shape[0],
470
+ )
471
+
472
+ if (
473
+ self.config._attn_implementation == "sdpa"
474
+ and attention_mask is not None
475
+ and attention_mask.device.type == "cuda"
476
+ and not output_attentions
477
+ ):
478
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
479
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
480
+ # Details: https://github.com/pytorch/pytorch/issues/110213
481
+ min_dtype = torch.finfo(dtype).min
482
+ causal_mask = AttentionMaskConverter._unmask_unattended(
483
+ causal_mask, min_dtype
484
+ )
485
+
486
+ return causal_mask
487
+
488
+ @staticmethod
489
+ def _prepare_4d_causal_attention_mask_with_cache_position(
490
+ attention_mask: torch.Tensor,
491
+ sequence_length: int,
492
+ target_length: int,
493
+ dtype: torch.dtype,
494
+ device: torch.device,
495
+ cache_position: torch.Tensor,
496
+ batch_size: int,
497
+ **kwargs,
498
+ ):
499
+ if attention_mask is not None and attention_mask.dim() == 4:
500
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
501
+ causal_mask = attention_mask
502
+ else:
503
+ min_dtype = torch.finfo(dtype).min
504
+ causal_mask = torch.full(
505
+ (sequence_length, target_length),
506
+ fill_value=min_dtype,
507
+ dtype=dtype,
508
+ device=device,
509
+ )
510
+ if sequence_length != 1:
511
+ causal_mask = torch.triu(causal_mask, diagonal=1)
512
+ causal_mask *= torch.arange(
513
+ target_length, device=device
514
+ ) > cache_position.reshape(-1, 1)
515
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
516
+ if attention_mask is not None:
517
+ causal_mask = (
518
+ causal_mask.clone()
519
+ ) # copy to contiguous memory for in-place edit
520
+ mask_length = attention_mask.shape[-1]
521
+ padding_mask = (
522
+ causal_mask[:, :, :, :mask_length]
523
+ + attention_mask[:, None, None, :]
524
+ )
525
+ padding_mask = padding_mask == 0
526
+ causal_mask[:, :, :, :mask_length] = causal_mask[
527
+ :, :, :, :mask_length
528
+ ].masked_fill(padding_mask, min_dtype)
529
+
530
+ return causal_mask
531
+
532
+
533
+ class RECASTMLP_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
534
+ _tied_weights_keys = ["lm_head.weight"]
535
+ _tp_plan = {"lm_head": "colwise_rep"}
536
+ config_class = RECASTMLP_llama
537
+ base_model_prefix = "llama"
538
+ supports_gradient_checkpointing = True
539
+
540
+ def __init__(self, config):
541
+ super().__init__(config)
542
+ self.model = RECASTMLP_llamaModel(config)
543
+ self.vocab_size = config.vocab_size
544
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
545
+
546
+ # Initialize weights and apply final processing
547
+ self.post_init()
548
+
549
+ def get_input_embeddings(self):
550
+ return self.model.embed_tokens
551
+
552
+ def set_input_embeddings(self, value):
553
+ self.model.embed_tokens = value
554
+
555
+ def get_output_embeddings(self):
556
+ return self.lm_head
557
+
558
+ def set_output_embeddings(self, new_embeddings):
559
+ self.lm_head = new_embeddings
560
+
561
+ def set_decoder(self, decoder):
562
+ self.model = decoder
563
+
564
+ def get_decoder(self):
565
+ return self.model
566
+
567
+ def loss_function(
568
+ self,
569
+ logits,
570
+ labels,
571
+ vocab_size: int,
572
+ num_items_in_batch: int = None,
573
+ ignore_index: int = -100,
574
+ **kwargs,
575
+ ):
576
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
577
+ logits = logits.float()
578
+ # Shift so that tokens < n predict n
579
+ shift_logits = logits[..., :-1, :].contiguous()
580
+ shift_labels = labels[..., 1:].contiguous()
581
+ # Flatten the tokens
582
+ shift_logits = shift_logits.view(-1, vocab_size)
583
+ shift_labels = shift_labels.view(-1)
584
+ # Enable model parallelism
585
+ shift_labels = shift_labels.to(shift_logits.device)
586
+ loss = fixed_cross_entropy(
587
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
588
+ )
589
+ return loss
590
+
591
+ def forward(
592
+ self,
593
+ input_ids: torch.LongTensor = None,
594
+ attention_mask: Optional[torch.Tensor] = None,
595
+ position_ids: Optional[torch.LongTensor] = None,
596
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
597
+ inputs_embeds: Optional[torch.FloatTensor] = None,
598
+ labels: Optional[torch.LongTensor] = None,
599
+ use_cache: Optional[bool] = None,
600
+ output_attentions: Optional[bool] = None,
601
+ output_hidden_states: Optional[bool] = None,
602
+ return_dict: Optional[bool] = None,
603
+ cache_position: Optional[torch.LongTensor] = None,
604
+ num_logits_to_keep: int = 0,
605
+ **kwargs,
606
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
607
+ """
608
+ Args:
609
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
610
+ Labels for computing the masked language modeling loss. Indices should be in
611
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
612
+ num_logits_to_keep (`int`, *optional*):
613
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
614
+ """
615
+ output_attentions = (
616
+ output_attentions
617
+ if output_attentions is not None
618
+ else self.config.output_attentions
619
+ )
620
+ output_hidden_states = (
621
+ output_hidden_states
622
+ if output_hidden_states is not None
623
+ else self.config.output_hidden_states
624
+ )
625
+ return_dict = (
626
+ return_dict if return_dict is not None else self.config.use_return_dict
627
+ )
628
+
629
+ outputs = self.model(
630
+ input_ids=input_ids,
631
+ attention_mask=attention_mask,
632
+ position_ids=position_ids,
633
+ past_key_values=past_key_values,
634
+ inputs_embeds=inputs_embeds,
635
+ use_cache=use_cache,
636
+ output_attentions=output_attentions,
637
+ output_hidden_states=output_hidden_states,
638
+ return_dict=return_dict,
639
+ cache_position=cache_position,
640
+ **kwargs,
641
+ )
642
+
643
+ hidden_states = outputs[0]
644
+ # Only compute necessary logits
645
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
646
+
647
+ loss = None
648
+ if labels is not None:
649
+ # Calculate batch size for loss function
650
+ num_items_in_batch = (
651
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
652
+ )
653
+ loss = self.loss_function(
654
+ logits=logits,
655
+ labels=labels,
656
+ vocab_size=self.config.vocab_size,
657
+ num_items_in_batch=num_items_in_batch,
658
+ **kwargs,
659
+ )
660
+
661
+ if not return_dict:
662
+ output = (logits,) + outputs[1:]
663
+ return (loss,) + output if loss is not None else output
664
+
665
+ return CausalLMOutputWithPast(
666
+ loss=loss,
667
+ logits=logits,
668
+ past_key_values=outputs.past_key_values,
669
+ hidden_states=outputs.hidden_states,
670
+ attentions=outputs.attentions,
671
+ )
672
+
673
+ def prepare_inputs_for_generation(
674
+ self,
675
+ input_ids,
676
+ past_key_values=None,
677
+ attention_mask=None,
678
+ inputs_embeds=None,
679
+ **kwargs,
680
+ ):
681
+ if past_key_values:
682
+ input_ids = input_ids[:, -1:]
683
+
684
+ position_ids = kwargs.get("position_ids", None)
685
+ if attention_mask is not None and position_ids is None:
686
+ # create position_ids on the fly for batch generation
687
+ position_ids = attention_mask.long().cumsum(-1) - 1
688
+ position_ids.masked_fill_(attention_mask == 0, 1)
689
+ if past_key_values:
690
+ position_ids = position_ids[:, -1].unsqueeze(-1)
691
+
692
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
693
+ if inputs_embeds is not None and past_key_values is None:
694
+ model_inputs = {"inputs_embeds": inputs_embeds}
695
+ else:
696
+ model_inputs = {"input_ids": input_ids}
697
+
698
+ model_inputs.update(
699
+ {
700
+ "position_ids": position_ids,
701
+ "past_key_values": past_key_values,
702
+ "use_cache": kwargs.get("use_cache"),
703
+ "attention_mask": attention_mask,
704
+ }
705
+ )
706
+ return model_inputs
707
+
708
+ @classmethod
709
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
710
+ if isinstance(
711
+ pretrained_model_name_or_path, str
712
+ ) and pretrained_model_name_or_path.endswith(".pt"):
713
+ print("Loading from local checkpoint")
714
+ config = kwargs.get("config", None)
715
+ if config is None:
716
+ config = AutoConfig.from_pretrained(
717
+ pretrained_model_name_or_path, trust_remote_code=True
718
+ )
719
+
720
+ model = cls(config)
721
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
722
+ state_dict = checkpoint["model_state_dict"]
723
+
724
+ missing_keys, unexpected_keys = model.load_state_dict(
725
+ state_dict, strict=False
726
+ )
727
+
728
+ if len(missing_keys) > 0:
729
+ logger.warning(f"Missing keys: {missing_keys}")
730
+ if len(unexpected_keys) > 0:
731
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
732
+
733
+ return model
734
+ else:
735
+ print("Loading from hub")
736
+ return super().from_pretrained(
737
+ pretrained_model_name_or_path, *model_args, **kwargs
738
+ )