LeroyDyer commited on
Commit
23544c8
1 Parent(s): dd9fb65

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_mistral.py +0 -2
  2. modeling_mistral.py +2065 -14
configuration_mistral.py CHANGED
@@ -190,8 +190,6 @@ class MistralConfig(PretrainedConfig):
190
  self.use_complex_think_head = use_complex_think_head
191
  self.use_complex_talk_head = use_complex_talk_head
192
  self.use_weighted_talk_head = use_weighted_talk_head
193
-
194
-
195
 
196
  super().__init__(
197
  pad_token_id=pad_token_id,
 
190
  self.use_complex_think_head = use_complex_think_head
191
  self.use_complex_talk_head = use_complex_talk_head
192
  self.use_weighted_talk_head = use_weighted_talk_head
 
 
193
 
194
  super().__init__(
195
  pad_token_id=pad_token_id,
modeling_mistral.py CHANGED
@@ -36,25 +36,31 @@ import warnings
36
  from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
38
 
 
39
  import torch
40
  import torch.nn.functional as F
41
  import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
-
45
- from ...activations import ACT2FN
46
- from ...cache_utils import Cache, DynamicCache
47
- from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
- from ...modeling_utils import PreTrainedModel
50
- from ...utils import (
51
- add_start_docstrings,
52
- add_start_docstrings_to_model_forward,
53
- is_flash_attn_2_available,
54
- is_flash_attn_greater_or_equal_2_10,
55
- logging,
56
- replace_return_docstrings,
 
 
 
 
57
  )
 
58
  from .configuration_mistral import MistralConfig
59
 
60
 
@@ -64,6 +70,14 @@ if is_flash_attn_2_available():
64
 
65
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
66
 
 
 
 
 
 
 
 
 
67
 
68
  logger = logging.get_logger(__name__)
69
 
@@ -134,6 +148,116 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
134
  previous_text = current_text
135
  c.showPage()
136
  c.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -148,6 +272,88 @@ def _get_unpad_data(attention_mask):
148
  max_seqlen_in_batch,
149
  )
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
153
  class MistralRMSNorm(nn.Module):
@@ -476,10 +682,11 @@ class MistralAttention(nn.Module):
476
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
477
 
478
  self._init_rope()
479
- )
480
 
481
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
482
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
483
  def _init_rope(self):
484
  if self.config.rope_scaling is None:
485
  self.rotary_emb = MistralRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)
@@ -2574,3 +2781,1847 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
2574
  hidden_states=transformer_outputs.hidden_states,
2575
  attentions=transformer_outputs.attentions,
2576
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
38
 
39
+
40
  import torch
41
  import torch.nn.functional as F
42
  import torch.utils.checkpoint
43
  from torch import nn
44
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
45
+ from transformers.generation.utils import GenerationMixin
46
+ from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
47
+ from transformers import TextStreamer, AutoTokenizer
48
+ import transformers
49
+
50
+ from transformers.activations import ACT2FN
51
+ from transformers.cache_utils import Cache, DynamicCache
52
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
53
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
54
+ from transformers.modeling_utils import PreTrainedModel
55
+ from transformers.utils import (
56
+ add_start_docstrings,
57
+ add_start_docstrings_to_model_forward,
58
+ is_flash_attn_2_available,
59
+ is_flash_attn_greater_or_equal_2_10,
60
+ logging,
61
+ replace_return_docstrings,
62
  )
63
+
64
  from .configuration_mistral import MistralConfig
65
 
66
 
 
70
 
71
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
72
 
73
+ from .configuration_quiet import QuietConfig
74
+
75
+ import time
76
+ from typing import Optional, List
77
+
78
+
79
+
80
+
81
 
82
  logger = logging.get_logger(__name__)
83
 
 
148
  previous_text = current_text
149
  c.showPage()
150
  c.save()
151
+ def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
152
+ # Compute the attention mask correctly
153
+ bsz, tgt_len = input_shape
154
+
155
+ # Create a 4D attention mask from a 2D tensor mask.
156
+ # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
157
+ # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
158
+ combined_attention_mask = None
159
+ if attention_mask is not None:
160
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
161
+ # In this case, we can just use it directly.
162
+ if attention_mask.dim() == 4:
163
+ combined_attention_mask = attention_mask
164
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
165
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
166
+ elif attention_mask.dim() == 3:
167
+ expanded_attn_mask = attention_mask[:, None, :, :]
168
+ combined_attention_mask = expanded_attn_mask
169
+ # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
170
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
171
+ elif attention_mask.dim() == 2:
172
+ # Provided a padding mask of dimensions [batch_size, seq_length]
173
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
174
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
175
+ if past_key_values_length > 0:
176
+ attention_mask = attention_mask.to(dtype=torch.long)
177
+ attention_mask = attention_mask[:, past_key_values_length:]
178
+ expanded_attn_mask = attention_mask[:, None, None, :]
179
+ combined_attention_mask = expanded_attn_mask
180
+ else:
181
+ raise ValueError(
182
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
183
+ input_shape, attention_mask.shape
184
+ )
185
+ )
186
+
187
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
188
+ # masked positions, this operation will create a tensor which is 0.0 for
189
+ # positions we want to attend and -10000.0 for masked positions.
190
+ # Since we are adding it to the raw scores before the softmax, this is
191
+ # effectively the same as removing these entirely.
192
+ if combined_attention_mask is not None:
193
+ # Ensure the attention mask values are within a reasonable range
194
+ combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
195
+
196
+ # Convert the attention mask to bfloat16
197
+ combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
198
+
199
+ # Normalize the attention mask values to be between 0 and 1
200
+ combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
201
+ else:
202
+ combined_attention_mask = torch.zeros(
203
+ (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
204
+ )
205
+
206
+ return combined_attention_mask
207
+
208
+
209
+
210
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
211
+ class QuietRMSNorm(nn.Module):
212
+ def __init__(self, hidden_size, eps=1e-6):
213
+ super().__init__()
214
+ self.weight = nn.Parameter(torch.ones(hidden_size))
215
+ self.variance_epsilon = eps
216
+
217
+
218
+ def forward(self, hidden_states):
219
+ input_dtype = hidden_states.dtype
220
+ hidden_states = hidden_states.to(torch.float32)
221
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
222
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
223
+ return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
224
+
225
+
226
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
227
+ class QuietRotaryEmbedding(nn.Module):
228
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
229
+ super().__init__()
230
+
231
+ self.dim = dim
232
+ self.max_position_embeddings = max_position_embeddings
233
+ self.base = base
234
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
235
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
236
+
237
+ # Build here to make `torch.jit.trace` work.
238
+ self._set_cos_sin_cache(
239
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
240
+ )
241
+
242
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
243
+ self.max_seq_len_cached = seq_len
244
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
245
+
246
+ freqs = torch.outer(t, self.inv_freq)
247
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
248
+ emb = torch.cat((freqs, freqs), dim=-1)
249
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
250
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
251
+
252
+ def forward(self, x, seq_len=None):
253
+ # x: [bs, num_attention_heads, seq_len, head_size]
254
+ if seq_len > self.max_seq_len_cached:
255
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
256
+
257
+ return (
258
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
259
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
260
+ )
261
 
262
 
263
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
272
  max_seqlen_in_batch,
273
  )
274
 
275
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
276
+ def _make_causal_mask(
277
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
278
+ ):
279
+ """
280
+ Make causal mask used for bi-directional self-attention.
281
+ """
282
+ bsz, tgt_len = input_ids_shape
283
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
284
+ mask_cond = torch.arange(mask.size(-1), device=device)
285
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
286
+ mask = mask.to(dtype)
287
+
288
+ if past_key_values_length > 0:
289
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
290
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
291
+
292
+ def _make_sliding_window_causal_mask(
293
+ input_ids_shape: torch.Size,
294
+ dtype: torch.dtype,
295
+ device: torch.device,
296
+ past_key_values_length: int = 0,
297
+ sliding_window: int = 4096,
298
+ ):
299
+ """
300
+ Make causal mask used for sliding window attention
301
+ """
302
+ bsz, tgt_len = input_ids_shape
303
+
304
+ tensor = torch.full(
305
+ (tgt_len, tgt_len),
306
+ fill_value=1,
307
+ device=device,
308
+ )
309
+ mask = torch.tril(tensor, diagonal=0)
310
+ # make the mask banded to account for sliding window
311
+ mask = torch.triu(mask, diagonal=-sliding_window)
312
+ mask = torch.log(mask).to(dtype)
313
+
314
+ if past_key_values_length > 0:
315
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
316
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
317
+
318
+
319
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
320
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
321
+ """
322
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
323
+ """
324
+ bsz, src_len = mask.size()
325
+ tgt_len = tgt_len if tgt_len is not None else src_len
326
+
327
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
328
+
329
+ inverted_mask = 1.0 - expanded_mask
330
+
331
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
332
+
333
+ # Inverse dim formula to find dim based on number of rotations
334
+ def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
335
+ return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
336
+
337
+ # Find dim range bounds based on rotations
338
+ def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
339
+ low = math.floor(_yarn_find_correction_dim(
340
+ low_rot, dim, base, max_position_embeddings))
341
+ high = math.ceil(_yarn_find_correction_dim(
342
+ high_rot, dim, base, max_position_embeddings))
343
+ return max(low, 0), min(high, dim-1) # Clamp values just in case
344
+
345
+ def _yarn_linear_ramp_mask(min, max, dim):
346
+ if min == max:
347
+ max += 0.001 # Prevent singularity
348
+
349
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
350
+ ramp_func = torch.clamp(linear_func, 0, 1)
351
+ return ramp_func
352
+
353
+ def _yarn_get_mscale(scale=1):
354
+ if scale <= 1:
355
+ return 1.0
356
+ return 0.07 * math.log(scale) + 1.0
357
 
358
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
359
  class MistralRMSNorm(nn.Module):
 
682
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
683
 
684
  self._init_rope()
685
+
686
 
687
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
688
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
689
+
690
  def _init_rope(self):
691
  if self.config.rope_scaling is None:
692
  self.rotary_emb = MistralRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)
 
2781
  hidden_states=transformer_outputs.hidden_states,
2782
  attentions=transformer_outputs.attentions,
2783
  )
