x54-729 commited on
Commit
cea5d14
1 Parent(s): 9f7e25e

FIx batch generation

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +88 -49
modeling_internlm.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -28,7 +28,6 @@ import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
31
- from transformers.generation.streamers import BaseStreamer
32
  from transformers.modeling_outputs import (
33
  BaseModelOutputWithPast,
34
  CausalLMOutputWithPast,
@@ -42,6 +41,11 @@ from transformers.utils import (
42
  replace_return_docstrings,
43
  )
44
 
 
 
 
 
 
45
  from .configuration_internlm import InternLMConfig
46
 
47
  logger = logging.get_logger(__name__)
@@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
82
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  class InternLMRMSNorm(nn.Module):
86
  """RMSNorm implemention."""
87
 
@@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
113
  base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
114
  device (Any, optional): Running device. Defaults to None.
115
  """
 
116
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
  super().__init__()
118
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@@ -124,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module):
124
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
125
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
126
  emb = torch.cat((freqs, freqs), dim=-1)
127
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
128
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
129
 
130
  def forward(self, x, seq_len=None):
131
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -136,11 +152,11 @@ class InternLMRotaryEmbedding(torch.nn.Module):
136
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
138
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
139
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
140
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
141
  return (
142
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
144
  )
145
 
146
 
@@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
158
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
159
  super().__init__()
160
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
161
- self.register_buffer("inv_freq", inv_freq)
162
  self.dim = dim
163
  self.base = base
164
  self.scaling_factor = scaling_factor
@@ -170,8 +186,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
170
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
171
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
172
  emb = torch.cat((freqs, freqs), dim=-1)
173
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
174
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
175
 
176
  def _update_cached(self, x, seq_len=None):
177
  self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
@@ -185,8 +201,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
185
  t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
186
  freqs = torch.einsum("i,j->ij", t, inv_freq)
187
  emb = torch.cat((freqs, freqs), dim=-1)
188
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
189
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
190
 
191
  def forward(self, x, seq_len=None):
192
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -199,8 +215,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
199
  self._update_cached(x, seq_len)
200
 
201
  return (
202
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
203
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
204
  )
205
 
206
 
@@ -210,23 +226,23 @@ def rotate_half(x):
210
  x2 = x[..., x.shape[-1] // 2 :]
211
  return torch.cat((-x2, x1), dim=-1)
212
 
213
-
214
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
215
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
216
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
217
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
218
- cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
219
- sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
220
- if q.size(2) == 1:
221
- q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
 
 
 
 
222
  else:
 
 
223
  q_embed = (q * cos) + (rotate_half(q) * sin)
224
-
225
- if k.size(2) == 1:
226
- k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
227
- else:
228
  k_embed = (k * cos) + (rotate_half(k) * sin)
229
-
230
  return q_embed, k_embed
231
 
232
 
@@ -256,6 +272,8 @@ class InternLMAttention(nn.Module):
256
  self.hidden_size = config.hidden_size
257
  self.num_heads = config.num_attention_heads
258
  self.head_dim = self.hidden_size // self.num_heads
 
 
259
  self.max_position_embeddings = config.max_position_embeddings
260
 
261
  if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -264,27 +282,30 @@ class InternLMAttention(nn.Module):
264
  f" and `num_heads`: {self.num_heads})."
265
  )
266
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
267
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
268
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
269
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
270
  self.rotary_emb = self._init_rope()
271
 
272
  def _init_rope(self):
273
- if self.config.rotary["type"] == "origin":
274
  self.rotary_emb = InternLMRotaryEmbedding(
275
  self.head_dim,
276
  max_position_embeddings=self.max_position_embeddings,
277
- base=self.config.rotary["base"],
278
- )
279
- elif self.config.rotary["type"] == "dynamic":
280
- self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
281
- self.head_dim,
282
- max_position_embeddings=self.max_position_embeddings,
283
- base=self.config.rotary["base"],
284
- scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
285
  )
286
  else:
287
- raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
 
 
 
 
 
 
 
 
 
 
288
  return self.rotary_emb
289
 
290
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -302,21 +323,27 @@ class InternLMAttention(nn.Module):
302
  bsz, q_len, _ = hidden_states.size()
303
 
304
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
305
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
306
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
307
 
308
  if past_key_value is not None:
309
  # reuse k, v, self_attention
310
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
311
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
312
 
313
- # print(use_cache)
314
  past_key_value = (key_states, value_states) if use_cache else None
315
 
316
  kv_seq_len = key_states.shape[-2]
317
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
318
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
319
 
 
 
 
320
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
321
 
322
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@@ -851,12 +878,16 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
851
  for layer_past in past_key_values:
852
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
853
  return reordered_past
854
-
855
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
856
  prompt = ""
 
 
 
 
857
  for record in history:
858
- prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
859
- prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
860
  return tokenizer([prompt], return_tensors="pt")
861
 
862
  @torch.no_grad()
@@ -870,9 +901,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
870
  do_sample: bool = True,
871
  temperature: float = 0.8,
872
  top_p: float = 0.8,
 
 
 
873
  **kwargs,
874
  ):
875
- inputs = self.build_inputs(tokenizer, query, history)
876
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
877
  outputs = self.generate(
878
  **inputs,
@@ -907,6 +941,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
907
  ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
908
  ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
909
  """
 
 
 
 
 
910
 
911
  response_queue = queue.Queue(maxsize=20)
912
 
@@ -1083,4 +1122,4 @@ class InternLMForSequenceClassification(InternLMPreTrainedModel):
1083
  past_key_values=transformer_outputs.past_key_values,
1084
  hidden_states=transformer_outputs.hidden_states,
1085
  attentions=transformer_outputs.attentions,
1086
- )
 
