ridger commited on
Commit
1855d8a
·
verified ·
1 Parent(s): a2d3a54

Update modeling_ouro.py

Browse files
Files changed (1) hide show
  1. modeling_ouro.py +305 -56
modeling_ouro.py CHANGED
@@ -1,13 +1,17 @@
1
- from typing import Callable, Optional, Union
 
2
 
3
  import torch
4
  from torch import nn
5
 
6
  from transformers.activations import ACT2FN
7
- from transformers.cache_utils import Cache, DynamicCache
8
  from transformers.generation import GenerationMixin
9
  from transformers.integrations import use_kernel_forward_from_hub
10
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
 
 
 
11
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
12
  from transformers.modeling_layers import (
13
  GenericForQuestionAnswering,
@@ -15,7 +19,10 @@ from transformers.modeling_layers import (
15
  GenericForTokenClassification,
16
  GradientCheckpointingLayer,
17
  )
18
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
19
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
  from transformers.processing_utils import Unpack
@@ -24,6 +31,37 @@ from transformers.utils.generic import check_model_inputs
24
  from .configuration_ouro import OuroConfig
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class OuroMLP(nn.Module):
28
  def __init__(self, config):
29
  super().__init__()
@@ -82,10 +120,111 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
82
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
83
  if n_rep == 1:
84
  return hidden_states
85
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
86
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def eager_attention_forward(
90
  module: nn.Module,
91
  query: torch.Tensor,
@@ -104,8 +243,12 @@ def eager_attention_forward(
104
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
105
  attn_weights = attn_weights + causal_mask
106
 
107
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
108
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
 
 
 
109
  attn_output = torch.matmul(attn_weights, value_states)
110
  attn_output = attn_output.transpose(1, 2).contiguous()
111
 
@@ -119,16 +262,32 @@ class OuroAttention(nn.Module):
119
  super().__init__()
120
  self.config = config
121
  self.layer_idx = layer_idx
122
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
123
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
 
 
 
 
124
  self.scaling = self.head_dim**-0.5
125
  self.attention_dropout = config.attention_dropout
126
  self.is_causal = True
127
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
128
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
129
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
130
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
131
- self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def forward(
134
  self,
@@ -148,16 +307,25 @@ class OuroAttention(nn.Module):
148
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
149
 
150
  cos, sin = position_embeddings
151
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
152
 
153
  if past_key_value is not None:
154
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
155
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
156
- key_states, value_states = past_key_value.update(key_states, value_states, current_ut * self.config.num_hidden_layers + self.layer_idx, cache_kwargs)
 
 
 
 
 
157
 
158
  attention_interface: Callable = eager_attention_forward
159
  if self.config._attn_implementation != "eager":
160
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
161
 
162
  attn_output, attn_weights = attention_interface(
163
  self,
@@ -206,9 +374,15 @@ class OuroDecoderLayer(GradientCheckpointingLayer):
206
 
207
  self.mlp = OuroMLP(config)
208
  self.input_layernorm = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
- self.input_layernorm_2 = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
210
- self.post_attention_layernorm = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
211
- self.post_attention_layernorm_2 = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
 
212
  self.attention_type = config.layer_types[layer_idx]
213
 
214
  def forward(
@@ -219,7 +393,9 @@ class OuroDecoderLayer(GradientCheckpointingLayer):
219
  past_key_value: Optional[Cache] = None,
220
  use_cache: Optional[bool] = False,
221
  cache_position: Optional[torch.LongTensor] = None,
222
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
 
 
223
  **kwargs: Unpack[TransformersKwargs],
224
  ) -> tuple[torch.Tensor]:
225
  residual = hidden_states
@@ -271,7 +447,9 @@ class OuroRotaryEmbedding(nn.Module):
271
  super().__init__()
272
  # BC: "rope_type" was originally "type"
273
  if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
274
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
275
  else:
276
  self.rope_type = "default"
277
  self.max_seq_len_cached = config.max_position_embeddings
@@ -287,12 +465,23 @@ class OuroRotaryEmbedding(nn.Module):
287
  @torch.no_grad()
288
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
289
  def forward(self, x, position_ids):
290
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 
 
 
 
 
291
  position_ids_expanded = position_ids[:, None, :].float()
292
 
293
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
 
 
 
 
294
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
295
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 
 
296
  emb = torch.cat((freqs, freqs), dim=-1)
297
  cos = emb.cos() * self.attention_scaling
298
  sin = emb.sin() * self.attention_scaling
@@ -307,9 +496,14 @@ class OuroModel(OuroPreTrainedModel):
307
  self.padding_idx = config.pad_token_id
308
  self.vocab_size = config.vocab_size
309
 
310
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
311
  self.layers = nn.ModuleList(
312
- [OuroDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
313
  )
314
  self.norm = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
315
  self.rotary_emb = OuroRotaryEmbedding(config=config)
@@ -334,18 +528,34 @@ class OuroModel(OuroPreTrainedModel):
334
  **kwargs: Unpack[TransformersKwargs],
335
  ) -> BaseModelOutputWithPast:
336
  if (input_ids is None) ^ (inputs_embeds is not None):
337
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
338
 
339
  if inputs_embeds is None:
340
  inputs_embeds = self.embed_tokens(input_ids)
341
 
342
- if use_cache and past_key_values is None:
343
- past_key_values = DynamicCache()
 
 
 
 
 
 
 
 
 
 
344
 
345
  if cache_position is None:
346
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
347
  cache_position = torch.arange(
348
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
349
  )
350
 
351
  if position_ids is None:
@@ -368,7 +578,9 @@ class OuroModel(OuroPreTrainedModel):
368
  }
369
  # The sliding window alternating layers are not always activated depending on the config
370
  if self.has_sliding_layers:
371
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
 
 
372
 
373
  hidden_states = inputs_embeds
374
 
@@ -395,10 +607,14 @@ class OuroModel(OuroPreTrainedModel):
395
  hidden_states_list.append(hidden_states)
396
  gate_list.append(self.early_exit_gate(hidden_states))
397
 
398
- return BaseModelOutputWithPast(
399
- last_hidden_state=hidden_states,
400
- past_key_values=past_key_values if use_cache else None,
401
- ), hidden_states_list, gate_list
 
 
 
 
402
 
403
 
404
  @auto_docstring
@@ -412,12 +628,11 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
412
  self.model = OuroModel(config)
413
  self.vocab_size = config.vocab_size
414
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
415
-
416
  # 分块大小配置
417
- self.chunk_size = getattr(config, 'chunk_size', 2) # 默认分块大小为2
418
  self.early_exit_step = getattr(config, "early_exit_step", None)
419
  self.early_exit_threshold = getattr(config, "early_exit_threshold", None)
420
-
421
 
422
  # Initialize weights and apply final processing
423
  self.post_init()
@@ -449,13 +664,13 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
449
  r"""
450
  Args:
451
  use_weighted_exit (`bool`, *optional*, defaults to `False`):
452
- Whether to use weighted early exit. If `True`, the logits from all UT steps will be
453
  averaged according to the exit probability distribution.
454
  exit_at_step (`int`, *optional*):
455
- Specifies which UT step to exit at. If set, the model will directly use the hidden states
456
  from this step to generate logits, ignoring other exit strategies.
457
  exit_threshold (`float`, *optional*):
458
- The cumulative probability threshold for early exit. When the cumulative exit probability
459
  reaches this threshold, the model will exit at that step.
460
 
461
  Example:
@@ -471,8 +686,12 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
471
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
472
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
473
  ```"""
474
- exit_at_step = exit_at_step if exit_at_step is not None else self.early_exit_step
475
- exit_threshold = exit_threshold if exit_threshold is not None else self.early_exit_threshold
 
 
 
 
476
 
477
  outputs, hidden_states_list, gate_list = self.model(
478
  input_ids=input_ids,
@@ -484,14 +703,20 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
484
  cache_position=cache_position,
485
  **kwargs,
486
  )
487
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
 
 
 
 
488
 
489
  def _select_token_positions(tensor: torch.Tensor) -> torch.Tensor:
490
  if isinstance(slice_indices, slice):
491
  return tensor[:, slice_indices, ...]
492
  if isinstance(slice_indices, torch.Tensor):
493
  return tensor.index_select(1, slice_indices.to(tensor.device))
494
- raise TypeError(f"Unsupported index type for logits_to_keep: {type(slice_indices)}")
 
 
495
 
496
  stacked_exit_pdf = None
497
  if gate_list:
@@ -520,8 +745,14 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
520
  for step_idx, hidden in enumerate(hidden_states_list):
521
  step_hidden = _select_token_positions(hidden)
522
  step_logits = self.lm_head(step_hidden)
523
- weight = token_exit_pdf[..., step_idx].unsqueeze(-1).to(step_logits.dtype)
524
- expected_logits = step_logits * weight if expected_logits is None else expected_logits + step_logits * weight
 
 
 
 
 
 
525
  expected_logits_cache = expected_logits
526
  return expected_logits_cache
527
 
@@ -533,10 +764,17 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
533
  if logits is None:
534
  hidden_states = outputs.last_hidden_state
535
  logits = self.lm_head(_select_token_positions(hidden_states))
536
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
 
 
 
537
  else:
538
  if stacked_exit_pdf is not None and hidden_states_list:
539
- if exit_at_step is not None and 0 <= exit_at_step < len(hidden_states_list):
 
 
540
  selected_hidden = hidden_states_list[exit_at_step]
541
  logits = self.lm_head(_select_token_positions(selected_hidden))
542
  elif exit_threshold is not None:
@@ -551,8 +789,14 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
551
  never_exceeded = ~threshold_mask.any(dim=2)
552
  exit_steps[never_exceeded] = last_step_idx
553
  stacked_hidden = torch.stack(hidden_states_list, dim=2)
554
- gather_index = exit_steps.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, stacked_hidden.size(-1))
555
- final_hidden_states = torch.gather(stacked_hidden, 2, gather_index).squeeze(2)
 
 
 
 
 
 
556
  logits = self.lm_head(_select_token_positions(final_hidden_states))
557
  elif use_weighted_exit:
558
  logits = compute_expected_logits()
@@ -572,7 +816,9 @@ class OuroForCausalLM(OuroPreTrainedModel, GenerationMixin):
572
  return result
573
 
574
 
575
- class OuroForSequenceClassification(GenericForSequenceClassification, OuroPreTrainedModel):
 
 
576
  pass
577
 
578
 
@@ -581,7 +827,9 @@ class OuroForTokenClassification(GenericForTokenClassification, OuroPreTrainedMo
581
 
582
 
583
  class OuroForQuestionAnswering(GenericForQuestionAnswering, OuroPreTrainedModel):
584
- base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
 
 
585
 
586
 
587
  __all__ = [
@@ -591,4 +839,5 @@ __all__ = [
591
  "OuroForSequenceClassification",
592
  "OuroForTokenClassification",
593
  "OuroForQuestionAnswering",
594
- ]
 
 
1
+ import logging
2
+ from typing import Any, Callable, Optional, Union
3
 
4
  import torch
5
  from torch import nn
6
 
7
  from transformers.activations import ACT2FN
8
+ from transformers.cache_utils import Cache
9
  from transformers.generation import GenerationMixin
10
  from transformers.integrations import use_kernel_forward_from_hub
11
+ from transformers.masking_utils import (
12
+ create_causal_mask,
13
+ create_sliding_window_causal_mask,
14
+ )
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
  from transformers.modeling_layers import (
17
  GenericForQuestionAnswering,
 
19
  GenericForTokenClassification,
20
  GradientCheckpointingLayer,
21
  )
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
27
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
28
  from transformers.processing_utils import Unpack
 
31
  from .configuration_ouro import OuroConfig
32
 
33
 
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def needs_universal_cache(
38
+ cache: Optional[Cache], max_cache_size: Optional[int]
39
+ ) -> bool:
40
+ if cache is None:
41
+ return True
42
+ if isinstance(cache, UniversalTransformerCache):
43
+ return False
44
+ if not isinstance(cache, Cache):
45
+ return False
46
+ can_grow = getattr(cache, "layer_class_to_replicate", None) is not None
47
+ if can_grow:
48
+ # Dynamic caches can extend to any index, so let them be
49
+ return False
50
+ cache_layers = getattr(cache, "layers", [])
51
+ if max_cache_size is not None and len(cache_layers) < max_cache_size:
52
+ try:
53
+ cached_tokens = cache.get_seq_length()
54
+ except Exception:
55
+ cached_tokens = 0
56
+ if cached_tokens > 0:
57
+ raise ValueError(
58
+ "The provided cache cannot store all Universal Transformer iterations. Please "
59
+ "instantiate Ouro.modeling_ouro.UniversalTransformerCache and pass it as past_key_values."
60
+ )
61
+ return True
62
+ return False
63
+
64
+
65
  class OuroMLP(nn.Module):
66
  def __init__(self, config):
67
  super().__init__()
 
120
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
121
  if n_rep == 1:
122
  return hidden_states
123
+ hidden_states = hidden_states[:, :, None, :, :].expand(
124
+ batch, num_key_value_heads, n_rep, slen, head_dim
125
+ )
126
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
127
 
128
 
129
+ class UniversalTransformerCache(Cache):
130
+ """Cache implementation that supports Ouro's multi-step Universal Transformer loops."""
131
+
132
+ def __init__(self, max_cache_size: Optional[int] = None):
133
+ # We intentionally don't call super().__init__ because the parent assumes static cache sizes.
134
+ self.key_cache: list[Optional[torch.Tensor]] = []
135
+ self.value_cache: list[Optional[torch.Tensor]] = []
136
+ self.layers: list[Any] = [] # attribute expected by HF Cache utilities
137
+ self._seen_tokens = 0
138
+ self.max_cache_size = max_cache_size
139
+
140
+ def update(
141
+ self,
142
+ key_states: torch.Tensor,
143
+ value_states: torch.Tensor,
144
+ layer_idx: int,
145
+ cache_kwargs: Optional[dict] = None,
146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
147
+ if layer_idx < 0:
148
+ raise ValueError(f"layer_idx must be non-negative, got {layer_idx}")
149
+
150
+ if self.max_cache_size is not None and layer_idx >= self.max_cache_size:
151
+ raise IndexError(
152
+ f"Cache index {layer_idx} exceeds configured max_cache_size={self.max_cache_size}. "
153
+ "Check total_ut_steps and num_hidden_layers."
154
+ )
155
+
156
+ # Expand cache storage so the requested index is available.
157
+ while len(self.key_cache) <= layer_idx:
158
+ self.key_cache.append(None)
159
+ self.value_cache.append(None)
160
+
161
+ cached_key = self.key_cache[layer_idx]
162
+ cached_value = self.value_cache[layer_idx]
163
+
164
+ if cached_key is None:
165
+ self.key_cache[layer_idx] = key_states
166
+ self.value_cache[layer_idx] = value_states
167
+ else:
168
+ if (
169
+ key_states.shape[0] != cached_key.shape[0]
170
+ or key_states.shape[1] != cached_key.shape[1]
171
+ or key_states.shape[3] != cached_key.shape[3]
172
+ ):
173
+ raise ValueError(
174
+ "Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
175
+ )
176
+ assert cached_value is not None
177
+ self.key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
178
+ self.value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
179
+
180
+ result_key = self.key_cache[layer_idx]
181
+ result_value = self.value_cache[layer_idx]
182
+ assert result_key is not None and result_value is not None
183
+
184
+ # Track sequence length using the first populated cache entry.
185
+ self._seen_tokens = result_key.shape[2]
186
+ return result_key, result_value
187
+
188
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
189
+ if layer_idx is None:
190
+ layer_idx = 0
191
+ if layer_idx < 0 or len(self.key_cache) <= layer_idx:
192
+ return 0
193
+ cached = self.key_cache[layer_idx]
194
+ if cached is None:
195
+ return 0
196
+ return cached.shape[2]
197
+
198
+ def get_max_length(self) -> Optional[int]:
199
+ return None
200
+
201
+ def get_usable_length(
202
+ self, new_seq_length: int, layer_idx: Optional[int] = 0
203
+ ) -> int:
204
+ return self.get_seq_length(layer_idx)
205
+
206
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
207
+ for idx, (key_entry, value_entry) in enumerate(
208
+ zip(self.key_cache, self.value_cache)
209
+ ):
210
+ if key_entry is None:
211
+ continue
212
+ assert value_entry is not None
213
+ device = key_entry.device
214
+ self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
215
+ self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
216
+
217
+ @property
218
+ def is_compileable(self) -> bool:
219
+ return False
220
+
221
+ def clear(self) -> None:
222
+ logger.debug("Clearing UniversalTransformerCache")
223
+ self.key_cache = []
224
+ self.value_cache = []
225
+ self._seen_tokens = 0
226
+
227
+
228
  def eager_attention_forward(
229
  module: nn.Module,
230
  query: torch.Tensor,
 
243
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
244
  attn_weights = attn_weights + causal_mask
245
 
246
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
247
+ query.dtype
248
+ )
249
+ attn_weights = nn.functional.dropout(
250
+ attn_weights, p=dropout, training=module.training
251
+ )
252
  attn_output = torch.matmul(attn_weights, value_states)
253
  attn_output = attn_output.transpose(1, 2).contiguous()
254
 
 
262
  super().__init__()
263
  self.config = config
264
  self.layer_idx = layer_idx
265
+ self.head_dim = getattr(
266
+ config, "head_dim", config.hidden_size // config.num_attention_heads
267
+ )
268
+ self.num_key_value_groups = (
269
+ config.num_attention_heads // config.num_key_value_heads
270
+ )
271
  self.scaling = self.head_dim**-0.5
272
  self.attention_dropout = config.attention_dropout
273
  self.is_causal = True
274
+ self.q_proj = nn.Linear(
275
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
276
+ )
277
+ self.k_proj = nn.Linear(
278
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
279
+ )
280
+ self.v_proj = nn.Linear(
281
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
282
+ )
283
+ self.o_proj = nn.Linear(
284
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
285
+ )
286
+ self.sliding_window = (
287
+ config.sliding_window
288
+ if config.layer_types[layer_idx] == "sliding_attention"
289
+ else None
290
+ )
291
 
292
  def forward(
293
  self,
 
307
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
308
 
309
  cos, sin = position_embeddings
310
+ query_states, key_states = apply_rotary_pos_emb(
311
+ query_states, key_states, cos, sin
312
+ )
313
 
314
  if past_key_value is not None:
315
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
316
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
317
+ key_states, value_states = past_key_value.update(
318
+ key_states,
319
+ value_states,
320
+ current_ut * self.config.num_hidden_layers + self.layer_idx,
321
+ cache_kwargs,
322
+ )
323
 
324
  attention_interface: Callable = eager_attention_forward
325
  if self.config._attn_implementation != "eager":
326
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
327
+ self.config._attn_implementation
328
+ ]
329
 
330
  attn_output, attn_weights = attention_interface(
331
  self,
 
374
 
375
  self.mlp = OuroMLP(config)
376
  self.input_layernorm = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377
+ self.input_layernorm_2 = OuroRMSNorm(
378
+ config.hidden_size, eps=config.rms_norm_eps
379
+ )
380
+ self.post_attention_layernorm = OuroRMSNorm(
381
+ config.hidden_size, eps=config.rms_norm_eps
382
+ )
383
+ self.post_attention_layernorm_2 = OuroRMSNorm(
384
+ config.hidden_size, eps=config.rms_norm_eps
385
+ )
386
  self.attention_type = config.layer_types[layer_idx]
387
 
388
  def forward(
 
393
  past_key_value: Optional[Cache] = None,
394
  use_cache: Optional[bool] = False,
395
  cache_position: Optional[torch.LongTensor] = None,
396
+ position_embeddings: Optional[
397
+ tuple[torch.Tensor, torch.Tensor]
398
+ ] = None, # necessary, but kept here for BC
399
  **kwargs: Unpack[TransformersKwargs],
400
  ) -> tuple[torch.Tensor]:
401
  residual = hidden_states
 
447
  super().__init__()
448
  # BC: "rope_type" was originally "type"
449
  if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
450
+ self.rope_type = config.rope_scaling.get(
451
+ "rope_type", config.rope_scaling.get("type")
452
+ )
453
  else:
454
  self.rope_type = "default"
455
  self.max_seq_len_cached = config.max_position_embeddings
 
465
  @torch.no_grad()
466
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
467
  def forward(self, x, position_ids):
468
+ inv_freq_expanded = (
469
+ self.inv_freq[None, :, None]
470
+ .float()
471
+ .expand(position_ids.shape[0], -1, 1)
472
+ .to(x.device)
473
+ )
474
  position_ids_expanded = position_ids[:, None, :].float()
475
 
476
+ device_type = (
477
+ x.device.type
478
+ if isinstance(x.device.type, str) and x.device.type != "mps"
479
+ else "cpu"
480
+ )
481
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
482
+ freqs = (
483
+ inv_freq_expanded.float() @ position_ids_expanded.float()
484
+ ).transpose(1, 2)
485
  emb = torch.cat((freqs, freqs), dim=-1)
486
  cos = emb.cos() * self.attention_scaling
487
  sin = emb.sin() * self.attention_scaling
 
496
  self.padding_idx = config.pad_token_id
497
  self.vocab_size = config.vocab_size
498
 
499
+ self.embed_tokens = nn.Embedding(
500
+ config.vocab_size, config.hidden_size, self.padding_idx
501
+ )
502
  self.layers = nn.ModuleList(
503
+ [
504
+ OuroDecoderLayer(config, layer_idx)
505
+ for layer_idx in range(config.num_hidden_layers)
506
+ ]
507
  )
508
  self.norm = OuroRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
509
  self.rotary_emb = OuroRotaryEmbedding(config=config)
 
528
  **kwargs: Unpack[TransformersKwargs],
529
  ) -> BaseModelOutputWithPast:
530
  if (input_ids is None) ^ (inputs_embeds is not None):
531
+ raise ValueError(
532
+ "You must specify exactly one of input_ids or inputs_embeds"
533
+ )
534
 
535
  if inputs_embeds is None:
536
  inputs_embeds = self.embed_tokens(input_ids)
537
 
538
+ if use_cache is None:
539
+ use_cache = self.config.use_cache
540
+
541
+ max_cache_size: Optional[int] = None
542
+ if use_cache:
543
+ total_ut_steps = getattr(self.config, "total_ut_steps", 1) or 1
544
+ total_layers = getattr(self.config, "num_hidden_layers", None)
545
+ if total_layers is not None:
546
+ max_cache_size = total_layers * total_ut_steps
547
+
548
+ if needs_universal_cache(past_key_values, max_cache_size):
549
+ past_key_values = UniversalTransformerCache(max_cache_size)
550
 
551
  if cache_position is None:
552
+ past_seen_tokens = (
553
+ past_key_values.get_seq_length() if past_key_values is not None else 0
554
+ )
555
  cache_position = torch.arange(
556
+ past_seen_tokens,
557
+ past_seen_tokens + inputs_embeds.shape[1],
558
+ device=inputs_embeds.device,
559
  )
560
 
561
  if position_ids is None:
 
578
  }
579
  # The sliding window alternating layers are not always activated depending on the config
580
  if self.has_sliding_layers:
581
+ causal_mask_mapping["sliding_attention"] = (
582
+ create_sliding_window_causal_mask(**mask_kwargs)
583
+ )
584
 
585
  hidden_states = inputs_embeds
586
 
 
607
  hidden_states_list.append(hidden_states)
608
  gate_list.append(self.early_exit_gate(hidden_states))
609
 
610
+ return (
611
+ BaseModelOutputWithPast(
612
+ last_hidden_state=hidden_states,
613
+ past_key_values=past_key_values if use_cache else None,
614
+ ),
615
+ hidden_states_list,
616
+ gate_list,
617
+ )
618
 
619
 
620
  @auto_docstring
 
628
  self.model = OuroModel(config)
629
  self.vocab_size = config.vocab_size
630
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
631
+
632
  # 分块大小配置
633
+ self.chunk_size = getattr(config, "chunk_size", 2) # 默认分块大小为2
634
  self.early_exit_step = getattr(config, "early_exit_step", None)
635
  self.early_exit_threshold = getattr(config, "early_exit_threshold", None)
 
636
 
637
  # Initialize weights and apply final processing
638
  self.post_init()
 
664
  r"""
665
  Args:
666
  use_weighted_exit (`bool`, *optional*, defaults to `False`):
667
+ Whether to use weighted early exit. If `True`, the logits from all UT steps will be
668
  averaged according to the exit probability distribution.
669
  exit_at_step (`int`, *optional*):
670
+ Specifies which UT step to exit at. If set, the model will directly use the hidden states
671
  from this step to generate logits, ignoring other exit strategies.
672
  exit_threshold (`float`, *optional*):
673
+ The cumulative probability threshold for early exit. When the cumulative exit probability
674
  reaches this threshold, the model will exit at that step.
675
 
676
  Example:
 
686
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
687
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
688
  ```"""
689
+ exit_at_step = (
690
+ exit_at_step if exit_at_step is not None else self.early_exit_step
691
+ )
692
+ exit_threshold = (
693
+ exit_threshold if exit_threshold is not None else self.early_exit_threshold
694
+ )
695
 
696
  outputs, hidden_states_list, gate_list = self.model(
697
  input_ids=input_ids,
 
703
  cache_position=cache_position,
704
  **kwargs,
705
  )
706
+ slice_indices = (
707
+ slice(-logits_to_keep, None)
708
+ if isinstance(logits_to_keep, int)
709
+ else logits_to_keep
710
+ )
711
 
712
  def _select_token_positions(tensor: torch.Tensor) -> torch.Tensor:
713
  if isinstance(slice_indices, slice):
714
  return tensor[:, slice_indices, ...]
715
  if isinstance(slice_indices, torch.Tensor):
716
  return tensor.index_select(1, slice_indices.to(tensor.device))
717
+ raise TypeError(
718
+ f"Unsupported index type for logits_to_keep: {type(slice_indices)}"
719
+ )
720
 
721
  stacked_exit_pdf = None
722
  if gate_list:
 
745
  for step_idx, hidden in enumerate(hidden_states_list):
746
  step_hidden = _select_token_positions(hidden)
747
  step_logits = self.lm_head(step_hidden)
748
+ weight = (
749
+ token_exit_pdf[..., step_idx].unsqueeze(-1).to(step_logits.dtype)
750
+ )
751
+ expected_logits = (
752
+ step_logits * weight
753
+ if expected_logits is None
754
+ else expected_logits + step_logits * weight
755
+ )
756
  expected_logits_cache = expected_logits
757
  return expected_logits_cache
758
 
 
764
  if logits is None:
765
  hidden_states = outputs.last_hidden_state
766
  logits = self.lm_head(_select_token_positions(hidden_states))
767
+ loss = self.loss_function(
768
+ logits=logits,
769
+ labels=labels,
770
+ vocab_size=self.config.vocab_size,
771
+ **kwargs,
772
+ )
773
  else:
774
  if stacked_exit_pdf is not None and hidden_states_list:
775
+ if exit_at_step is not None and 0 <= exit_at_step < len(
776
+ hidden_states_list
777
+ ):
778
  selected_hidden = hidden_states_list[exit_at_step]
779
  logits = self.lm_head(_select_token_positions(selected_hidden))
780
  elif exit_threshold is not None:
 
789
  never_exceeded = ~threshold_mask.any(dim=2)
790
  exit_steps[never_exceeded] = last_step_idx
791
  stacked_hidden = torch.stack(hidden_states_list, dim=2)
792
+ gather_index = (
793
+ exit_steps.unsqueeze(-1)
794
+ .unsqueeze(-1)
795
+ .expand(-1, -1, 1, stacked_hidden.size(-1))
796
+ )
797
+ final_hidden_states = torch.gather(
798
+ stacked_hidden, 2, gather_index
799
+ ).squeeze(2)
800
  logits = self.lm_head(_select_token_positions(final_hidden_states))
801
  elif use_weighted_exit:
802
  logits = compute_expected_logits()
 
816
  return result
817
 
818
 
819
+ class OuroForSequenceClassification(
820
+ GenericForSequenceClassification, OuroPreTrainedModel
821
+ ):
822
  pass
823
 
824
 
 
827
 
828
 
829
  class OuroForQuestionAnswering(GenericForQuestionAnswering, OuroPreTrainedModel):
830
+ base_model_prefix = (
831
+ "transformer" # For BC, where `transformer` was used instead of `model`
832
+ )
833
 
834
 
835
  __all__ = [
 
839
  "OuroForSequenceClassification",
840
  "OuroForTokenClassification",
841
  "OuroForQuestionAnswering",
842
+ "UniversalTransformerCache",
843
+ ]