2784
+
2785
+ class QuietMLP(nn.Module):
2786
+ def __init__(self, config):
2787
+ super().__init__()
2788
+ self.config = config
2789
+ self.hidden_size = config.hidden_size
2790
+ self.intermediate_size = config.intermediate_size
2791
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
2792
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
2793
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
2794
+ self.act_fn = ACT2FN[config.hidden_act]
2795
+
2796
+ def forward(self, x):
2797
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
2798
+
2799
+
2800
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
2801
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
2802
+ """
2803
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
2804
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
2805
+ """
2806
+
2807
+ # pdb.set_trace()
2808
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
2809
+ if n_rep == 1:
2810
+ return hidden_states
2811
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
2812
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
2813
+
2814
+
2815
+ class QuietAttention(nn.Module):
2816
+ """
2817
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
2818
+ and "Generating Long Sequences with Sparse Transformers".
2819
+ """
2820
+
2821
+ def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
2822
+ super().__init__()
2823
+ self.config = config
2824
+ self.layer_idx = layer_idx
2825
+ if layer_idx is None:
2826
+ logger.warning_once(
2827
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
2828
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
2829
+ "when creating this class."
2830
+ )
2831
+
2832
+ self.hidden_size = config.hidden_size
2833
+ self.num_heads = config.num_attention_heads
2834
+ self.head_dim = self.hidden_size // self.num_heads
2835
+ self.num_key_value_heads = config.num_key_value_heads
2836
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
2837
+ self.max_position_embeddings = config.max_position_embeddings
2838
+ self.rope_theta = config.rope_theta
2839
+ self.is_causal = True
2840
+ self.attention_dropout = config.attention_dropout
2841
+ self._attn_implementation = config._attn_implementation
2842
+
2843
+ if (self.head_dim * self.num_heads) != self.hidden_size:
2844
+ raise ValueError(
2845
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
2846
+ f" and `num_heads`: {self.num_heads})."
2847
+ )
2848
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
2849
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
2850
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
2851
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
2852
+
2853
+ self.rotary_emb = QuietRotaryEmbedding(
2854
+ self.head_dim,
2855
+ max_position_embeddings=self.max_position_embeddings,
2856
+ base=self.rope_theta,
2857
+ )
2858
+
2859
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
2860
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
2861
+
2862
+ def forward(
2863
+ self,
2864
+ hidden_states: torch.Tensor,
2865
+ attention_mask: Optional[torch.Tensor] = None,
2866
+ position_ids: Optional[torch.LongTensor] = None,
2867
+ past_key_value: Optional[Cache] = None,
2868
+ output_attentions: bool = False,
2869
+ use_cache: bool = False,
2870
+ **kwargs,
2871
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
2872
+ if "padding_mask" in kwargs:
2873
+ warnings.warn(
2874
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
2875
+ )
2876
+ bsz, q_len, _ = hidden_states.size()
2877
+
2878
+ query_states = self.q_proj(hidden_states)
2879
+ key_states = self.k_proj(hidden_states)
2880
+ value_states = self.v_proj(hidden_states)
2881
+
2882
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
2883
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
2884
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
2885
+
2886
+ kv_seq_len = key_states.shape[-2]
2887
+ if past_key_value is not None:
2888
+ if self.layer_idx is None:
2889
+ raise ValueError(
2890
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
2891
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
2892
+ "with a layer index."
2893
+ )
2894
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
2895
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
2896
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
2897
+
2898
+ if past_key_value is not None:
2899
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
2900
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
2901
+
2902
+ # repeat k/v heads if n_kv_heads < n_heads
2903
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
2904
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
2905
+
2906
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
2907
+
2908
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
2909
+ raise ValueError(
2910
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
2911
+ f" {attn_weights.size()}"
2912
+ )
2913
+ if self._attn_implementation == "flash_attention_2":
2914
+ # Prepare attention mask for flash-attn
2915
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
2916
+ elif self._attn_implementation == "sdpa":
2917
+ # Prepare attention mask for SDPA
2918
+ if attention_mask is None or attention_mask.dim() == 2:
2919
+ attention_mask = _prepare_4d_causal_attention_mask(
2920
+ attention_mask,
2921
+ (batch_size, seq_length),
2922
+ inputs_embeds,
2923
+ past_key_values_length,
2924
+ sliding_window=self.config.sliding_window,
2925
+ )
2926
+ else:
2927
+ # Prepare attention mask for other implementations
2928
+ if attention_mask is None or attention_mask.dim() == 2:
2929
+ attention_mask = _prepare_4d_causal_attention_mask(
2930
+ attention_mask,
2931
+ (batch_size, seq_length),
2932
+ inputs_embeds,
2933
+ past_key_values_length,
2934
+ sliding_window=self.config.sliding_window,
2935
+ )
2936
+
2937
+ if attention_mask is not None:
2938
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
2939
+ raise ValueError(
2940
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
2941
+ )
2942
+
2943
+ attn_weights = attn_weights + attention_mask
2944
+
2945
+ # upcast attention to fp32
2946
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
2947
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
2948
+ attn_output = torch.matmul(attn_weights, value_states)
2949
+
2950
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
2951
+ raise ValueError(
2952
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
2953
+ f" {attn_output.size()}"
2954
+ )
2955
+
2956
+ attn_output = attn_output.transpose(1, 2).contiguous()
2957
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
2958
+
2959
+ attn_output = self.o_proj(attn_output)
2960
+
2961
+ if not output_attentions:
2962
+ attn_weights = None
2963
+
2964
+ return attn_output, attn_weights, past_key_value
2965
+
2966
+
2967
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
2968
+ class QuietSdpaAttention(QuietAttention):
2969
+ """
2970
+ Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
2971
+ `QuietAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
2972
+ SDPA API.
2973
+ """
2974
+
2975
+ # Adapted from QuietAttention.forward
2976
+ def forward(
2977
+ self,
2978
+ hidden_states: torch.Tensor,
2979
+ attention_mask: Optional[torch.Tensor] = None,
2980
+ position_ids: Optional[torch.LongTensor] = None,
2981
+ past_key_value: Optional[Cache] = None,
2982
+ output_attentions: bool = False,
2983
+ use_cache: bool = False,
2984
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
2985
+ if output_attentions:
2986
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
2987
+ logger.warning_once(
2988
+ "QuietModel is using QuietSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
2989
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
2990
+ )
2991
+ return super().forward(
2992
+ hidden_states=hidden_states,
2993
+ attention_mask=attention_mask,
2994
+ position_ids=position_ids,
2995
+ past_key_value=past_key_value,
2996
+ output_attentions=output_attentions,
2997
+ use_cache=use_cache,
2998
+ )
2999
+ bsz, q_len, _ = hidden_states.size()
3000
+
3001
+ query_states = self.q_proj(hidden_states)
3002
+ key_states = self.k_proj(hidden_states)
3003
+ value_states = self.v_proj(hidden_states)
3004
+
3005
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
3006
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
3007
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
3008
+
3009
+ kv_seq_len = key_states.shape[-2]
3010
+ if past_key_value is not None:
3011
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
3012
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
3013
+
3014
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
3015
+
3016
+ if past_key_value is not None:
3017
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
3018
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
3019
+
3020
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
3021
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
3022
+
3023
+ if attention_mask is not None:
3024
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
3025
+ raise ValueError(
3026
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
3027
+ )
3028
+ attention_mask = attention_mask.to(query_states.dtype)
3029
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
3030
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
3031
+ if query_states.device.type == "cuda" and attention_mask is not None:
3032
+ query_states = query_states.contiguous()
3033
+ key_states = key_states.contiguous()
3034
+ value_states = value_states.contiguous()
3035
+
3036
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
3037
+ query_states,
3038
+ key_states,
3039
+ value_states,
3040
+ attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
3041
+ dropout_p=self.attention_dropout if self.training else 0.0,
3042
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
3043
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
3044
+ )
3045
+
3046
+ attn_output = attn_output.transpose(1, 2).contiguous()
3047
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
3048
+
3049
+ attn_output = self.o_proj(attn_output)
3050
+
3051
+ return attn_output, None, past_key_value
3052
+
3053
+
3054
+ QUIET_ATTENTION_CLASSES = {
3055
+ "eager": QuietAttention,
3056
+ "sdpa": QuietSdpaAttention,
3057
+ }
3058
+
3059
+
3060
+ class QuietDecoderLayer(nn.Module):
3061
+ def __init__(self, config: QuietConfig, layer_idx: int):
3062
+ super().__init__()
3063
+ self.hidden_size = config.hidden_size
3064
+
3065
+ self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
3066
+
3067
+ self.mlp = QuietMLP(config)
3068
+ self.input_layernorm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
3069
+ self.post_attention_layernorm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
3070
+
3071
+ def forward(
3072
+ self,
3073
+ hidden_states: torch.Tensor,
3074
+ attention_mask: Optional[torch.Tensor] = None,
3075
+ position_ids: Optional[torch.LongTensor] = None,
3076
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
3077
+ output_attentions: Optional[bool] = False,
3078
+ use_cache: Optional[bool] = False,
3079
+ **kwargs,
3080
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
3081
+ if "padding_mask" in kwargs:
3082
+ warnings.warn(
3083
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
3084
+ )
3085
+ """
3086
+ Args:
3087
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
3088
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
3089
+ `(batch, sequence_length)` where padding elements are indicated by 0.
3090
+ output_attentions (`bool`, *optional*):
3091
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
3092
+ returned tensors for more detail.
3093
+ use_cache (`bool`, *optional*):
3094
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
3095
+ (see `past_key_values`).
3096
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
3097
+ """
3098
+
3099
+ residual = hidden_states
3100
+
3101
+ hidden_states = self.input_layernorm(hidden_states)
3102
+
3103
+ # Self Attention
3104
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
3105
+ hidden_states=hidden_states,
3106
+ attention_mask=attention_mask,
3107
+ position_ids=position_ids,
3108
+ past_key_value=past_key_value,
3109
+ output_attentions=output_attentions,
3110
+ use_cache=use_cache,
3111
+ )
3112
+ hidden_states = residual.to(hidden_states.device) + hidden_states
3113
+
3114
+ # Fully Connected
3115
+ residual = hidden_states
3116
+ hidden_states = self.post_attention_layernorm(hidden_states)
3117
+ hidden_states = self.mlp(hidden_states)
3118
+ hidden_states = residual + hidden_states
3119
+
3120
+ outputs = (hidden_states,)
3121
+
3122
+ if output_attentions:
3123
+ outputs += (self_attn_weights,)
3124
+
3125
+ if use_cache:
3126
+ outputs += (present_key_value,)
3127
+
3128
+ return outputs
3129
+
3130
+
3131
+ QUIET_START_DOCSTRING = r"""
3132
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
3133
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
3134
+ etc.)
3135
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
3136
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
3137
+ and behavior.
3138
+ Parameters:
3139
+ config ([`QuietConfig`]):
3140
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
3141
+ load the weights associated with the model, only the configuration. Check out the
3142
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
3143
+ """
3144
+
3145
+
3146
+ @add_start_docstrings(
3147
+ "The bare Quiet Model outputting raw hidden-states without any specific head on top.",
3148
+ QUIET_START_DOCSTRING,
3149
+ )
3150
+ class QuietPreTrainedModel(PreTrainedModel):
3151
+ config_class = QuietConfig
3152
+ base_model_prefix = "model"
3153
+ supports_gradient_checkpointing = True
3154
+ _no_split_modules = ["QuietDecoderLayer"]
3155
+ _skip_keys_device_placement = "past_key_values"
3156
+ _supports_flash_attn_2 = True
3157
+ _supports_sdpa = True
3158
+ _supports_cache_class = True
3159
+
3160
+ def _init_weights(self, module):
3161
+ std = self.config.initializer_range
3162
+ if isinstance(module, nn.Linear):
3163
+ module.weight.data.normal_(mean=0.0, std=std)
3164
+ if module.bias is not None:
3165
+ module.bias.data.zero_()
3166
+ elif isinstance(module, nn.Embedding):
3167
+ module.weight.data.normal_(mean=0.0, std=std)
3168
+ if module.padding_idx is not None:
3169
+ module.weight.data[module.padding_idx].zero_()
3170
+
3171
+
3172
+ QUIET_INPUTS_DOCSTRING = r"""
3173
+ Args:
3174
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
3175
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
3176
+ it.
3177
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
3178
+ [`PreTrainedTokenizer.__call__`] for details.
3179
+ [What are input IDs?](../glossary#input-ids)
3180
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
3181
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
3182
+ - 1 for tokens that are **not masked**,
3183
+ - 0 for tokens that are **masked**.
3184
+ [What are attention masks?](../glossary#attention-mask)
3185
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
3186
+ [`PreTrainedTokenizer.__call__`] for details.
3187
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
3188
+ `past_key_values`).
3189
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
3190
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
3191
+ information on the default strategy.
3192
+ - 1 indicates the head is **not masked**,
3193
+ - 0 indicates the head is **masked**.
3194
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
3195
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
3196
+ config.n_positions - 1]`.
3197
+ [What are position IDs?](../glossary#position-ids)
3198
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
3199
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
3200
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
3201
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
3202
+ Two formats are allowed:
3203
+ - a [`~cache_utils.Cache`] instance;
3204
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
3205
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
3206
+ cache format.
3207
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
3208
+ legacy cache format will be returned.
3209
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
3210
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
3211
+ of shape `(batch_size, sequence_length)`.
3212
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
3213
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
3214
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
3215
+ model's internal embedding lookup matrix.
3216
+ use_cache (`bool`, *optional*):
3217
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
3218
+ `past_key_values`).
3219
+ output_attentions (`bool`, *optional*):
3220
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
3221
+ tensors for more detail.
3222
+ output_hidden_states (`bool`, *optional*):
3223
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
3224
+ more detail.
3225
+ return_dict (`bool`, *optional*):
3226
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
3227
+ """
3228
+
3229
+
3230
+ @add_start_docstrings(
3231
+ "The bare Quiet Model outputting raw hidden-states without any specific head on top.",
3232
+ QUIET_START_DOCSTRING,
3233
+ )
3234
+ class QuietModel(QuietPreTrainedModel):
3235
+ """
3236
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
3237
+ Args:
3238
+ config: QuietConfig
3239
+ """
3240
+
3241
+ def __init__(self, config: QuietConfig):
3242
+ super().__init__(config)
3243
+ self.padding_idx = config.pad_token_id
3244
+ self.vocab_size = config.vocab_size
3245
+
3246
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
3247
+ self.layers = nn.ModuleList(
3248
+ [QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
3249
+ )
3250
+ self._attn_implementation = config._attn_implementation
3251
+ self.norm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
3252
+
3253
+ self.gradient_checkpointing = False
3254
+ # Initialize weights and apply final processing
3255
+ self.post_init()
3256
+
3257
+ def get_input_embeddings(self):
3258
+ return self.embed_tokens
3259
+
3260
+ def set_input_embeddings(self, value):
3261
+ self.embed_tokens = value
3262
+
3263
+ @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
3264
+ def forward(
3265
+ self,
3266
+ input_ids: torch.LongTensor = None,
3267
+ attention_mask: Optional[torch.Tensor] = None,
3268
+ position_ids: Optional[torch.LongTensor] = None,
3269
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
3270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
3271
+ use_cache: Optional[bool] = None,
3272
+ output_attentions: Optional[bool] = None,
3273
+ output_hidden_states: Optional[bool] = None,
3274
+ return_dict: Optional[bool] = None,
3275
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
3276
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
3277
+ output_hidden_states = (
3278
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
3279
+ )
3280
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
3281
+
3282
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
3283
+
3284
+ # retrieve input_ids and inputs_embeds
3285
+ if input_ids is not None and inputs_embeds is not None:
3286
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
3287
+ elif input_ids is not None:
3288
+ batch_size, seq_length = input_ids.shape
3289
+ elif inputs_embeds is not None:
3290
+ batch_size, seq_length, _ = inputs_embeds.shape
3291
+ else:
3292
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
3293
+
3294
+ if self.gradient_checkpointing and self.training:
3295
+ if use_cache:
3296
+ logger.warning_once(
3297
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
3298
+ )
3299
+ use_cache = False
3300
+
3301
+ past_key_values_length = 0
3302
+
3303
+ if use_cache:
3304
+ use_legacy_cache = not isinstance(past_key_values, Cache)
3305
+ if use_legacy_cache:
3306
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
3307
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
3308
+
3309
+ if position_ids is None:
3310
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
3311
+ position_ids = torch.arange(
3312
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
3313
+ )
3314
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
3315
+ else:
3316
+ position_ids = position_ids.view(-1, seq_length).long()
3317
+
3318
+ if inputs_embeds is None:
3319
+ inputs_embeds = self.embed_tokens(input_ids)
3320
+
3321
+ if self._attn_implementation == "flash_attention_2":
3322
+ # 2d mask is passed through the layers
3323
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
3324
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask is not None and attention_mask.dim() == 2:
3325
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
3326
+ # the manual implementation that requires a 4D causal mask in all cases.
3327
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
3328
+ attention_mask,
3329
+ (batch_size, seq_length),
3330
+ inputs_embeds,
3331
+ past_key_values_length,
3332
+ )
3333
+ elif attention_mask is None or (attention_mask is not None and attention_mask.dim() == 2):
3334
+ # 4d mask is passed through the layers
3335
+ attention_mask = _prepare_4d_causal_attention_mask(
3336
+ attention_mask,
3337
+ (batch_size, seq_length),
3338
+ inputs_embeds,
3339
+ past_key_values_length,
3340
+ sliding_window=self.config.sliding_window,
3341
+ )
3342
+
3343
+
3344
+ hidden_states = inputs_embeds
3345
+
3346
+ # decoder layers
3347
+ all_hidden_states = () if output_hidden_states else None
3348
+ all_self_attns = () if output_attentions else None
3349
+ next_decoder_cache = None
3350
+
3351
+ for decoder_layer in self.layers:
3352
+ if output_hidden_states:
3353
+ all_hidden_states += (hidden_states,)
3354
+
3355
+ if self.gradient_checkpointing and self.training:
3356
+ layer_outputs = self._gradient_checkpointing_func(
3357
+ decoder_layer.__call__,
3358
+ hidden_states,
3359
+ attention_mask,
3360
+ position_ids,
3361
+ past_key_values,
3362
+ output_attentions,
3363
+ use_cache,
3364
+ )
3365
+ else:
3366
+ layer_outputs = decoder_layer(
3367
+ hidden_states,
3368
+ attention_mask=attention_mask,
3369
+ position_ids=position_ids,
3370
+ past_key_value=past_key_values,
3371
+ output_attentions=output_attentions,
3372
+ use_cache=use_cache,
3373
+ )
3374
+
3375
+ hidden_states = layer_outputs[0]
3376
+
3377
+ if use_cache:
3378
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
3379
+
3380
+ if output_attentions:
3381
+ all_self_attns += (layer_outputs[1],)
3382
+
3383
+ hidden_states = self.norm(hidden_states)
3384
+
3385
+ # add hidden states from the last decoder layer
3386
+ if output_hidden_states:
3387
+ all_hidden_states += (hidden_states,)
3388
+
3389
+ next_cache = None
3390
+ if use_cache:
3391
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
3392
+
3393
+ if not return_dict:
3394
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
3395
+ return BaseModelOutputWithPast(
3396
+ last_hidden_state=hidden_states,
3397
+ past_key_values=next_cache,
3398
+ hidden_states=all_hidden_states,
3399
+ attentions=all_self_attns,
3400
+ )
3401
+
3402
+ def nonzero_mean(x, axis=None):
3403
+ if axis is not None:
3404
+ return x.sum(axis) / (x != 0).sum(axis)
3405
+ return x.sum() / (x != 0).sum()
3406
+
3407
+ def loss_mean(x):
3408
+ return x.sum() / (x != 0).sum()
3409
+
3410
+ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
3411
+ _tied_weights_keys = ["lm_head.weight"]
3412
+
3413
+ def __init__(self, config):
3414
+ super().__init__(config)
3415
+ self.model = QuietModel(config)
3416
+ self.vocab_size = config.vocab_size
3417
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
3418
+ # self.router_aux_loss_coef = config.router_aux_loss_coef
3419
+ # self.num_experts = config.num_experts
3420
+ # self.num_experts_per_tok = config.num_experts_per_tok
3421
+ self.max_thoughts = config.max_thoughts
3422
+ self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
3423
+ self.use_concat_talk_head = config.use_concat_talk_head
3424
+ self.use_shallow_talk = config.use_shallow_talk
3425
+ self.use_complex_talk_head = config.use_complex_talk_head
3426
+ self.use_weighted_talk_head = config.use_weighted_talk_head
3427
+ # the weighted head will output a single value, so it can't be passed to the lm head
3428
+ assert not (self.use_weighted_talk_head and self.use_shallow_talk)
3429
+
3430
+ self.n_ahead = 1
3431
+ self.n_ahead_talk = 1
3432
+ self.n_passes = 1
3433
+ self.n_tokens_print = 1
3434
+ self.gradient_accumulation_steps = 1
3435
+ self.training_steps = 0
3436
+ self.tokenizer = AutoTokenizer.from_pretrained("LeroyDyer/Mixtral_AI_Cyber_Q")
3437
+ self.start_token_id = None
3438
+ self.end_token_id = None
3439
+ self.rm_initialized = False
3440
+ self.residual_talk_head = True
3441
+ self.thought_init_std_scale = 1e-2
3442
+
3443
+ self.final_only_mode = False
3444
+ self.first_and_last_mode = True
3445
+ self.first_only = False
3446
+ self.original_loss_weight = 0.5
3447
+
3448
+ self.cumulative_residual = False
3449
+ self.clever_residual = False
3450
+ self.skip_residual = False
3451
+ self.no_residual = True
3452
+
3453
+ self.optimize_lm_head_only_at_start = False
3454
+ self.optimize_model_only_at_start = False
3455
+
3456
+ if self.optimize_model_only_at_start:
3457
+ raise NotImplementedError
3458
+ self.train_only_thinking_embedding = False
3459
+ self.weighted_embeddings = False
3460
+ self.use_start_thought_token = True
3461
+ self.use_end_thought_token = True
3462
+ self.initialize_thought_embedding_to_normal = False
3463
+ self.initial_start_token = "---"
3464
+ self.initial_end_token = "---"
3465
+ self.output_logits_at_the_end = True
3466
+
3467
+ self.wandb_enabled = False
3468
+ self.gumbel_temperature = 0.001
3469
+
3470
+ self.use_policy_loss = True
3471
+ self.include_policy_loss = True
3472
+ self.trice_mode = True
3473
+ self.remove_negative_rewards = True
3474
+ self.use_policy_loss_for_end_thought = True
3475
+
3476
+ self.base_original_mode = False
3477
+ self.original_mode = False
3478
+
3479
+ self.thought_prefix = "(Let's think step by step"
3480
+ self.tokenized_thought_prefix = None
3481
+ self.log_dict = defaultdict(int)
3482
+ self.eval_log_dict = defaultdict(int)
3483
+ self.loss_mean = loss_mean
3484
+
3485
+ self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
3486
+ self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
3487
+
3488
+ self.policy_loss_beta = 1e6
3489
+ self.embedding_scale = 1e2
3490
+ self.temperature = nn.Parameter(torch.ones(1))
3491
+ self.max_temperature = config.max_temperature
3492
+ self.reinforce_temperature = 3
3493
+ self.base_loss_beta = 1
3494
+ self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
3495
+ self.thinking_threshold = 0.5
3496
+ self.thinking_usefulness_loss_weight = 1e-2
3497
+
3498
+ # Not used in the paper:
3499
+ self.use_thought_prefix = False
3500
+ self.use_reparam_for_thought_embeddings = False
3501
+ self.use_upper_triangular = False
3502
+ self.subtract_mean_reward = False
3503
+ self.comparison_mode = False
3504
+ self.gumbel_detach = False
3505
+
3506
+ # For visualization
3507
+ self.eval_mode = False
3508
+
3509
+ num_talk = 1
3510
+ talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
3511
+ if self.use_weighted_talk_head:
3512
+ talk_output_dim = 1
3513
+ else:
3514
+ talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
3515
+
3516
+ if not self.merged_lm_and_talk_heads:
3517
+ if self.use_complex_talk_head:
3518
+ self.talk_head = nn.ModuleList([nn.Sequential(
3519
+ nn.Linear(talk_input_dim, config.hidden_size),
3520
+ nn.ReLU(),
3521
+ nn.Linear(config.hidden_size, config.hidden_size),
3522
+ nn.ReLU(),
3523
+ nn.Linear(config.hidden_size, talk_output_dim, bias=False)
3524
+ )])
3525
+ else:
3526
+ self.talk_head = nn.ModuleList([nn.Sequential(
3527
+ nn.Linear(talk_input_dim, talk_output_dim, bias=False)
3528
+ )])
3529
+
3530
+ self.apply(self._init_weights)
3531
+
3532
+ # Add dropout regularization
3533
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
3534
+
3535
+ # Initialize weights and apply final processing
3536
+ self.post_init()
3537
+
3538
+ def get_input_embeddings(self):
3539
+ return self.model.embed_tokens
3540
+
3541
+ def set_input_embeddings(self, value):
3542
+ self.model.embed_tokens = value
3543
+
3544
+ def get_output_embeddings(self):
3545
+ return self.lm_head
3546
+
3547
+ def set_output_embeddings(self, new_embeddings):
3548
+ self.lm_head = new_embeddings
3549
+
3550
+ def set_decoder(self, decoder):
3551
+ self.model = decoder
3552
+
3553
+ def get_decoder(self):
3554
+ return self.model
3555
+
3556
+ def _init_weights(self, module):
3557
+ if isinstance(module, nn.Linear):
3558
+ nn.init.xavier_uniform_(module.weight)
3559
+ if module.bias is not None:
3560
+ nn.init.constant_(module.bias, 0)
3561
+ elif isinstance(module, nn.Embedding):
3562
+ nn.init.xavier_uniform_(module.weight)
3563
+
3564
+ @torch.no_grad()
3565
+ def infer(
3566
+ self,
3567
+ input_ids: torch.LongTensor,
3568
+ attention_mask: Optional[torch.Tensor] = None,
3569
+ position_ids: Optional[torch.LongTensor] = None,
3570
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
3571
+ inputs_embeds: Optional[torch.FloatTensor] = None,
3572
+ use_cache: Optional[bool] = None,
3573
+ output_attentions: Optional[bool] = None,
3574
+ output_hidden_states: Optional[bool] = None,
3575
+ return_dict: Optional[bool] = None,
3576
+ ):
3577
+ batch_size, seq_len = input_ids.shape
3578
+
3579
+ # Save the original input_ids and attention_mask for later use
3580
+ original_input_ids = input_ids.clone()
3581
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
3582
+
3583
+ # Append the start thought token to the input sequence
3584
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
3585
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
3586
+ seq_len += 1
3587
+
3588
+ # Update the attention mask
3589
+ if attention_mask is not None:
3590
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
3591
+
3592
+ # Generate the continuation
3593
+ continuation_length = self.n_ahead - 2
3594
+ new_key_values = past_key_values
3595
+
3596
+ # Initialize next_token_id with a default value
3597
+ next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
3598
+
3599
+ start_time = time.time()
3600
+ for continuation_idx in range(continuation_length):
3601
+ outputs = self.model(
3602
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
3603
+ attention_mask=attention_mask,
3604
+ position_ids=position_ids,
3605
+ past_key_values=new_key_values,
3606
+ inputs_embeds=inputs_embeds,
3607
+ use_cache=True,
3608
+ output_attentions=output_attentions,
3609
+ output_hidden_states=output_hidden_states,
3610
+ return_dict=return_dict,
3611
+ )
3612
+ new_key_values = outputs.past_key_values
3613
+
3614
+ hidden_states = outputs[0]
3615
+
3616
+ logits = self.lm_head(hidden_states)
3617
+ logits = logits[:, -1, :] # Only consider the last token
3618
+
3619
+ # Apply Gumbel-Softmax to the logits
3620
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
3621
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
3622
+
3623
+ # Append the generated token to the input sequence
3624
+ # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
3625
+ seq_len += 1
3626
+
3627
+ # Update the attention mask
3628
+ if attention_mask is not None:
3629
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
3630
+
3631
+ # Append the end thought token to the input sequence
3632
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
3633
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
3634
+ seq_len += 1
3635
+
3636
+ # Update the attention mask
3637
+ if attention_mask is not None:
3638
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
3639
+
3640
+ # Get the hidden states before and after the thought
3641
+ outputs_before = self.model(
3642
+ input_ids=original_input_ids,
3643
+ attention_mask=original_attention_mask,
3644
+ position_ids=position_ids,
3645
+ past_key_values=past_key_values,
3646
+ inputs_embeds=inputs_embeds,
3647
+ use_cache=use_cache,
3648
+ output_attentions=output_attentions,
3649
+ output_hidden_states=output_hidden_states,
3650
+ return_dict=return_dict,
3651
+ )
3652
+ hidden_states_before = outputs_before[0][:, -1:, :]
3653
+
3654
+ # two new tokens: last continuation token and end thought token
3655
+ outputs_after = self.model(
3656
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
3657
+ attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
3658
+ position_ids=position_ids,
3659
+ past_key_values=new_key_values,
3660
+ inputs_embeds=inputs_embeds,
3661
+ use_cache=use_cache,
3662
+ output_attentions=output_attentions,
3663
+ output_hidden_states=output_hidden_states,
3664
+ return_dict=return_dict,
3665
+ )
3666
+ hidden_states_after = outputs_after[0][:, -1:, :]
3667
+
3668
+ # Apply the talk head to get the mixing weight
3669
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
3670
+
3671
+ # Apply the mixing weight to the hidden states
3672
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
3673
+
3674
+ # Apply the language model head to get the final logits
3675
+ logits = self.lm_head(mixed_hidden_states)
3676
+ return logits
3677
+
3678
+ @torch.no_grad()
3679
+ def generate(
3680
+ self,
3681
+ input_ids: torch.LongTensor = torch.LongTensor(),
3682
+ attention_mask: Optional[torch.Tensor] = None,
3683
+ max_new_tokens: Optional[int] = None,
3684
+ temperature: float = 1.1,
3685
+ **kwargs,
3686
+ ):
3687
+ if isinstance(input_ids, str):
3688
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
3689
+
3690
+ if attention_mask is None:
3691
+ # Create a default attention mask if not provided
3692
+ attention_mask = torch.ones_like(input_ids)
3693
+
3694
+ from .generate import generate
3695
+ return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
3696
+
3697
+ @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
3698
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
3699
+ def forward(
3700
+ self,
3701
+ input_ids: torch.LongTensor = None,
3702
+ attention_mask: Optional[torch.Tensor] = None,
3703
+ position_ids: Optional[torch.LongTensor] = None,
3704
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
3705
+ inputs_embeds: Optional[torch.FloatTensor] = None,
3706
+ labels: Optional[torch.LongTensor] = None,
3707
+ use_cache: Optional[bool] = None,
3708
+ output_attentions: Optional[bool] = None,
3709
+ output_hidden_states: Optional[bool] = None,
3710
+ return_dict: Optional[bool] = None,
3711
+ max_new_tokens: Optional[int] = None,
3712
+ temperature: Optional[float] = None,
3713
+ temperature_last: Optional[float] = None,
3714
+ dynamic_temperature: Optional[float] = None,
3715
+ dynatemp_low: Optional[float] = None,
3716
+ dynatemp_high: Optional[float] = None,
3717
+ dynatemp_exponent: Optional[float] = None,
3718
+ smoothing_factor: Optional[float] = None,
3719
+ smoothing_curve: Optional[str] = None,
3720
+ top_p: Optional[float] = None,
3721
+ min_p: Optional[float] = None,
3722
+ top_k: Optional[int] = None,
3723
+ repetition_penalty: Optional[float] = None,
3724
+ presence_penalty: Optional[float] = None,
3725
+ frequency_penalty: Optional[float] = None,
3726
+ repetition_penalty_range: Optional[int] = None,
3727
+ typical_p: Optional[float] = None,
3728
+ tfs: Optional[float] = None,
3729
+ top_a: Optional[float] = None,
3730
+ guidance_scale: Optional[float] = None,
3731
+ penalty_alpha: Optional[float] = None,
3732
+ mirostat_mode: Optional[int] = None,
3733
+ mirostat_tau: Optional[float] = None,
3734
+ mirostat_eta: Optional[float] = None,
3735
+ do_sample: Optional[bool] = None,
3736
+ encoder_repetition_penalty: Optional[float] = None,
3737
+ no_repeat_ngram_size: Optional[int] = None,
3738
+ sampler_priority: Optional[List[str]] = None,
3739
+ negative_prompt_ids: Optional[List[int]] = None,
3740
+ prompt_lookup_num_tokens: Optional[int] = None,
3741
+ epsilon_cutoff: Optional[float] = None,
3742
+ eta_cutoff: Optional[float] = None,
3743
+ max_length: Optional[int] = None,
3744
+ suppress_tokens: Optional[List[int]] = None,
3745
+ synced_gpus: Optional[bool] = None,
3746
+ eos_token_id: Optional[List[int]] = None,
3747
+ stopping_criteria: Optional[transformers.StoppingCriteriaList] = None,
3748
+ logits_processor: Optional[transformers.LogitsProcessorList] = None,
3749
+ inputs: Optional[torch.LongTensor] = None,
3750
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
3751
+ r"""
3752
+ Args:
3753
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
3754
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
3755
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
3756
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
3757
+ Returns:
3758
+ Example:
3759
+ ```python
3760
+ >>> from transformers import AutoTokenizer, QuietForCausalLM
3761
+ >>> model = QuietForCausalLM.from_pretrained("quietai/Quiet-7B-v0.1")
3762
+ >>> tokenizer = AutoTokenizer.from_pretrained("quietai/Quiet-7B-v0.1")
3763
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
3764
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
3765
+ >>> # Generate
3766
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
3767
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
3768
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
3769
+ ```"""
3770
+
3771
+ if not self.training:
3772
+ n_ahead_talk_to_restore = self.n_ahead_talk
3773
+ n_passes_to_restore = self.n_passes
3774
+ self.n_ahead_talk = 1
3775
+ self.n_passes = 1
3776
+
3777
+ # aux_loss = None
3778
+ # output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
3779
+ # if output_router_logits:
3780
+ # router_logits = outputs.router_logits if return_dict else outputs[-1]
3781
+ # if router_logits is not None:
3782
+ # aux_loss = load_balancing_loss_func(
3783
+ # router_logits,
3784
+ # self.num_experts,
3785
+ # self.num_experts_per_tok,
3786
+ # attention_mask,
3787
+ # )
3788
+ # if labels is not None:
3789
+ # loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
3790
+
3791
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
3792
+ output_hidden_states = (
3793
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
3794
+ )
3795
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
3796
+
3797
+ assert self.cumulative_residual or self.clever_residual or self.skip_residual or self.no_residual
3798
+ assert not (self.skip_residual and self.use_policy_loss)
3799
+
3800
+ if self.tokenized_thought_prefix is None and self.use_thought_prefix:
3801
+ self.tokenized_thought_prefix = self.tokenizer(self.thought_prefix, return_tensors="pt", add_special_tokens=False)["input_ids"]
3802
+
3803
+ def apply_head(head, states, detach=False):
3804
+ if detach:
3805
+ head_weight = head.weight.detach()
3806
+ else:
3807
+ head_weight = head.weight
3808
+ head_weight = head_weight.to(states.device)
3809
+ return (head_weight @ states.transpose(-1, -2)).transpose(-1, -2).contiguous()
3810
+
3811
+ def idx_if_sequential(head, idx=0):
3812
+ if isinstance(head, nn.Sequential) or isinstance(head, nn.ModuleList):
3813
+ return idx_if_sequential(head[idx], idx=idx)
3814
+ return head
3815
+
3816
+ def none_repeat_interleave(x, n):
3817
+ if x is None:
3818
+ return x
3819
+ return x.repeat_interleave(n, dim=0)
3820
+
3821
+ if self.n_passes > 1:
3822
+ input_ids = none_repeat_interleave(input_ids, self.n_passes)
3823
+ attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
3824
+ position_ids = none_repeat_interleave(position_ids, self.n_passes)
3825
+ inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
3826
+ labels = none_repeat_interleave(labels, self.n_passes)
3827
+ if past_key_values is not None:
3828
+ past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
3829
+ cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
3830
+
3831
+ self.tokenizer_has_start_thought_token = True
3832
+ self.tokenizer_has_end_thought_token = True
3833
+ if self.start_token_id is None:
3834
+ self.start_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
3835
+ if self.start_token_id == 0:
3836
+ self.start_token_id = self.tokenizer.bos_token_id
3837
+ self.tokenizer_has_start_thought_token = False
3838
+ elif self.use_start_thought_token:
3839
+ # base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token)
3840
+ base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
3841
+ if self.initialize_thought_embedding_to_normal:
3842
+ self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
3843
+ else:
3844
+ self.start_embedding.data[0] = self.model.embed_tokens.weight.data[base_start_id].clone().detach() / self.embedding_scale
3845
+ self.start_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
3846
+ if self.end_token_id is None:
3847
+ self.end_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
3848
+ if self.end_token_id == 0:
3849
+ self.end_token_id = self.tokenizer.eos_token_id
3850
+ self.tokenizer_has_end_thought_token = False
3851
+ elif self.use_end_thought_token:
3852
+ # base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token)
3853
+ base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
3854
+ if self.initialize_thought_embedding_to_normal:
3855
+ self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
3856
+ else:
3857
+ self.end_embedding.data[0] = self.model.embed_tokens.weight.data[base_end_id].clone().detach() / self.embedding_scale
3858
+ self.end_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
3859
+
3860
+ if not self.rm_initialized and (self.n_ahead > 1 or not self.base_original_mode):
3861
+ self.rm_initialized = True
3862
+ if not self.use_shallow_talk:
3863
+ head = self.talk_head[0]
3864
+ cur_head = head[-1] if isinstance(head, nn.Sequential) else head
3865
+ talk_input_dim = cur_head.weight.data.shape[1]
3866
+ talk_output_dim = 1 if self.use_weighted_talk_head else self.lm_head.weight.data.shape[0]
3867
+ cur_head.weight.data = torch.zeros(talk_output_dim, talk_input_dim, device=cur_head.weight.device, dtype=cur_head.weight.dtype)
3868
+ else:
3869
+ # convert to identity transform
3870
+ def lambda_transform(cur_head):
3871
+ # pdb.set_trace()
3872
+ if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
3873
+ return torch.cat([
3874
+ torch.eye(
3875
+ cur_head.weight.data.shape[0],
3876
+ device=cur_head.weight.device,
3877
+ dtype=cur_head.weight.dtype
3878
+ ),
3879
+ torch.zeros(
3880
+ cur_head.weight.data.shape[0],
3881
+ cur_head.weight.data.shape[1] - cur_head.weight.data.shape[0],
3882
+ device=cur_head.weight.device,
3883
+ dtype=cur_head.weight.dtype
3884
+ )], dim=1)
3885
+ return torch.eye(
3886
+ cur_head.weight.data.shape[0],
3887
+ device=cur_head.weight.device,
3888
+ dtype=cur_head.weight.dtype
3889
+ )
3890
+ if isinstance(self.talk_head[0], nn.Sequential):
3891
+ for cur_head in self.talk_head[0]:
3892
+ # if it has weights
3893
+ if hasattr(cur_head, "weight"):
3894
+ cur_head.weight.data = lambda_transform(cur_head)
3895
+ else:
3896
+ self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
3897
+
3898
+ loss = None
3899
+ prev_rm_tokens = None
3900
+ cur_rm_tokens = None
3901
+ prev_rm_logits = None
3902
+ prev_sample_probs = None
3903
+ did_skip_sampling = None
3904
+ skip_sampling = None
3905
+ sample_probs = None
3906
+ hidden_states = None
3907
+ logits = None
3908
+ talk_kl_penalty = None
3909
+ rm_logits = None
3910
+ residual_logits = None
3911
+ probabilities_2d = None
3912
+ prev_probabilities_2d = None
3913
+ policy_reward = None
3914
+ logits_to_output = None
3915
+ batch_size, seq_len = input_ids.shape
3916
+ base_input_ids = input_ids.clone()
3917
+ loss_list = []
3918
+ dqn_loss_list = []
3919
+ sampled_token_history = []
3920
+ sample_probs_history = []
3921
+ action_loglikelihoods_list = []
3922
+
3923
+ temperature = self.temperature
3924
+
3925
+ if self.use_end_thought_token or self.use_start_thought_token:
3926
+ if not self.use_reparam_for_thought_embeddings:
3927
+ start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
3928
+ end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
3929
+ else:
3930
+ start_embedding = self.start_embedding * self.embedding_scale * temperature
3931
+ end_embedding = self.end_embedding * self.embedding_scale * temperature
3932
+ base_embeddings = self.model.embed_tokens.weight
3933
+ if self.train_only_thinking_embedding:
3934
+ base_embeddings = base_embeddings.detach()
3935
+
3936
+ # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
3937
+ fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
3938
+ for ahead_idx in range(fwd_iters):
3939
+ past_key_values_length = 0
3940
+ if past_key_values is not None:
3941
+ use_legacy_cache = not isinstance(past_key_values, Cache)
3942
+ if use_legacy_cache:
3943
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
3944
+ past_key_values_length = past_key_values.get_usable_length(seq_len)
3945
+
3946
+ if position_ids is None:
3947
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
3948
+ position_ids = torch.arange(
3949
+ past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
3950
+ )
3951
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
3952
+ else:
3953
+ position_ids = position_ids.view(-1, seq_len).long()
3954
+
3955
+ if inputs_embeds is None:
3956
+ contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
3957
+ contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
3958
+ contains_thought = contains_start or contains_end
3959
+ if contains_thought:
3960
+ thought_id = self.start_token_id if contains_start else self.end_token_id
3961
+ cur_thought_embedding = start_embedding if contains_start else end_embedding
3962
+ if self.use_reparam_for_thought_embeddings:
3963
+ inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
3964
+ inputs_embeds = inputs_embeds.detach() * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
3965
+ if contains_start:
3966
+ sampled_start = inputs_embeds.clone().detach()
3967
+ if contains_end:
3968
+ sampled_end = inputs_embeds.clone().detach()
3969
+ else:
3970
+ inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
3971
+ else:
3972
+ with torch.set_grad_enabled(not self.train_only_thinking_embedding):
3973
+ inputs_embeds = self.model.embed_tokens(input_ids)
3974
+
3975
+ if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
3976
+ if attention_mask is None:
3977
+ base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
3978
+ base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
3979
+ base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
3980
+ attention_mask = base_attention_mask
3981
+ # breakpoint()
3982
+ elif attention_mask.dim() == 2:
3983
+ if seq_len + past_key_values_length != attention_mask.shape[-1]:
3984
+ # breakpoint()
3985
+ attention_mask = torch.cat(
3986
+ [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
3987
+ dim=-1
3988
+ )
3989
+ # # if the attention mask
3990
+ attention_mask = _prepare_4d_causal_attention_mask(
3991
+ attention_mask,
3992
+ (batch_size, seq_len),
3993
+ inputs_embeds,
3994
+ past_key_values_length,
3995
+ sliding_window=self.config.sliding_window,
3996
+ )
3997
+
3998
+ outputs = self.model(
3999
+ # input_ids=input_ids,
4000
+ attention_mask=attention_mask,
4001
+ position_ids=position_ids,
4002
+ past_key_values=past_key_values,
4003
+ inputs_embeds=inputs_embeds,
4004
+ use_cache=use_cache,
4005
+ output_attentions=output_attentions,
4006
+ output_hidden_states=output_hidden_states,
4007
+ # output_router_logits=output_router_logits,
4008
+ return_dict=return_dict,
4009
+ )
4010
+
4011
+ prev_hidden_states = hidden_states
4012
+ hidden_states = outputs[0]
4013
+ prev_rm_logits = rm_logits # for policy gradient
4014
+ prev_rm_tokens = cur_rm_tokens # for policy gradient
4015
+
4016
+ if ahead_idx == 0:
4017
+ hidden_states_lm = hidden_states
4018
+ logits = self.lm_head(hidden_states_lm)
4019
+ base_hidden_states = hidden_states.clone()
4020
+ initial_loss_logits = logits.clone()
4021
+ if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
4022
+ logits = logits.detach()
4023
+ base_hidden_states = base_hidden_states.detach()
4024
+ if self.optimize_model_only_at_start:
4025
+ hidden_states = hidden_states.detach()
4026
+ base_logits = logits.clone()
4027
+ else:
4028
+ talk_hidden_states = hidden_states
4029
+ if self.merged_lm_and_talk_heads:
4030
+ assert self.no_residual
4031
+ residual_logits = self.lm_head(hidden_states)
4032
+ talk_hidden_states = hidden_states
4033
+ else:
4034
+ if ahead_idx > self.n_ahead - 1:
4035
+ cur_base_hidden = torch.cat([
4036
+ base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
4037
+ base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
4038
+ ], dim=-2)
4039
+ else:
4040
+ cur_base_hidden = base_hidden_states
4041
+
4042
+ if self.use_concat_talk_head:
4043
+ # concatenate the hidden states with the original hidden states
4044
+ head_input_hidden_states = torch.cat([cur_base_hidden, talk_hidden_states], dim=-1)
4045
+ else:
4046
+ head_input_hidden_states = talk_hidden_states
4047
+
4048
+ residual_logits = self.talk_head[0](head_input_hidden_states)
4049
+ if self.use_shallow_talk:
4050
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
4051
+ residual_logits = residual_logits.to(logits.device)
4052
+ if self.use_weighted_talk_head:
4053
+ # combine the cur_base_hidden with the talk_hidden_states according to the weighted head
4054
+ residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits
4055
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
4056
+
4057
+ assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
4058
+ if self.clever_residual:
4059
+ if ahead_idx >= self.n_ahead - 1:
4060
+ # get the logits shifted according to the current talk ahead
4061
+ cur_base_logits = torch.cat([
4062
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
4063
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
4064
+ ], dim=-2)
4065
+ if self.optimize_lm_head_only_at_start:
4066
+ cur_base_logits = cur_base_logits.detach()
4067
+ logits = cur_base_logits + residual_logits
4068
+ else:
4069
+ logits += residual_logits / self.n_ahead
4070
+ elif self.cumulative_residual:
4071
+ if self.residual_talk_head:
4072
+ if ahead_idx < self.n_ahead:
4073
+ logits += residual_logits
4074
+ else:
4075
+ # get the logits shifted according to the current talk ahead
4076
+ cur_base_logits = torch.cat([
4077
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
4078
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
4079
+ ], dim=-2)
4080
+ if self.optimize_lm_head_only_at_start:
4081
+ cur_base_logits = cur_base_logits.detach()
4082
+ logits = cur_base_logits + residual_logits
4083
+ else:
4084
+ if ahead_idx < self.n_ahead:
4085
+ logits += residual_logits
4086
+ else:
4087
+ logits = residual_logits
4088
+ elif self.skip_residual:
4089
+ if ahead_idx >= self.n_ahead:
4090
+ # get the logits shifted according to the current talk ahead
4091
+ cur_base_logits = torch.cat([
4092
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
4093
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
4094
+ ], dim=-2)
4095
+ if self.optimize_lm_head_only_at_start:
4096
+ cur_base_logits = cur_base_logits.detach()
4097
+ logits = cur_base_logits
4098
+ elif self.no_residual:
4099
+ logits = residual_logits
4100
+ else:
4101
+ logits = base_logits + residual_logits
4102
+
4103
+ attempted = False
4104
+ talk_loss_list = []
4105
+ if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):# or (self.optimize_lm_head_only_at_start and ahead_idx == 0):
4106
+ loss = None
4107
+ attempted = True
4108
+
4109
+ if labels is not None:
4110
+ for shift_amount in range(self.n_ahead_talk):
4111
+ # Shift so that tokens < n predict n
4112
+ # ab[cde]f
4113
+ # abc[def]
4114
+ if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
4115
+ loss_logits = initial_loss_logits
4116
+ else:
4117
+ loss_logits = logits
4118
+ shift_logits = loss_logits[..., shift_amount:-1, :].contiguous()
4119
+ shift_labels = labels[..., 1 + shift_amount:].contiguous()
4120
+ # Flatten the tokens
4121
+ loss_fct = CrossEntropyLoss(reduction="none")
4122
+ # print("Shift logits before:", shift_logits)
4123
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
4124
+ shift_labels = shift_labels.view(-1).clone()
4125
+ # print("shift logits after:", shift_logits)
4126
+ # Enable model parallelism
4127
+ shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
4128
+ shift_labels = shift_labels.to(shift_logits.device)
4129
+ loss = loss_fct(shift_logits, shift_labels)
4130
+ if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
4131
+ loss_list.append(loss)
4132
+ talk_loss_list.append(nonzero_mean(loss).detach())
4133
+
4134
+ if not attempted or self.comparison_mode:
4135
+ rm_hidden_states = hidden_states
4136
+ # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
4137
+ rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
4138
+
4139
+ # don't allow it to predict the thinking token
4140
+ if self.tokenizer_has_start_thought_token:
4141
+ rm_logits[..., self.start_token_id] = -1e10
4142
+ if self.tokenizer_has_end_thought_token:
4143
+ rm_logits[..., self.end_token_id] = -1e10
4144
+ probabilities = rm_logits
4145
+ if probabilities_2d is not None:
4146
+ prev_probabilities_2d = probabilities_2d.clone()
4147
+ probabilities_2d = probabilities.view(-1, probabilities.size(-1))
4148
+
4149
+ did_skip_sampling = skip_sampling
4150
+ skip_sampling = False
4151
+ if ahead_idx == 0 and self.use_start_thought_token:
4152
+ override_token = self.start_token_id
4153
+ elif self.use_thought_prefix and ahead_idx < self.tokenized_thought_prefix.shape[-1]:
4154
+ override_token = self.tokenized_thought_prefix[..., ahead_idx]
4155
+ elif ahead_idx == self.n_ahead - 2 and self.use_end_thought_token:
4156
+ override_token = self.end_token_id
4157
+ else:
4158
+ override_token = None
4159
+ if override_token is not None and self.n_ahead > 1:
4160
+ # always start with the start token
4161
+ probabilities_2d = torch.zeros_like(probabilities_2d)
4162
+ probabilities_2d[:, override_token] = 1.0
4163
+ skip_sampling = True
4164
+ elif ahead_idx >= self.n_ahead - 1:
4165
+ if labels is not None: # we're in the talk phase
4166
+ cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
4167
+ # print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
4168
+ shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
4169
+ padding = torch.full_like(
4170
+ labels[..., :cur_talk_n],
4171
+ self.tokenizer.pad_token_id,
4172
+ dtype=torch.long,
4173
+ device=shift_labels.device
4174
+ )
4175
+ new_rm_tokens = torch.cat(
4176
+ [shift_labels, padding],
4177
+ dim=-1
4178
+ )
4179
+
4180
+ # print((new_rm_tokens > self.vocab_size - 1).any().item())
4181
+ new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
4182
+
4183
+ # Now safely convert rm tokens to one-hot
4184
+ probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
4185
+ else:
4186
+ continue
4187
+ temperature = self.gumbel_temperature if self.training else 0.001
4188
+ prev_sample_probs = sample_probs
4189
+ sample_probs = probabilities_2d
4190
+ if ahead_idx < self.n_ahead - 1 and not skip_sampling:
4191
+ probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
4192
+ if self.gumbel_detach:
4193
+ probabilities_2d = probabilities_2d.detach()
4194
+ sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
4195
+ # convert rm logits directly to embeddings
4196
+ contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
4197
+ contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
4198
+ contains_thought = contains_start or contains_end
4199
+
4200
+
4201
+ if not contains_thought:
4202
+ with torch.set_grad_enabled(not self.train_only_thinking_embedding):
4203
+ inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype) * temperature)
4204
+ else:
4205
+ thought_id = self.start_token_id if contains_start else self.end_token_id
4206
+ cur_thought_embedding = start_embedding if contains_start else end_embedding
4207
+ if self.use_reparam_for_thought_embeddings:
4208
+ inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
4209
+ inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
4210
+ if contains_start:
4211
+ sampled_start = inputs_embeds.clone().detach()
4212
+ else:
4213
+ sampled_end = inputs_embeds.clone().detach()
4214
+ else:
4215
+ inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
4216
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
4217
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
4218
+
4219
+ # Predict the usefulness of thinking at each token position
4220
+ thinking_usefulness = self.thinking_usefulness_head(hidden_states).squeeze(-1)
4221
+
4222
+ # Apply a threshold to decide where to generate thoughts
4223
+ generate_thought_mask = thinking_usefulness > self.thinking_threshold
4224
+
4225
+ # Compute the regularization loss for thinking usefulness prediction
4226
+ thinking_usefulness_loss = torch.mean(thinking_usefulness * (1 - generate_thought_mask.float()))
4227
+
4228
+ # Add the regularization loss to the total loss
4229
+ if loss is not None:
4230
+ loss = loss + self.thinking_usefulness_loss_weight * thinking_usefulness_loss
4231
+ else:
4232
+ loss = self.thinking_usefulness_loss_weight * thinking_usefulness_loss
4233
+
4234
+
4235
+ if len(attention_mask.shape) == 2:
4236
+ breakpoint()
4237
+ else:
4238
+ original_attention = attention_mask[..., :attention_mask.shape[-2]]
4239
+ if self.use_upper_triangular:
4240
+ new_attention = original_attention
4241
+ else:
4242
+ original_attention = original_attention == attention_mask.max()
4243
+ # because eye isn't implemented for BF16, we need to handle the case
4244
+ if not attention_mask.dtype == torch.bfloat16:
4245
+ new_attention = torch.eye(
4246
+ seq_len, dtype=attention_mask.dtype, device=attention_mask.device
4247
+ )
4248
+ else:
4249
+ new_attention = torch.eye(
4250
+ seq_len, dtype=torch.float32, device=attention_mask.device
4251
+ ).to(attention_mask.dtype)
4252
+
4253
+ new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
4254
+ new_attention = new_attention * original_attention
4255
+ new_attention[new_attention == 0] = attention_mask.min()
4256
+ new_attention[new_attention == 1] = attention_mask.max()
4257
+ attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
4258
+ past_key_values = outputs.past_key_values
4259
+ position_ids = position_ids + 1
4260
+
4261
+ if labels is not None and (self.n_ahead > 1 or not self.base_original_mode):
4262
+ # Shift so that tokens < n predict n
4263
+ # logits: abcdef -> bcdef? -> cdef??
4264
+ # labels: abcdef -> ?bcdef -> ??cdef
4265
+ if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
4266
+ loss_logits = initial_loss_logits
4267
+ else:
4268
+ loss_logits = logits
4269
+ shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
4270
+ shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
4271
+ shift_labels = labels[..., shift_idx:].contiguous()
4272
+ # Flatten the tokens
4273
+ loss_fct = CrossEntropyLoss(reduction="none")
4274
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
4275
+ shift_labels = shift_labels.view(-1)
4276
+ # Enable model parallelism
4277
+ shift_labels = shift_labels.to(shift_logits.device)
4278
+ # if shift_labels.min() == self.tokenizer.pad_token_id:
4279
+ shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
4280
+ unreduced_loss = loss_fct(shift_logits, shift_labels)
4281
+ # print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
4282
+ if torch.any(unreduced_loss != unreduced_loss):
4283
+ # pdb.set_trace()
4284
+ raise ValueError("NaN loss")
4285
+ unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
4286
+ loss_list.append(unreduced_loss)
4287
+
4288
+
4289
+ if self.use_policy_loss and ahead_idx > 0 and (ahead_idx > 1 or not self.use_start_thought_token):
4290
+ # we treat the change in loss as the reward
4291
+ previous_loss = loss_list[-2]
4292
+ # for example, suppose n_ahead = 3 and n_ahead_talk = 2
4293
+ # note that we end at self.n_ahead + self.n_ahead_talk - 2
4294
+ # in this case, 5 - 2 = 3, so we end at ahead_idx = 3
4295
+ # we also predict the next token at ahead_idx = 2
4296
+ # when we get to ahead_idx = 2, we predict ahead
4297
+ # so we shift by 1
4298
+ # note that this is ahead_idx = n_ahead - 1
4299
+ # when we get to ahead_idx = 3, we predict ahead
4300
+ # so we shift by 2
4301
+ # note that this is ahead_idx = n_ahead
4302
+ if ahead_idx < self.n_ahead - 1:
4303
+ shift_amount = 0
4304
+ reward_scale = 1.0
4305
+ original_dqn_reward = torch.sign(previous_loss - unreduced_loss).detach() * reward_scale
4306
+ if self.first_and_last_mode:
4307
+ original_dqn_reward = original_dqn_reward * 0.0
4308
+ else:
4309
+ # logits vs cur_policy_shift_logits
4310
+ # let's look at rm_logits and prev_rm_logits
4311
+ shift_amount = max(0, ahead_idx - (self.n_ahead - 1))
4312
+ # let's say shift_amount = 2
4313
+ # abcdefg -> bcdefg? -> cdefg??
4314
+ # logits = [a b]c d e f[g]
4315
+ # labels = [a b c]d e f g
4316
+ cur_policy_shift_logits = initial_loss_logits[..., shift_amount:-1, :].contiguous().detach()
4317
+ cur_policy_shift_labels = labels[..., 1 + shift_amount:].contiguous()
4318
+ # Flatten the tokens
4319
+ cur_policy_loss_fct = CrossEntropyLoss(reduction="none")
4320
+ cur_policy_shift_logits = cur_policy_shift_logits.view(-1, self.config.vocab_size)
4321
+ cur_policy_shift_labels = cur_policy_shift_labels.view(-1).clone()
4322
+ # Enable model parallelism
4323
+ cur_policy_shift_labels[cur_policy_shift_labels == self.tokenizer.pad_token_id] = -100
4324
+ cur_policy_shift_labels = cur_policy_shift_labels.to(cur_policy_shift_labels.device)
4325
+ cur_policy_reward_base_loss = loss_fct(
4326
+ cur_policy_shift_logits, cur_policy_shift_labels.to(cur_policy_shift_logits.device)
4327
+ ).reshape(logits.shape[0], -1)
4328
+ original_dqn_reward = cur_policy_reward_base_loss.detach() - unreduced_loss
4329
+
4330
+ if not did_skip_sampling:
4331
+ nonzero_indices = prev_probabilities_2d.nonzero()
4332
+ action_loglikelihoods = F.log_softmax(prev_sample_probs / self.reinforce_temperature, dim=-1)[nonzero_indices[:, 0], nonzero_indices[:, 1]]
4333
+ action_loglikelihoods_2d = action_loglikelihoods.reshape(batch_size, -1)[:, :-1 - shift_amount]
4334
+ action_loglikelihoods_list.append(action_loglikelihoods_2d)
4335
+ if policy_reward is None:
4336
+ policy_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
4337
+ else:
4338
+ if self.n_ahead_talk > shift_amount:
4339
+ added_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
4340
+ else:
4341
+ added_reward = original_dqn_reward
4342
+ policy_reward += added_reward
4343
+
4344
+ for action_loglikelihoods_2d in action_loglikelihoods_list:
4345
+ train_policy_reward = policy_reward
4346
+
4347
+ # discard rewards below the mean
4348
+ if self.trice_mode and self.n_passes > 1:
4349
+ batched_policy_reward = train_policy_reward.reshape(-1, self.n_passes, train_policy_reward.shape[-1])
4350
+ # average over the passes
4351
+ train_policy_reward = batched_policy_reward - batched_policy_reward.mean(dim=1, keepdim=True)
4352
+ train_policy_reward = train_policy_reward.reshape(-1, train_policy_reward.shape[-1])
4353
+
4354
+ if self.subtract_mean_reward:
4355
+ train_policy_reward = train_policy_reward - train_policy_reward.mean()
4356
+ if self.remove_negative_rewards:
4357
+ fixed_policy_reward = train_policy_reward.detach().clamp(min=0)
4358
+ else:
4359
+ fixed_policy_reward = train_policy_reward.detach()
4360
+
4361
+ # Normalize rewards
4362
+ fixed_policy_reward = (fixed_policy_reward - fixed_policy_reward.mean()) / (fixed_policy_reward.std() + 1e-8)
4363
+ actor_loss = -fixed_policy_reward * action_loglikelihoods_2d[:, :policy_reward.shape[-1]].to(policy_reward.device)
4364
+ if action_loglikelihoods_2d.mean() < -1e4 and not self.use_policy_loss_just_for_thoughts:
4365
+ # This will only happen when we force the next token to be the end of thought token
4366
+ break
4367
+ dqn_loss_list.append(actor_loss.mean())
4368
+
4369
+ if loss_list:
4370
+ if self.first_and_last_mode:
4371
+ loss = sum(
4372
+ self.loss_mean(loss_list[-(i + 1)]) for i in range(self.n_ahead_talk)
4373
+ ) * (1 - self.original_loss_weight) / self.n_ahead_talk
4374
+ loss = loss + self.loss_mean(loss_list[0]) * self.original_loss_weight
4375
+ # Let's NaN out the others
4376
+ # e.g. if n_ahead_talk = 2 and the list is 5 long, we want to NaN out 1, 2 but keep 0, 3, 4
4377
+ for i in range(1, len(loss_list) - self.n_ahead_talk):
4378
+ loss_list[i] = loss_list[i] * math.nan
4379
+ elif self.first_only:
4380
+ loss = self.loss_mean(loss_list[0])
4381
+ elif self.final_only_mode:
4382
+ loss = sum(
4383
+ self.loss_mean(loss_list[-i]) for i in range(1, self.n_ahead_talk + 1)
4384
+ ) / self.n_ahead_talk
4385
+ else:
4386
+ loss = None
4387
+ for i in range(len(loss_list)):
4388
+ cur_loss = self.loss_mean(loss_list[i])
4389
+ if loss is not None:
4390
+ loss = loss + cur_loss.to(loss.device)
4391
+ else:
4392
+ loss = cur_loss
4393
+ loss = loss / len(loss_list)
4394
+ loss = loss + thinking_usefulness_loss
4395
+
4396
+ base_loss_scale = 0.6
4397
+ policy_loss_scale = 0.03
4398
+
4399
+ loss = loss * base_loss_scale
4400
+
4401
+ if dqn_loss_list:
4402
+ dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
4403
+ if self.include_policy_loss:
4404
+ if loss is not None:
4405
+ loss += dqn_loss * policy_loss_scale
4406
+ else:
4407
+ loss = dqn_loss * self.policy_loss_beta
4408
+
4409
+ if not return_dict:
4410
+ output = (logits,) + outputs[1:]
4411
+ return (loss,) + output if loss is not None else output
4412
+
4413
+ base_log_dict = {
4414
+ f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
4415
+ }
4416
+
4417
+ if loss is not None:
4418
+ base_log_dict["loss_train"] = loss.item()
4419
+
4420
+ if not self.training:
4421
+ self.n_ahead_talk = n_ahead_talk_to_restore
4422
+ self.n_passes = n_passes_to_restore
4423
+
4424
+ del start_embedding
4425
+ del end_embedding
4426
+ torch.cuda.empty_cache()
4427
+
4428
+
4429
+ return CausalLMOutputWithPast(
4430
+ loss=loss if loss is not None else None,
4431
+ logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
4432
+ past_key_values=outputs.past_key_values,
4433
+ hidden_states=outputs.hidden_states,
4434
+ attentions=outputs.attentions,
4435
+ )
4436
+
4437
+
4438
+
4439
+ def prepare_inputs_for_generation(
4440
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
4441
+ ):
4442
+ # Omit tokens covered by past_key_values
4443
+ if past_key_values is not None:
4444
+ if isinstance(past_key_values, Cache):
4445
+ cache_length = past_key_values.get_seq_length()
4446
+ past_length = past_key_values.seen_tokens
4447
+ max_cache_length = past_key_values.get_max_length()
4448
+ else:
4449
+ cache_length = past_length = past_key_values[0][0].shape[2]
4450
+ max_cache_length = None
4451
+
4452
+ # Keep only the unprocessed tokens:
4453
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
4454
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
4455
+ # input)
4456
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
4457
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
4458
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
4459
+ # input_ids based on the past_length.
4460
+ elif past_length < input_ids.shape[1]:
4461
+ input_ids = input_ids[:, past_length:]
4462
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
4463
+
4464
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
4465
+ if (
4466
+ max_cache_length is not None
4467
+ and attention_mask is not None
4468
+ and cache_length + input_ids.shape[1] > max_cache_length
4469
+ ):
4470
+ attention_mask = attention_mask[:, -max_cache_length:]
4471
+
4472
+ position_ids = kwargs.get("position_ids", None)
4473
+ if attention_mask is not None and position_ids is None:
4474
+ # create position_ids on the fly for batch generation
4475
+ position_ids = attention_mask.long().cumsum(-1) - 1
4476
+ position_ids.masked_fill_(attention_mask == 0, 1)
4477
+ if past_key_values:
4478
+ position_ids = position_ids[:, -input_ids.shape[1] :]
4479
+
4480
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
4481
+ if inputs_embeds is not None and past_key_values is None:
4482
+ model_inputs = {"inputs_embeds": inputs_embeds}
4483
+ else:
4484
+ model_inputs = {"input_ids": input_ids}
4485
+
4486
+ model_inputs.update(
4487
+ {
4488
+ "position_ids": position_ids,
4489
+ "past_key_values": past_key_values,
4490
+ "use_cache": kwargs.get("use_cache"),
4491
+ "attention_mask": attention_mask,
4492
+ }
4493
+ )
4494
+ return model_inputs
4495
+
4496
+ @staticmethod
4497
+ def _reorder_cache(past_key_values, beam_idx):
4498
+ reordered_past = ()
4499
+ for layer_past in past_key_values:
4500
+ reordered_past += (
4501
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
4502
+ )
4503
+ return reordered_past
4504
+
4505
+
4506
+
4507
+
4508
+ @add_start_docstrings(
4509
+ """
4510
+ The Quiet Model transformer with a sequence classification head on top (linear layer).
4511
+ [`QuietForSequenceClassification`] uses the last token in order to do the classification, as other causal models
4512
+ (e.g. GPT-2) do.
4513
+ Since it does classification on the last token, it requires to know the position of the last token. If a
4514
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
4515
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
4516
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
4517
+ each row of the batch).
4518
+ """,
4519
+ QUIET_START_DOCSTRING,
4520
+ )
4521
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
4522
+ class QuietForSequenceClassification(QuietPreTrainedModel):
4523
+ def __init__(self, config):
4524
+ super().__init__(config)
4525
+ self.num_labels = config.num_labels
4526
+ self.model = QuietModel(config)
4527
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
4528
+
4529
+ # Initialize weights and apply final processing
4530
+ self.post_init()
4531
+
4532
+ def get_input_embeddings(self):
4533
+ return self.model.embed_tokens
4534
+
4535
+ def set_input_embeddings(self, value):
4536
+ self.model.embed_tokens = value
4537
+
4538
+ @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
4539
+ def forward(
4540
+ self,
4541
+ input_ids: torch.LongTensor = None,
4542
+ attention_mask: Optional[torch.Tensor] = None,
4543
+ position_ids: Optional[torch.LongTensor] = None,
4544
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
4545
+ inputs_embeds: Optional[torch.FloatTensor] = None,
4546
+ labels: Optional[torch.LongTensor] = None,
4547
+ use_cache: Optional[bool] = None,
4548
+ output_attentions: Optional[bool] = None,
4549
+ output_hidden_states: Optional[bool] = None,
4550
+ return_dict: Optional[bool] = None,
4551
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
4552
+ r"""
4553
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
4554
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
4555
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
4556
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
4557
+ """
4558
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
4559
+
4560
+ transformer_outputs = self.model(
4561
+ input_ids,
4562
+ attention_mask=attention_mask,
4563
+ position_ids=position_ids,
4564
+ past_key_values=past_key_values,
4565
+ inputs_embeds=inputs_embeds,
4566
+ use_cache=use_cache,
4567
+ output_attentions=output_attentions,
4568
+ output_hidden_states=output_hidden_states,
4569
+ return_dict=return_dict,
4570
+ )
4571
+ hidden_states = transformer_outputs[0]
4572
+ logits = self.score(hidden_states)
4573
+
4574
+ if input_ids is not None:
4575
+ batch_size = input_ids.shape[0]
4576
+ else:
4577
+ batch_size = inputs_embeds.shape[0]
4578
+
4579
+ if self.config.pad_token_id is None and batch_size != 1:
4580
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
4581
+ if self.config.pad_token_id is None:
4582
+ sequence_lengths = -1
4583
+ else:
4584
+ if input_ids is not None:
4585
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
4586
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
4587
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
4588
+ sequence_lengths = sequence_lengths.to(logits.device)
4589
+ else:
4590
+ sequence_lengths = -1
4591
+
4592
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
4593
+
4594
+ loss = None
4595
+ if labels is not None:
4596
+ labels = labels.to(logits.device)
4597
+ if self.config.problem_type is None:
4598
+ if self.num_labels == 1:
4599
+ self.config.problem_type = "regression"
4600
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
4601
+ self.config.problem_type = "single_label_classification"
4602
+ else:
4603
+ self.config.problem_type = "multi_label_classification"
4604
+
4605
+ if self.config.problem_type == "regression":
4606
+ loss_fct = MSELoss()
4607
+ if self.num_labels == 1:
4608
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
4609
+ else:
4610
+ loss = loss_fct(pooled_logits, labels)
4611
+ elif self.config.problem_type == "single_label_classification":
4612
+ loss_fct = CrossEntropyLoss()
4613
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
4614
+ elif self.config.problem_type == "multi_label_classification":
4615
+ loss_fct = BCEWithLogitsLoss()
4616
+ loss = loss_fct(pooled_logits, labels)
4617
+ if not return_dict:
4618
+ output = (pooled_logits,) + transformer_outputs[1:]
4619
+ return ((loss,) + output) if loss is not None else output
4620
+
4621
+ return SequenceClassifierOutputWithPast(
4622
+ loss=loss,
4623
+ logits=pooled_logits,
4624
+ past_key_values=transformer_outputs.past_key_values,
4625
+ hidden_states=transformer_outputs.hidden_states,
4626
+ attentions=transformer_outputs.attentions,
4627
+ )