1
  # coding=utf-8
2
+ # Copyright (c) InternLM. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
 
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
 
41
  replace_return_docstrings,
42
  )
43
 
44
+ try:
45
+ from transformers.generation.streamers import BaseStreamer
46
+ except: # noqa # pylint: disable=bare-except
47
+ BaseStreamer = None
48
+
49
  from .configuration_internlm import InternLMConfig
50
 
51
  logger = logging.get_logger(__name__)
 
86
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
 
88
 
89
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
90
+ """
91
+ (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
92
+ """
93
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
94
+ if n_rep == 1:
95
+ return hidden_states
96
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
97
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
98
+
99
+
100
  class InternLMRMSNorm(nn.Module):
101
  """RMSNorm implemention."""
102
 
 
128
  base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
129
  device (Any, optional): Running device. Defaults to None.
130
  """
131
+
132
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
133
  super().__init__()
134
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
 
140
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
141
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
142
  emb = torch.cat((freqs, freqs), dim=-1)
143
+ self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False)
144
+ self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False)
145
 
146
  def forward(self, x, seq_len=None):
147
  # x: [bs, num_attention_heads, seq_len, head_size]
 
152
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
153
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
155
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
156
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
157
  return (
158
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
159
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
160
  )
161
 
162
 
 
174
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
175
  super().__init__()
176
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
  self.dim = dim
179
  self.base = base
180
  self.scaling_factor = scaling_factor
 
186
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
187
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
188
  emb = torch.cat((freqs, freqs), dim=-1)
189
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
190
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
191
 
192
  def _update_cached(self, x, seq_len=None):
193
  self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
 
201
  t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
202
  freqs = torch.einsum("i,j->ij", t, inv_freq)
203
  emb = torch.cat((freqs, freqs), dim=-1)
204
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
205
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
206
 
207
  def forward(self, x, seq_len=None):
208
  # x: [bs, num_attention_heads, seq_len, head_size]
 
215
  self._update_cached(x, seq_len)
216
 
217
  return (
218
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
219
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
220
  )
221
 
222
 
 
226
  x2 = x[..., x.shape[-1] // 2 :]
227
  return torch.cat((-x2, x1), dim=-1)
228
 
 
229
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
230
+ if position_ids.size(1) == 1:
231
+ q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
232
+ q_sin = sin[position_ids].unsqueeze(1).expand(q.shape)
233
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
234
+
235
+ position_ids = position_ids.flatten() + 1
236
+ max_length = max(position_ids)
237
+ position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
238
+ k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
239
+ k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
240
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
241
  else:
242
+ cos = cos[position_ids].unsqueeze(1).expand(q.shape)
243
+ sin = sin[position_ids].unsqueeze(1).expand(q.shape)
244
  q_embed = (q * cos) + (rotate_half(q) * sin)
 
 
 
 
245
  k_embed = (k * cos) + (rotate_half(k) * sin)
 
246
  return q_embed, k_embed
247
 
248
 
 
272
  self.hidden_size = config.hidden_size
273
  self.num_heads = config.num_attention_heads
274
  self.head_dim = self.hidden_size // self.num_heads
275
+ self.num_key_value_heads = config.num_key_value_heads
276
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
  self.max_position_embeddings = config.max_position_embeddings
278
 
279
  if (self.head_dim * self.num_heads) != self.hidden_size:
 
282
  f" and `num_heads`: {self.num_heads})."
283
  )
284
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
285
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
286
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
287
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
288
  self.rotary_emb = self._init_rope()
289
 
290
  def _init_rope(self):
291
+ if self.config.rope_scaling is None:
292
  self.rotary_emb = InternLMRotaryEmbedding(
293
  self.head_dim,
294
  max_position_embeddings=self.max_position_embeddings,
295
+ base=self.config.rope_theta,
 
 
 
 
 
 
 
296
  )
297
  else:
298
+ scaling_type = self.config.rope_scaling["type"]
299
+ scaling_factor = self.config.rope_scaling["factor"]
300
+ if scaling_type == "dynamic":
301
+ self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
302
+ self.head_dim,
303
+ max_position_embeddings=self.max_position_embeddings,
304
+ base=self.config.rope_theta,
305
+ scaling_factor=scaling_factor,
306
+ )
307
+ else:
308
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
309
  return self.rotary_emb
310
 
311
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
323
  bsz, q_len, _ = hidden_states.size()
324
 
325
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
326
+ key_states = (
327
+ self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
+ )
329
+ value_states = (
330
+ self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
331
+ )
332
 
333
  if past_key_value is not None:
334
  # reuse k, v, self_attention
335
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
336
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
337
 
 
338
  past_key_value = (key_states, value_states) if use_cache else None
339
 
340
  kv_seq_len = key_states.shape[-2]
341
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
342
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
343
 
344
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
345
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
346
+
347
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
348
 
349
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
878
  for layer_past in past_key_values:
879
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
880
  return reordered_past
881
+
882
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
883
  prompt = ""
884
+ if meta_instruction:
885
+ prompt += f"""<s><|System|>:{meta_instruction}\n"""
886
+ else:
887
+ prompt += "<s>"
888
  for record in history:
889
+ prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}<eoa>\n"""
890
+ prompt += f"""<|User|>:{query}\n<|Bot|>:"""
891
  return tokenizer([prompt], return_tensors="pt")
892
 
893
  @torch.no_grad()
 
901
  do_sample: bool = True,
902
  temperature: float = 0.8,
903
  top_p: float = 0.8,
904
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
905
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
906
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
907
  **kwargs,
908
  ):
909
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
910
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
911
  outputs = self.generate(
912
  **inputs,
 
941
  ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
942
  ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
943
  """
944
+ if BaseStreamer is None:
945
+ raise ModuleNotFoundError(
946
+ "The version of `transformers` is too low. Please make sure "
947
+ "that you have installed `transformers>=4.28.0`."
948
+ )
949
 
950
  response_queue = queue.Queue(maxsize=20)
951
 
 
1122
  past_key_values=transformer_outputs.past_key_values,
1123
  hidden_states=transformer_outputs.hidden_states,
1124
  attentions=transformer_outputs.attentions,
1125
+ )