rtaAILabs commited on
Commit
f5c80ef
·
verified ·
1 Parent(s): 8f0707c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
chat_template.jinja ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {% for message in messages %}{% if loop.first %}<|im_start|><|system|>You are Nandi-Mini, a helpful, concise, and accurate AI assistant by Rta AI Labs that provides clear answers, asks for clarification when needed, and avoids harmful or incorrect information.<|endoftext|>
2
+ {% endif %}{% if message['role'] == 'user' %}<|user|>{{ message['content'] }}<|endoftext|>
3
+ <|assistant|>{% endif %}{% if message['role'] == 'assistant' %}{{ message['content'] }}<|endoftext|>{% endif %}{% endfor %}
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NandiForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_nandi.NandiConfig",
9
+ "AutoModel": "modeling_nandi.NandiModel",
10
+ "AutoModelForCausalLM": "modeling_nandi.NandiForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "dtype": "bfloat16",
14
+ "embedding_rank": 196,
15
+ "eos_token_id": 0,
16
+ "factorized_embedding": true,
17
+ "head_dim": 52,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 832,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 2496,
22
+ "layer_sharing": true,
23
+ "layer_sharing_repeats": 2,
24
+ "max_position_embeddings": 2048,
25
+ "mlp_bias": false,
26
+ "model_type": "nandi",
27
+ "num_attention_heads": 16,
28
+ "num_hidden_layers": 16,
29
+ "num_key_value_heads": 4,
30
+ "pad_token_id": 0,
31
+ "pretraining_tp": 1,
32
+ "rms_norm_eps": 1e-05,
33
+ "rope_parameters": {
34
+ "rope_theta": 100000,
35
+ "rope_type": "default"
36
+ },
37
+ "tie_word_embeddings": true,
38
+ "transformers_version": "5.4.0",
39
+ "use_cache": false,
40
+ "vocab_size": 131072
41
+ }
configuration_nandi.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 RTA AI Labs. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers.configuration_utils import PretrainedConfig
16
+
17
+
18
+ class NandiConfig(PretrainedConfig):
19
+ r"""
20
+ Configuration class for the Nandi model.
21
+
22
+ Example:
23
+
24
+ ```python
25
+ >>> from transformers import AutoConfig, AutoModelForCausalLM
26
+
27
+ >>> configuration = AutoConfig.from_pretrained("Rta-AILabs/Nandi-150M-remote", trust_remote_code=True)
28
+
29
+ >>> model = AutoModelForCausalLM.from_pretrained("Rta-AILabs/Nandi-150M-remote", trust_remote_code=True)
30
+
31
+ >>> configuration = model.config
32
+ ```
33
+ """
34
+
35
+ model_type = "nandi"
36
+ keys_to_ignore_at_inference = ["past_key_values"]
37
+
38
+ base_model_tp_plan = {
39
+ "layers.*.self_attn.q_proj": "colwise",
40
+ "layers.*.self_attn.k_proj": "colwise",
41
+ "layers.*.self_attn.v_proj": "colwise",
42
+ "layers.*.self_attn.o_proj": "rowwise",
43
+ "layers.*.mlp.gate_proj": "colwise",
44
+ "layers.*.mlp.up_proj": "colwise",
45
+ "layers.*.mlp.down_proj": "rowwise",
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_size=131072,
51
+ hidden_size=832,
52
+ intermediate_size=2496,
53
+ num_hidden_layers=16,
54
+ num_attention_heads=16,
55
+ num_key_value_heads=4,
56
+ head_dim=None,
57
+ hidden_act="silu",
58
+ max_position_embeddings=2048,
59
+ initializer_range=0.008,
60
+ rms_norm_eps=1e-5,
61
+ use_cache=True,
62
+ pad_token_id=None,
63
+ bos_token_id=1,
64
+ eos_token_id=0,
65
+ pretraining_tp=1,
66
+ tie_word_embeddings=True,
67
+ rope_parameters=None,
68
+ attention_bias=False,
69
+ attention_dropout=0.0,
70
+ mlp_bias=False,
71
+ factorized_embedding=True,
72
+ embedding_rank=196,
73
+ layer_sharing=True,
74
+ layer_sharing_repeats=2,
75
+ **kwargs,
76
+ ):
77
+ self.vocab_size = vocab_size
78
+ self.hidden_size = hidden_size
79
+ self.intermediate_size = intermediate_size
80
+ self.num_hidden_layers = num_hidden_layers
81
+ self.num_attention_heads = num_attention_heads
82
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
83
+ self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
84
+ self.hidden_act = hidden_act
85
+ self.max_position_embeddings = max_position_embeddings
86
+ self.initializer_range = initializer_range
87
+ self.rms_norm_eps = rms_norm_eps
88
+ self.use_cache = use_cache
89
+ self.pretraining_tp = pretraining_tp
90
+ self.rope_parameters = rope_parameters if rope_parameters is not None else {"rope_theta": 100000.0}
91
+ self.attention_bias = attention_bias
92
+ self.attention_dropout = attention_dropout
93
+ self.mlp_bias = mlp_bias
94
+ self.factorized_embedding = factorized_embedding
95
+ self.embedding_rank = embedding_rank
96
+ self.layer_sharing = layer_sharing
97
+ self.layer_sharing_repeats = layer_sharing_repeats if layer_sharing else 1
98
+
99
+ if self.factorized_embedding and self.embedding_rank <= 0:
100
+ raise ValueError(
101
+ f"`embedding_rank` must be positive when `factorized_embedding=True`, got {self.embedding_rank}."
102
+ )
103
+ if self.hidden_size % self.num_attention_heads != 0:
104
+ raise ValueError(
105
+ f"`hidden_size` ({self.hidden_size}) must be divisible by "
106
+ f"`num_attention_heads` ({self.num_attention_heads})."
107
+ )
108
+ if self.layer_sharing_repeats < 1:
109
+ raise ValueError(f"`layer_sharing_repeats` must be >= 1, got {self.layer_sharing_repeats}.")
110
+
111
+ super().__init__(
112
+ pad_token_id=pad_token_id,
113
+ bos_token_id=bos_token_id,
114
+ eos_token_id=eos_token_id,
115
+ tie_word_embeddings=tie_word_embeddings,
116
+ **kwargs,
117
+ )
118
+
119
+
120
+ __all__ = ["NandiConfig"]
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 0
6
+ ],
7
+ "pad_token_id": 0,
8
+ "transformers_version": "5.4.0",
9
+ "use_cache": true
10
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4524b11c4106416720bd24fbb58950a3fc279f89be85d3e47eb57a45bcc0b6a9
3
+ size 306842392
modeling_nandi.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/nandi/modular_nandi.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_nandi.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache, DynamicLayer
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import create_causal_mask
31
+ from transformers.modeling_layers import GradientCheckpointingLayer
32
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from transformers.processing_utils import Unpack
36
+ from transformers.utils import TransformersKwargs, auto_docstring
37
+ from transformers.utils.deprecation import deprecate_kwarg
38
+ from transformers.utils.generic import can_return_tuple, merge_with_config_defaults
39
+ from transformers.utils.output_capturing import capture_outputs
40
+ from .configuration_nandi import NandiConfig
41
+
42
+
43
+ @use_kernel_forward_from_hub("RMSNorm")
44
+ class NandiRMSNorm(nn.Module):
45
+ def __init__(self, hidden_size, eps=1e-6):
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.ones(hidden_size))
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, hidden_states):
51
+ input_dtype = hidden_states.dtype
52
+ hidden_states = hidden_states.to(torch.float32)
53
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
54
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
55
+ return self.weight * hidden_states.to(input_dtype)
56
+
57
+ def extra_repr(self):
58
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
59
+
60
+
61
+ class NandiRotaryEmbedding(nn.Module):
62
+ inv_freq: torch.Tensor
63
+
64
+ def __init__(self, config: NandiConfig, device=None):
65
+ super().__init__()
66
+ self.max_seq_len_cached = config.max_position_embeddings
67
+ self.original_max_seq_len = config.max_position_embeddings
68
+
69
+ self.config = config
70
+ self.rope_type = self.config.rope_parameters.get("rope_type", "default")
71
+ rope_init_fn: Callable = self.compute_default_rope_parameters
72
+ if self.rope_type != "default":
73
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
74
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
75
+
76
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
77
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
78
+
79
+ @staticmethod
80
+ def compute_default_rope_parameters(
81
+ config: NandiConfig | None = None,
82
+ device: torch.device | None = None,
83
+ seq_len: int | None = None,
84
+ ) -> tuple[torch.Tensor, float]:
85
+ del seq_len
86
+ base = config.rope_parameters["rope_theta"]
87
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
88
+ attention_factor = 1.0
89
+ inv_freq = 1.0 / (
90
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
91
+ )
92
+ return inv_freq, attention_factor
93
+
94
+ @torch.no_grad()
95
+ @dynamic_rope_update
96
+ def forward(self, x, position_ids):
97
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
98
+ position_ids_expanded = position_ids[:, None, :].float()
99
+
100
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
101
+ with torch.autocast(device_type=device_type, enabled=False):
102
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ cos = emb.cos() * self.attention_scaling
105
+ sin = emb.sin() * self.attention_scaling
106
+
107
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
108
+
109
+
110
+ def rotate_half(x):
111
+ """Rotates half the hidden dims of the input."""
112
+ x1 = x[..., : x.shape[-1] // 2]
113
+ x2 = x[..., x.shape[-1] // 2 :]
114
+ return torch.cat((-x2, x1), dim=-1)
115
+
116
+
117
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
118
+ del position_ids
119
+ cos = cos.unsqueeze(unsqueeze_dim)
120
+ sin = sin.unsqueeze(unsqueeze_dim)
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
127
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
128
+ if n_rep == 1:
129
+ return hidden_states
130
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
131
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
132
+
133
+
134
+ def eager_attention_forward(
135
+ module: nn.Module,
136
+ query: torch.Tensor,
137
+ key: torch.Tensor,
138
+ value: torch.Tensor,
139
+ attention_mask: torch.Tensor | None,
140
+ scaling: float,
141
+ dropout: float = 0.0,
142
+ **kwargs: Unpack[TransformersKwargs],
143
+ ):
144
+ del kwargs
145
+ key_states = repeat_kv(key, module.num_key_value_groups)
146
+ value_states = repeat_kv(value, module.num_key_value_groups)
147
+
148
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
149
+ if attention_mask is not None:
150
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
151
+ attn_weights = attn_weights + causal_mask
152
+
153
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
154
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
155
+ attn_output = torch.matmul(attn_weights, value_states)
156
+ attn_output = attn_output.transpose(1, 2).contiguous()
157
+
158
+ return attn_output, attn_weights
159
+
160
+
161
+ class NandiAttention(nn.Module):
162
+ def __init__(self, config: NandiConfig, layer_idx: int):
163
+ super().__init__()
164
+ self.config = config
165
+ self.layer_idx = layer_idx
166
+ self.head_dim = config.head_dim
167
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
168
+ self.scaling = self.head_dim**-0.5
169
+ self.attention_dropout = config.attention_dropout
170
+ self.is_causal = True
171
+
172
+ self.q_proj = nn.Linear(
173
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
174
+ )
175
+ self.k_proj = nn.Linear(
176
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
177
+ )
178
+ self.v_proj = nn.Linear(
179
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
180
+ )
181
+ self.o_proj = nn.Linear(
182
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
183
+ )
184
+
185
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
186
+ def forward(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
190
+ attention_mask: torch.Tensor | None,
191
+ past_key_values: Cache | None = None,
192
+ **kwargs: Unpack[TransformersKwargs],
193
+ ) -> tuple[torch.Tensor, torch.Tensor]:
194
+ input_shape = hidden_states.shape[:-1]
195
+ hidden_shape = (*input_shape, -1, self.head_dim)
196
+
197
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
198
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
199
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
200
+
201
+ cos, sin = position_embeddings
202
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
203
+
204
+ if past_key_values is not None:
205
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
206
+
207
+ attention_interface: Callable = eager_attention_forward
208
+ if self.config._attn_implementation != "eager":
209
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
210
+
211
+ attn_output, attn_weights = attention_interface(
212
+ self,
213
+ query_states,
214
+ key_states,
215
+ value_states,
216
+ attention_mask,
217
+ dropout=0.0 if not self.training else self.attention_dropout,
218
+ scaling=self.scaling,
219
+ **kwargs,
220
+ )
221
+
222
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
223
+ attn_output = self.o_proj(attn_output)
224
+ return attn_output, attn_weights
225
+
226
+
227
+ class NandiMLP(nn.Module):
228
+ def __init__(self, config):
229
+ super().__init__()
230
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
231
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
232
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
233
+ self.act_fn = ACT2FN[config.hidden_act]
234
+
235
+ def forward(self, x):
236
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
237
+
238
+
239
+ class NandiDecoderLayer(GradientCheckpointingLayer):
240
+ def __init__(self, config: NandiConfig, layer_idx: int):
241
+ super().__init__()
242
+ self.hidden_size = config.hidden_size
243
+ self.self_attn = NandiAttention(config=config, layer_idx=layer_idx)
244
+ self.mlp = NandiMLP(config)
245
+ self.input_layernorm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+ self.post_attention_layernorm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
247
+
248
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: torch.Tensor | None = None,
253
+ position_ids: torch.LongTensor | None = None,
254
+ past_key_values: Cache | None = None,
255
+ use_cache: bool | None = False,
256
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
257
+ **kwargs: Unpack[TransformersKwargs],
258
+ ) -> torch.Tensor:
259
+ residual = hidden_states
260
+ hidden_states = self.input_layernorm(hidden_states)
261
+
262
+ hidden_states, _ = self.self_attn(
263
+ hidden_states=hidden_states,
264
+ attention_mask=attention_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ use_cache=use_cache,
268
+ position_embeddings=position_embeddings,
269
+ **kwargs,
270
+ )
271
+ hidden_states = residual + hidden_states
272
+
273
+ residual = hidden_states
274
+ hidden_states = self.post_attention_layernorm(hidden_states)
275
+ hidden_states = self.mlp(hidden_states)
276
+ hidden_states = residual + hidden_states
277
+ return hidden_states
278
+
279
+
280
+ class _VirtualLayerCache:
281
+ """Proxy that shifts cache layer indices by `offset` to give each repeat its own virtual slots."""
282
+
283
+ def __init__(self, cache: Cache, offset: int):
284
+ self._cache = cache
285
+ self._offset = offset
286
+
287
+ def __getattr__(self, name):
288
+ return getattr(self._cache, name)
289
+
290
+ def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
291
+ virtual_idx = layer_idx + self._offset
292
+ # grow the backing cache if generate() pre-allocated fewer slots than needed
293
+ while len(self._cache.layers) <= virtual_idx:
294
+ self._cache.layers.append(DynamicLayer())
295
+ return self._cache.update(key_states, value_states, virtual_idx, cache_kwargs)
296
+
297
+ def get_seq_length(self, layer_idx: int = 0) -> int:
298
+ return self._cache.get_seq_length(layer_idx + self._offset)
299
+
300
+
301
+ @auto_docstring
302
+ class NandiPreTrainedModel(PreTrainedModel):
303
+ config: NandiConfig
304
+ base_model_prefix = "model"
305
+ supports_gradient_checkpointing = True
306
+ _no_split_modules = ["NandiDecoderLayer"]
307
+ _skip_keys_device_placement = ["past_key_values"]
308
+ _supports_flash_attn = True
309
+ _supports_sdpa = True
310
+ _supports_flex_attn = True
311
+ _can_compile_fullgraph = True
312
+ _supports_attention_backend = True
313
+ _can_record_outputs = {
314
+ "hidden_states": NandiDecoderLayer,
315
+ "attentions": NandiAttention,
316
+ }
317
+
318
+ def __init__(self, config: NandiConfig):
319
+ super().__init__(config)
320
+
321
+
322
+ @auto_docstring
323
+ class NandiModel(NandiPreTrainedModel):
324
+ def __init__(self, config: NandiConfig):
325
+ super().__init__(config)
326
+ self.padding_idx = config.pad_token_id
327
+ self.vocab_size = config.vocab_size
328
+ embedding_dim = config.embedding_rank if config.factorized_embedding else config.hidden_size
329
+
330
+ self.embed_tokens = nn.Embedding(config.vocab_size, embedding_dim, self.padding_idx)
331
+ self.embedding_proj = (
332
+ nn.Linear(config.embedding_rank, config.hidden_size, bias=False) if config.factorized_embedding else None
333
+ )
334
+ self.layers = nn.ModuleList(
335
+ [NandiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
336
+ )
337
+ self.norm = NandiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
338
+ self.rotary_emb = NandiRotaryEmbedding(config=config)
339
+ self.gradient_checkpointing = False
340
+
341
+ self.post_init()
342
+
343
+ @merge_with_config_defaults
344
+ @capture_outputs
345
+ @auto_docstring
346
+ def forward(
347
+ self,
348
+ input_ids: torch.LongTensor | None = None,
349
+ attention_mask: torch.Tensor | None = None,
350
+ position_ids: torch.LongTensor | None = None,
351
+ past_key_values: Cache | None = None,
352
+ inputs_embeds: torch.FloatTensor | None = None,
353
+ use_cache: bool | None = None,
354
+ **kwargs: Unpack[TransformersKwargs],
355
+ ) -> BaseModelOutputWithPast:
356
+ if (input_ids is None) ^ (inputs_embeds is not None):
357
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
358
+
359
+ if inputs_embeds is None:
360
+ inputs_embeds = self.embed_tokens(input_ids)
361
+
362
+ if self.embedding_proj is not None:
363
+ inputs_embeds = self.embedding_proj(inputs_embeds)
364
+
365
+ repeats = self.config.layer_sharing_repeats if self.config.layer_sharing else 1
366
+
367
+ if use_cache and past_key_values is None:
368
+ # Use lazy DynamicCache (no config) so it grows to accommodate
369
+ # num_hidden_layers * repeats virtual slots for layer-sharing.
370
+ past_key_values = DynamicCache()
371
+
372
+ if position_ids is None:
373
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
374
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
375
+ position_ids = position_ids.unsqueeze(0)
376
+
377
+ causal_mask = create_causal_mask(
378
+ config=self.config,
379
+ inputs_embeds=inputs_embeds,
380
+ attention_mask=attention_mask,
381
+ past_key_values=past_key_values,
382
+ position_ids=position_ids,
383
+ )
384
+
385
+ hidden_states = inputs_embeds
386
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
387
+
388
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
389
+ for repeat_idx in range(repeats):
390
+ # Each repeat gets its own virtual cache slots offset by num_hidden_layers,
391
+ # so repeat 0 uses slots 0..N-1 and repeat 1 uses slots N..2N-1, etc.
392
+ repeat_cache = (
393
+ _VirtualLayerCache(past_key_values, repeat_idx * self.config.num_hidden_layers)
394
+ if (past_key_values is not None and repeat_idx > 0)
395
+ else past_key_values
396
+ )
397
+ hidden_states = decoder_layer(
398
+ hidden_states,
399
+ attention_mask=causal_mask,
400
+ position_embeddings=position_embeddings,
401
+ position_ids=position_ids,
402
+ past_key_values=repeat_cache,
403
+ use_cache=use_cache,
404
+ **kwargs,
405
+ )
406
+
407
+ hidden_states = self.norm(hidden_states)
408
+ return BaseModelOutputWithPast(
409
+ last_hidden_state=hidden_states,
410
+ past_key_values=past_key_values,
411
+ )
412
+
413
+
414
+ @auto_docstring
415
+ class NandiForCausalLM(NandiPreTrainedModel, GenerationMixin):
416
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
417
+ _tp_plan = {"lm_head": "colwise_gather_output"}
418
+ _pp_plan = {
419
+ "lm_head_proj": (["hidden_states"], ["hidden_states"]),
420
+ "lm_head": (["hidden_states"], ["logits"]),
421
+ }
422
+
423
+ def __init__(self, config):
424
+ super().__init__(config)
425
+ self.model = NandiModel(config)
426
+ self.vocab_size = config.vocab_size
427
+
428
+ lm_head_in_features = config.embedding_rank if config.factorized_embedding else config.hidden_size
429
+ self.lm_head_proj = (
430
+ nn.Linear(config.hidden_size, config.embedding_rank, bias=False) if config.factorized_embedding else None
431
+ )
432
+ self.lm_head = nn.Linear(lm_head_in_features, config.vocab_size, bias=False)
433
+
434
+ self.post_init()
435
+
436
+ @can_return_tuple
437
+ @auto_docstring
438
+ def forward(
439
+ self,
440
+ input_ids: torch.LongTensor | None = None,
441
+ attention_mask: torch.Tensor | None = None,
442
+ position_ids: torch.LongTensor | None = None,
443
+ past_key_values: Cache | None = None,
444
+ inputs_embeds: torch.FloatTensor | None = None,
445
+ labels: torch.LongTensor | None = None,
446
+ use_cache: bool | None = None,
447
+ logits_to_keep: int | torch.Tensor = 0,
448
+ **kwargs: Unpack[TransformersKwargs],
449
+ ) -> CausalLMOutputWithPast:
450
+ outputs: BaseModelOutputWithPast = self.model(
451
+ input_ids=input_ids,
452
+ attention_mask=attention_mask,
453
+ position_ids=position_ids,
454
+ past_key_values=past_key_values,
455
+ inputs_embeds=inputs_embeds,
456
+ use_cache=use_cache,
457
+ **kwargs,
458
+ )
459
+
460
+ hidden_states = outputs.last_hidden_state
461
+ if self.lm_head_proj is not None:
462
+ hidden_states = self.lm_head_proj(hidden_states)
463
+
464
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
465
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
466
+
467
+ loss = None
468
+ if labels is not None:
469
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
470
+
471
+ return CausalLMOutputWithPast(
472
+ loss=loss,
473
+ logits=logits,
474
+ past_key_values=outputs.past_key_values,
475
+ hidden_states=outputs.hidden_states,
476
+ attentions=outputs.attentions,
477
+ )
478
+
479
+
480
+ __all__ = ["NandiPreTrainedModel", "NandiModel", "NandiForCausalLM"]
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9fd2911e5e02cb959f6a77a1ebd4bba088d4ec2e0bc0a208b3c1e0ca2278791
3
+ size 12460626
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|im_start|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "is_local": true,
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "pad_token": "<|endoftext|>",
8
+ "tokenizer_class": "TokenizersBackend",
9
+ "unk_token": "<|endoftext|>"
10
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff