muverqqw commited on
Commit
4468908
·
1 Parent(s): 411b49d

FIX ATTENTION

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +77 -16
modeling_alinlight.py CHANGED
@@ -18,16 +18,17 @@ import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  from typing import Optional, Tuple, List, Union
21
- from transformers import PreTrainedModel
22
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
23
- from transformers import GenerationMixin
24
- from configuration_alinlight import AlinlightConfig # Импортируем конфиг из соседнего файла
25
 
26
  class AlinlightRMSNorm(nn.Module):
27
  def __init__(self, hidden_size, eps=1e-6):
28
  super().__init__()
29
  self.weight = nn.Parameter(torch.ones(hidden_size))
30
  self.eps = eps
 
31
  def forward(self, x):
32
  input_dtype = x.dtype
33
  x = x.to(torch.float32)
@@ -42,8 +43,12 @@ class AlinlightRotaryEmbedding(nn.Module):
42
  self.base = base
43
  self.max_position_embeddings = max_position_embeddings
44
  self.scaling_factor = scaling_factor
 
 
45
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
46
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
47
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
48
 
49
  def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -55,9 +60,14 @@ class AlinlightRotaryEmbedding(nn.Module):
55
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
56
 
57
  def forward(self, x, seq_len=None):
 
58
  if seq_len > self.cos_cached.shape[0]:
59
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
60
- return self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype)
 
 
 
 
61
 
62
  def rotate_half(x):
63
  x1 = x[..., : x.shape[-1] // 2]
@@ -95,18 +105,28 @@ class AlinlightAttention(nn.Module):
95
  self.num_key_value_heads = config.num_key_value_heads
96
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
97
  self.sliding_window = config.sliding_window
98
- self.attention_dropout = config.attention_dropout
99
-
100
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
101
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
102
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
103
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
104
 
105
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cos_sin=None):
 
 
 
 
 
 
 
 
 
106
  bsz, q_len, _ = hidden_states.size()
 
107
  query_states = self.q_proj(hidden_states)
108
  key_states = self.k_proj(hidden_states)
109
  value_states = self.v_proj(hidden_states)
 
110
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
111
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
112
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -119,19 +139,33 @@ class AlinlightAttention(nn.Module):
119
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
120
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
121
 
122
- # Truncation logic for sliding window
123
  if self.sliding_window is not None and key_states.shape[2] > self.sliding_window:
124
  key_states = key_states[:, :, -self.sliding_window:, :]
125
  value_states = value_states[:, :, -self.sliding_window:, :]
126
 
127
  past_key_value = (key_states, value_states) if use_cache else None
128
 
 
129
  if self.num_key_value_groups > 1:
130
- key_states = key_states[:, :, None, :, :].expand(bsz, self.num_key_value_heads, self.num_key_value_groups, key_states.shape[-2], self.head_dim).reshape(bsz, self.num_heads, key_states.shape[-2], self.head_dim)
131
- value_states = value_states[:, :, None, :, :].expand(bsz, self.num_key_value_heads, self.num_key_value_groups, value_states.shape[-2], self.head_dim).reshape(bsz, self.num_heads, value_states.shape[-2], self.head_dim)
 
 
 
 
 
132
 
133
- # Use Scaled Dot Product Attention
134
- attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True)
 
 
 
 
 
 
 
 
135
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
136
  return self.o_proj(attn_output), None, past_key_value
137
 
@@ -146,12 +180,16 @@ class AlinlightDecoderLayer(nn.Module):
146
  def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cos_sin=None):
147
  residual = hidden_states
148
  hidden_states = self.input_layernorm(hidden_states)
149
- hidden_states, _, present_key_value = self.self_attn(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cos_sin)
 
 
150
  hidden_states = residual + hidden_states
 
151
  residual = hidden_states
152
  hidden_states = self.post_attention_layernorm(hidden_states)
153
  hidden_states = self.mlp(hidden_states)
154
  hidden_states = residual + hidden_states
 
155
  return hidden_states, None, present_key_value
156
 
157
  class AlinlightModel(PreTrainedModel):
