muverqqw commited on
Commit
b708ac5
·
1 Parent(s): 4468908

Update modeling_alinlight.py

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +11 -14
modeling_alinlight.py CHANGED
@@ -13,6 +13,9 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
 
16
  import math
17
  import torch
18
  import torch.nn as nn
@@ -20,7 +23,6 @@ import torch.nn.functional as F
20
  from typing import Optional, Tuple, List, Union
21
  from transformers import PreTrainedModel, GenerationMixin
22
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
23
-
24
  from configuration_alinlight import AlinlightConfig
25
 
26
  class AlinlightRMSNorm(nn.Module):
@@ -44,11 +46,8 @@ class AlinlightRotaryEmbedding(nn.Module):
44
  self.max_position_embeddings = max_position_embeddings
45
  self.scaling_factor = scaling_factor
46
 
47
- # We calculate frequencies immediately upon initialization
48
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
49
  self.register_buffer("inv_freq", inv_freq, persistent=False)
50
-
51
- # Initialize the cache
52
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
53
 
54
  def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -60,10 +59,8 @@ class AlinlightRotaryEmbedding(nn.Module):
60
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
61
 
62
  def forward(self, x, seq_len=None):
63
- # If the length is greater than the cache, we recalculate (a rare case, but needed for reliability)
64
  if seq_len > self.cos_cached.shape[0]:
65
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
66
-
67
  return (
68
  self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
69
  self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device)
@@ -139,14 +136,13 @@ class AlinlightAttention(nn.Module):
139
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
 
142
- # === TPU OPTIMIZATION: Physical cropping instead of mask ===
143
  if self.sliding_window is not None and key_states.shape[2] > self.sliding_window:
144
  key_states = key_states[:, :, -self.sliding_window:, :]
145
  value_states = value_states[:, :, -self.sliding_window:, :]
146
 
147
  past_key_value = (key_states, value_states) if use_cache else None
148
 
149
- # GQA / MQA processing
150
  if self.num_key_value_groups > 1:
151
  key_states = key_states[:, :, None, :, :].expand(
152
  bsz, self.num_key_value_heads, self.num_key_value_groups, key_states.shape[-2], self.head_dim
@@ -156,14 +152,18 @@ class AlinlightAttention(nn.Module):
156
  bsz, self.num_key_value_heads, self.num_key_value_groups, value_states.shape[-2], self.head_dim
157
  ).reshape(bsz, self.num_heads, value_states.shape[-2], self.head_dim)
158
 
159
- # SDPA (Flash Attention backend compatible)
 
 
 
 
160
  attn_output = F.scaled_dot_product_attention(
161
  query_states,
162
  key_states,
163
  value_states,
164
  attn_mask=None,
165
  dropout_p=0.0,
166
- is_causal=True
167
  )
168
 
169
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
@@ -189,7 +189,6 @@ class AlinlightDecoderLayer(nn.Module):
189
  hidden_states = self.post_attention_layernorm(hidden_states)
190
  hidden_states = self.mlp(hidden_states)
191
  hidden_states = residual + hidden_states
192
-
193
  return hidden_states, None, present_key_value
194
 
195
  class AlinlightModel(PreTrainedModel):
@@ -221,7 +220,6 @@ class AlinlightModel(PreTrainedModel):
221
  else:
222
  inputs_embeds = kwargs.get("inputs_embeds")
223
 
224
- # Length calculation for RoPE
225
  seq_len = inputs_embeds.shape[1]
226
  if past_key_values is not None:
227
  seq_len += past_key_values[0][0].shape[2]
@@ -264,7 +262,7 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
264
  super().__init__(config)
265
  self.model = AlinlightModel(config)
266
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
267
- self.lm_head.weight = self.model.embed_tokens.weight # Связываем веса (Weight Tying)
268
  self.post_init()
269
 
270
  def get_input_embeddings(self): return self.model.embed_tokens
@@ -273,7 +271,6 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
273
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
274
 
275
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
276
- # Optimization for generation: if there is a cache, we serve only the last token
277
  if past_key_values:
278
  input_ids = input_ids[:, -1:]
279
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ # -*- coding: utf-8 -*-
17
+ # Copyright 2026 EngineerGL Research.
18
+
19
  import math
20
  import torch
21
  import torch.nn as nn
 
23
  from typing import Optional, Tuple, List, Union
24
  from transformers import PreTrainedModel, GenerationMixin
25
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
 
26
  from configuration_alinlight import AlinlightConfig
27
 
28
  class AlinlightRMSNorm(nn.Module):
 
46
  self.max_position_embeddings = max_position_embeddings
47
  self.scaling_factor = scaling_factor
48
 
 
49
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
50
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
51
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
52
 
53
  def _set_cos_sin_cache(self, seq_len, device, dtype):
 
59
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
60
 
61
  def forward(self, x, seq_len=None):
 
62
  if seq_len > self.cos_cached.shape[0]:
63
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
64
  return (
65
  self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
66
  self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device)
 
136
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
137
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
138
 
139
+ # Truncation logic (TPU Optimization)
140
  if self.sliding_window is not None and key_states.shape[2] > self.sliding_window:
141
  key_states = key_states[:, :, -self.sliding_window:, :]
142
  value_states = value_states[:, :, -self.sliding_window:, :]
143
 
144
  past_key_value = (key_states, value_states) if use_cache else None
145
 
 
146
  if self.num_key_value_groups > 1:
147
  key_states = key_states[:, :, None, :, :].expand(
148
  bsz, self.num_key_value_heads, self.num_key_value_groups, key_states.shape[-2], self.head_dim
 
152
  bsz, self.num_key_value_heads, self.num_key_value_groups, value_states.shape[-2], self.head_dim
153
  ).reshape(bsz, self.num_heads, value_states.shape[-2], self.head_dim)
154
 
155
+ # FIX: Динамический флаг is_causal
156
+ # Если q_len > 1 (обучение или prefill) -> True (маскируем будущее)
157
+ # Если q_len == 1 (генерация) -> False (видим всё прошлое, что есть в key_states)
158
+ is_causal = q_len > 1
159
+
160
  attn_output = F.scaled_dot_product_attention(
161
  query_states,
162
  key_states,
163
  value_states,
164
  attn_mask=None,
165
  dropout_p=0.0,
166
+ is_causal=is_causal
167
  )
168
 
169
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
 
189
  hidden_states = self.post_attention_layernorm(hidden_states)
190
  hidden_states = self.mlp(hidden_states)
191
  hidden_states = residual + hidden_states
 
192
  return hidden_states, None, present_key_value
193
 
194
  class AlinlightModel(PreTrainedModel):
 
220
  else:
221
  inputs_embeds = kwargs.get("inputs_embeds")
222
 
 
223
  seq_len = inputs_embeds.shape[1]
224
  if past_key_values is not None:
225
  seq_len += past_key_values[0][0].shape[2]
 
262
  super().__init__(config)
263
  self.model = AlinlightModel(config)
264
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
265
+ self.lm_head.weight = self.model.embed_tokens.weight
266
  self.post_init()
267
 
268
  def get_input_embeddings(self): return self.model.embed_tokens
 
271
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
272
 
273
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
 
274
  if past_key_values:
275
  input_ids = input_ids[:, -1:]
276