@@ -166,7 +204,16 @@ class AlinlightModel(PreTrainedModel):
166
  if config.rope_scaling and config.rope_scaling.get("type") == "linear":
167
  scaling_factor = config.rope_scaling.get("factor", 1.0)
168
 
169
- self.rotary_emb = AlinlightRotaryEmbedding(config.hidden_size // config.num_attention_heads, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, scaling_factor=scaling_factor)
 
 
 
 
 
 
 
 
 
170
 
171
  def forward(self, input_ids=None, past_key_values=None, use_cache=None, **kwargs):
172
  if input_ids is not None:
@@ -174,6 +221,7 @@ class AlinlightModel(PreTrainedModel):
174
  else:
175
  inputs_embeds = kwargs.get("inputs_embeds")
176
 
 
177
  seq_len = inputs_embeds.shape[1]
178
  if past_key_values is not None:
179
  seq_len += past_key_values[0][0].shape[2]
@@ -190,7 +238,13 @@ class AlinlightModel(PreTrainedModel):
190
 
191
  for idx, layer in enumerate(self.layers):
192
  past_key_value = past_key_values[idx] if past_key_values is not None else None
193
- layer_outputs = layer(hidden_states, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, cos_sin=(cos, sin))
 
 
 
 
 
 
194
  hidden_states = layer_outputs[0]
195
  if use_cache:
196
  next_decoder_cache += (layer_outputs[2],)
@@ -210,9 +264,16 @@ class AlinlightForCausalLM(PreTrainedModel, GenerationMixin):
210
  super().__init__(config)
211
  self.model = AlinlightModel(config)
212
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
213
- self.lm_head.weight = self.model.embed_tokens.weight
 
 
 
 
 
 
214
 
215
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
 
216
  if past_key_values:
217
  input_ids = input_ids[:, -1:]
218
 
 
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  from typing import Optional, Tuple, List, Union
21
+ from transformers import PreTrainedModel, GenerationMixin
22
  from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
23
+
24
+ from configuration_alinlight import AlinlightConfig
25
 
26
  class AlinlightRMSNorm(nn.Module):
27
  def __init__(self, hidden_size, eps=1e-6):
28
  super().__init__()
29
  self.weight = nn.Parameter(torch.ones(hidden_size))
30
  self.eps = eps
31
+
32
  def forward(self, x):
33
  input_dtype = x.dtype
34
  x = x.to(torch.float32)
 
43
  self.base = base
44
  self.max_position_embeddings = max_position_embeddings
45
  self.scaling_factor = scaling_factor
46
+
47
+ # We calculate frequencies immediately upon initialization
48
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
49
  self.register_buffer("inv_freq", inv_freq, persistent=False)
50
+
51
+ # Initialize the cache
52
  self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
53
 
54
  def _set_cos_sin_cache(self, seq_len, device, dtype):
 
60
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
61
 
62
  def forward(self, x, seq_len=None):
63
+ # If the length is greater than the cache, we recalculate (a rare case, but needed for reliability)
64
  if seq_len > self.cos_cached.shape[0]:
65
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
66
+
67
+ return (
68
+ self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
69
+ self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device)
70
+ )
71
 
72
  def rotate_half(x):
73
  x1 = x[..., : x.shape[-1] // 2]
 
105
  self.num_key_value_heads = config.num_key_value_heads
106
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
107
  self.sliding_window = config.sliding_window
108
+
 
109
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
110
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
111
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
112
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
113
 
114
+ def forward(
115
+ self,
116
+ hidden_states,
117
+ attention_mask=None,
118
+ position_ids=None,
119
+ past_key_value=None,
120
+ output_attentions=False,
121
+ use_cache=False,
122
+ cos_sin=None
123
+ ):
124
  bsz, q_len, _ = hidden_states.size()
125
+
126
  query_states = self.q_proj(hidden_states)
127
  key_states = self.k_proj(hidden_states)
128
  value_states = self.v_proj(hidden_states)
129
+
130
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
132
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
139
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
140
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
141
 
142
+ # === TPU OPTIMIZATION: Physical cropping instead of mask ===
143
  if self.sliding_window is not None and key_states.shape[2] > self.sliding_window:
144
  key_states = key_states[:, :, -self.sliding_window:, :]
145
  value_states = value_states[:, :, -self.sliding_window:, :]
146
 
147
  past_key_value = (key_states, value_states) if use_cache else None
148
 
149
+ # GQA / MQA processing
150
  if self.num_key_value_groups > 1:
151
+ key_states = key_states[:, :, None, :, :].expand(
152
+ bsz, self.num_key_value_heads, self.num_key_value_groups, key_states.shape[-2], self.head_dim
153
+ ).reshape(bsz, self.num_heads, key_states.shape[-2], self.head_dim)
154
+
155
+ value_states = value_states[:, :, None, :, :].expand(
156
+ bsz, self.num_key_value_heads, self.num_key_value_groups, value_states.shape[-2], self.head_dim
157
+ ).reshape(bsz, self.num_heads, value_states.shape[-2], self.head_dim)
158
 
159
+ # SDPA (Flash Attention backend compatible)
160
+ attn_output = F.scaled_dot_product_attention(
161
+ query_states,
162
+ key_states,
163
+ value_states,
164
+ attn_mask=None,
165
+ dropout_p=0.0,
166
+ is_causal=True
167
+ )
168
+
169
  attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
170
  return self.o_proj(attn_output), None, past_key_value
171
 
 
180
  def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cos_sin=None):
181
  residual = hidden_states
182
  hidden_states = self.input_layernorm(hidden_states)
183
+ hidden_states, _, present_key_value = self.self_attn(
184
+ hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cos_sin
185
+ )
186
  hidden_states = residual + hidden_states
187
+
188
  residual = hidden_states
189
  hidden_states = self.post_attention_layernorm(hidden_states)
190
  hidden_states = self.mlp(hidden_states)
191
  hidden_states = residual + hidden_states
192
+
193
  return hidden_states, None, present_key_value
194
 
195
  class AlinlightModel(PreTrainedModel):
 
204
  if config.rope_scaling and config.rope_scaling.get("type") == "linear":
205
  scaling_factor = config.rope_scaling.get("factor", 1.0)
206
 
207
+ self.rotary_emb = AlinlightRotaryEmbedding(
208
+ config.hidden_size // config.num_attention_heads,
209
+ max_position_embeddings=config.max_position_embeddings,
210
+ base=config.rope_theta,
211
+ scaling_factor=scaling_factor
212
+ )
213
+ self.post_init()
214
+
215
+ def get_input_embeddings(self): return self.embed_tokens
216
+ def set_input_embeddings(self, value): self.embed_tokens = value
217
 
218
  def forward(self, input_ids=None, past_key_values=None, use_cache=None, **kwargs):
219
  if input_ids is not None:
 
221
  else:
222
  inputs_embeds = kwargs.get("inputs_embeds")
223
 
224
+ # Length calculation for RoPE
225
  seq_len = inputs_embeds.shape[1]
226
  if past_key_values is not None:
227
  seq_len += past_key_values[0][0].shape[2]
 
238
 
239
  for idx, layer in enumerate(self.layers):
240
  past_key_value = past_key_values[idx] if past_key_values is not None else None
241
+ layer_outputs = layer(
242
+ hidden_states,
243
+ position_ids=position_ids,
244
+ past_key_value=past_key_value,
245
+ use_cache=use_cache,
246
+ cos_sin=(cos, sin)
247
+ )
248
  hidden_states = layer_outputs[0]
249
  if use_cache:
250
  next_decoder_cache += (layer_outputs[2],)
 
264
  super().__init__(config)
265
  self.model = AlinlightModel(config)
266
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
267
+ self.lm_head.weight = self.model.embed_tokens.weight # Связываем веса (Weight Tying)
268
+ self.post_init()
269
+
270
+ def get_input_embeddings(self): return self.model.embed_tokens
271
+ def set_input_embeddings(self, value): self.model.embed_tokens = value
272
+ def get_output_embeddings(self): return self.lm_head
273
+ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
274
 
275
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
276
+ # Optimization for generation: if there is a cache, we serve only the last token
277
  if past_key_values:
278
  input_ids = input_ids[:, -1:]
279