Abhaykoul commited on
Commit
ce9ff91
1 Parent(s): fc837bd

Update modeling_llava.py

Browse files
Files changed (1) hide show
  1. modeling_llava.py +1498 -17
modeling_llava.py CHANGED
@@ -1,17 +1,1365 @@
1
  # coding=utf-8
 
2
  from dataclasses import dataclass
3
  from typing import List, Optional, Tuple, Union
4
 
5
  import torch
 
6
  import torch.utils.checkpoint
7
  from torch import nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- from transformers import PreTrainedModel
10
- from transformers.modeling_outputs import ModelOutput
 
 
 
11
 
12
- from modeling_phi import PhiForCausalLM
13
- from configuration_llava import LlavaConfig
14
- from open_clip import create_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @dataclass
@@ -24,22 +1372,86 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
24
  image_features: Optional[torch.FloatTensor] = None
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class LlavaMultiModalProjector(nn.Module):
28
  def __init__(self, config: LlavaConfig):
29
  super().__init__()
30
 
31
  self.linear_1 = nn.Linear(
32
  config.vision_embed_dim,
33
- config.text_config.n_embd * config.projector_tokens_num,
34
  bias=True,
35
  )
36
  self.act = nn.GELU()
37
  self.linear_2 = nn.Linear(
38
- config.text_config.n_embd * 5,
39
- config.text_config.n_embd,
40
  bias=True,
41
  )
42
- self.projector_tokens_num = config.projector_tokens_num
43
 
44
  def forward(self, image_features):
45
  hidden_states = self.linear_1(image_features)
@@ -71,11 +1483,10 @@ class LlavaPreTrainedModel(PreTrainedModel):
71
  return self.language_model._supports_sdpa
72
 
73
 
74
- class LlavaForConditionalGeneration(LlavaPreTrainedModel):
75
  def __init__(self, config: LlavaConfig):
76
  super().__init__(config)
77
- clip_model = create_model(config.vision_tower_name)
78
- self.vision_model = clip_model.visual
79
 
80
  self.multi_modal_projector = LlavaMultiModalProjector(config)
81
  self.vocab_size = config.vocab_size
@@ -246,6 +1657,38 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
246
  attention_mask,
247
  position_ids,
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  outputs = self.language_model(
251
  input_ids=None,
@@ -261,7 +1704,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
261
 
262
  logits = outputs[0]
263
 
264
-
265
  if not return_dict:
266
  output = (logits,) + outputs[1:]
267
  return output
@@ -283,11 +1725,49 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
283
  image_features=None,
284
  **kwargs,
285
  ):
286
- res = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, attention_mask, **kwargs)
287
- input_ids = res["input_ids"]
288
- past_key_values = res["past_key_values"]
289
- attention_mask = res["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
 
291
  if inputs_embeds is not None and past_key_values is None:
292
  model_inputs = {"inputs_embeds": inputs_embeds}
293
  else:
@@ -295,6 +1775,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
295
 
296
  model_inputs.update(
297
  {
 
298
  "past_key_values": past_key_values,
299
  "use_cache": kwargs.get("use_cache"),
300
  "attention_mask": attention_mask,
 
1
  # coding=utf-8
2
+ import math
3
  from dataclasses import dataclass
4
  from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
+ import torch.nn.functional as F
8
  import torch.utils.checkpoint
9
  from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+ from transformers import PreTrainedModel, SiglipVisionModel
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache, DynamicCache
14
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ ModelOutput,
19
+ SequenceClassifierOutputWithPast,
20
+ TokenClassifierOutput,
21
+ )
22
+ from transformers.utils import (
23
+ is_flash_attn_2_available,
24
+ is_flash_attn_greater_or_equal_2_10,
25
+ logging,
26
+ )
27
 
28
+ try:
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ except Exception as exp:
32
+ print(exp)
33
 
34
+
35
+ from transformers.configuration_utils import PretrainedConfig
36
+ from transformers import SiglipVisionConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class PhiConfig(PretrainedConfig):
43
+ model_type = "phi"
44
+ keys_to_ignore_at_inference = ["past_key_values"]
45
+
46
+ def __init__(
47
+ self,
48
+ vocab_size=51200,
49
+ hidden_size=2048,
50
+ intermediate_size=8192,
51
+ num_hidden_layers=24,
52
+ num_attention_heads=32,
53
+ num_key_value_heads=None,
54
+ resid_pdrop=0.0,
55
+ embd_pdrop=0.0,
56
+ attention_dropout=0.0,
57
+ hidden_act="gelu_new",
58
+ max_position_embeddings=2048,
59
+ initializer_range=0.02,
60
+ layer_norm_eps=1e-5,
61
+ use_cache=True,
62
+ tie_word_embeddings=False,
63
+ rope_theta=10000.0,
64
+ rope_scaling=None,
65
+ partial_rotary_factor=0.5,
66
+ qk_layernorm=False,
67
+ bos_token_id=1,
68
+ eos_token_id=2,
69
+ **kwargs,
70
+ ):
71
+ self.vocab_size = vocab_size
72
+ self.hidden_size = hidden_size
73
+ self.intermediate_size = intermediate_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+
77
+ if num_key_value_heads is None:
78
+ num_key_value_heads = num_attention_heads
79
+
80
+ self.num_key_value_heads = num_key_value_heads
81
+ self.resid_pdrop = resid_pdrop
82
+ self.embd_pdrop = embd_pdrop
83
+ self.attention_dropout = attention_dropout
84
+ self.hidden_act = hidden_act
85
+ self.max_position_embeddings = max_position_embeddings
86
+ self.initializer_range = initializer_range
87
+ self.layer_norm_eps = layer_norm_eps
88
+ self.use_cache = use_cache
89
+ self.rope_theta = rope_theta
90
+ self.rope_scaling = rope_scaling
91
+ self.partial_rotary_factor = partial_rotary_factor
92
+ self.qk_layernorm = qk_layernorm
93
+ self._rope_scaling_validation()
94
+
95
+ super().__init__(
96
+ bos_token_id=bos_token_id,
97
+ eos_token_id=eos_token_id,
98
+ tie_word_embeddings=tie_word_embeddings,
99
+ **kwargs,
100
+ )
101
+
102
+ def _rope_scaling_validation(self):
103
+ """
104
+ Validate the `rope_scaling` configuration.
105
+ """
106
+ if self.rope_scaling is None:
107
+ return
108
+
109
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
110
+ raise ValueError(
111
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
112
+ f"got {self.rope_scaling}"
113
+ )
114
+ rope_scaling_type = self.rope_scaling.get("type", None)
115
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
116
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
117
+ raise ValueError(
118
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
119
+ )
120
+ if (
121
+ rope_scaling_factor is None
122
+ or not isinstance(rope_scaling_factor, float)
123
+ or rope_scaling_factor <= 1.0
124
+ ):
125
+ raise ValueError(
126
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
127
+ )
128
+
129
+
130
+ class LlavaConfig(PretrainedConfig):
131
+ model_type = "mc-llava"
132
+ is_composition = False
133
+
134
+ def __init__(
135
+ self,
136
+ text_config=None,
137
+ vision_config=None,
138
+ ignore_index=-100,
139
+ image_token_index=50297,
140
+ projector_hidden_act="gelu",
141
+ projector_tokens_num=1,
142
+ vocab_size=51200,
143
+ **kwargs,
144
+ ):
145
+ self.ignore_index = ignore_index
146
+ self.image_token_index = image_token_index
147
+ self.projector_hidden_act = projector_hidden_act
148
+ self.projector_tokens_num = projector_tokens_num
149
+ self.vocab_size = vocab_size
150
+
151
+ self.text_config = text_config
152
+ if isinstance(self.text_config, dict):
153
+ text_config["model_type"] = (
154
+ text_config["model_type"] if "model_type" in text_config else "phi"
155
+ )
156
+ self.text_config = PhiConfig(**text_config)
157
+ self.vocab_size = self.text_config.vocab_size
158
+
159
+ self.vision_config = vision_config
160
+ if isinstance(self.vision_config, dict):
161
+ self.vision_config = SiglipVisionConfig(**vision_config)
162
+ self.vision_embed_dim = self.vision_config.hidden_size
163
+
164
+ super().__init__(**kwargs)
165
+
166
+
167
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
168
+ def _get_unpad_data(attention_mask):
169
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
170
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
172
+ cu_seqlens = F.pad(
173
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
174
+ )
175
+ return (
176
+ indices,
177
+ cu_seqlens,
178
+ max_seqlen_in_batch,
179
+ )
180
+
181
+
182
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
183
+ class PhiRotaryEmbedding(nn.Module):
184
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
185
+ super().__init__()
186
+
187
+ self.dim = dim
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.base = base
190
+ inv_freq = 1.0 / (
191
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
192
+ )
193
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
194
+
195
+ # Build here to make `torch.jit.trace` work.
196
+ self._set_cos_sin_cache(
197
+ seq_len=max_position_embeddings,
198
+ device=self.inv_freq.device,
199
+ dtype=torch.get_default_dtype(),
200
+ )
201
+
202
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
203
+ self.max_seq_len_cached = seq_len
204
+ t = torch.arange(
205
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
206
+ )
207
+
208
+ freqs = torch.outer(t, self.inv_freq)
209
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
210
+ emb = torch.cat((freqs, freqs), dim=-1)
211
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
212
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
213
+
214
+ def forward(self, x, seq_len=None):
215
+ # x: [bs, num_attention_heads, seq_len, head_size]
216
+ if seq_len > self.max_seq_len_cached:
217
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
218
+
219
+ return (
220
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
221
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
222
+ )
223
+
224
+
225
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
226
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
227
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
228
+
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ max_position_embeddings=2048,
233
+ base=10000,
234
+ device=None,
235
+ scaling_factor=1.0,
236
+ ):
237
+ self.scaling_factor = scaling_factor
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+ t = torch.arange(
243
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
244
+ )
245
+ t = t / self.scaling_factor
246
+
247
+ freqs = torch.outer(t, self.inv_freq)
248
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
249
+ emb = torch.cat((freqs, freqs), dim=-1)
250
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
251
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
252
+
253
+
254
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
255
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
256
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
257
+
258
+ def __init__(
259
+ self,
260
+ dim,
261
+ max_position_embeddings=2048,
262
+ base=10000,
263
+ device=None,
264
+ scaling_factor=1.0,
265
+ ):
266
+ self.scaling_factor = scaling_factor
267
+ super().__init__(dim, max_position_embeddings, base, device)
268
+
269
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
270
+ self.max_seq_len_cached = seq_len
271
+
272
+ if seq_len > self.max_position_embeddings:
273
+ base = self.base * (
274
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
275
+ - (self.scaling_factor - 1)
276
+ ) ** (self.dim / (self.dim - 2))
277
+ inv_freq = 1.0 / (
278
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
279
+ )
280
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
281
+
282
+ t = torch.arange(
283
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
284
+ )
285
+
286
+ freqs = torch.outer(t, self.inv_freq)
287
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
288
+ emb = torch.cat((freqs, freqs), dim=-1)
289
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
290
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
294
+ def rotate_half(x):
295
+ """Rotates half the hidden dims of the input."""
296
+ x1 = x[..., : x.shape[-1] // 2]
297
+ x2 = x[..., x.shape[-1] // 2 :]
298
+ return torch.cat((-x2, x1), dim=-1)
299
+
300
+
301
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
302
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
303
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
304
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
305
+ q_embed = (q * cos) + (rotate_half(q) * sin)
306
+ k_embed = (k * cos) + (rotate_half(k) * sin)
307
+ return q_embed, k_embed
308
+
309
+
310
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
311
+ class PhiMLP(nn.Module):
312
+ def __init__(self, config):
313
+ super().__init__()
314
+ self.config = config
315
+ self.activation_fn = ACT2FN[config.hidden_act]
316
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
317
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
318
+
319
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
320
+ hidden_states = self.fc1(hidden_states)
321
+ hidden_states = self.activation_fn(hidden_states)
322
+ hidden_states = self.fc2(hidden_states)
323
+ return hidden_states
324
+
325
+
326
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
327
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
328
+ """
329
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
330
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
331
+ """
332
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
333
+ if n_rep == 1:
334
+ return hidden_states
335
+ hidden_states = hidden_states[:, :, None, :, :].expand(
336
+ batch, num_key_value_heads, n_rep, slen, head_dim
337
+ )
338
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
339
+
340
+
341
+ class PhiAttention(nn.Module):
342
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
343
+
344
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
345
+ super().__init__()
346
+ self.config = config
347
+ self.layer_idx = layer_idx
348
+ if layer_idx is None:
349
+ logger.warning_once(
350
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
351
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
352
+ "when creating this class."
353
+ )
354
+
355
+ self.attention_dropout = config.attention_dropout
356
+ self.hidden_size = config.hidden_size
357
+ self.num_heads = config.num_attention_heads
358
+ self.head_dim = self.hidden_size // self.num_heads
359
+ self.num_key_value_heads = config.num_key_value_heads
360
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
361
+ self.max_position_embeddings = config.max_position_embeddings
362
+ self.rope_theta = config.rope_theta
363
+ self.partial_rotary_factor = config.partial_rotary_factor
364
+ self.is_causal = True
365
+
366
+ if (self.head_dim * self.num_heads) != self.hidden_size:
367
+ raise ValueError(
368
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
369
+ f" and `num_heads`: {self.num_heads})."
370
+ )
371
+
372
+ self.q_proj = nn.Linear(
373
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
374
+ )
375
+ self.k_proj = nn.Linear(
376
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
377
+ )
378
+ self.v_proj = nn.Linear(
379
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
380
+ )
381
+ self.dense = nn.Linear(
382
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
383
+ )
384
+
385
+ self.qk_layernorm = config.qk_layernorm
386
+ if self.qk_layernorm:
387
+ self.q_layernorm = nn.LayerNorm(
388
+ config.hidden_size // self.num_heads,
389
+ eps=config.layer_norm_eps,
390
+ elementwise_affine=True,
391
+ )
392
+ self.k_layernorm = nn.LayerNorm(
393
+ config.hidden_size // self.num_heads,
394
+ eps=config.layer_norm_eps,
395
+ elementwise_affine=True,
396
+ )
397
+
398
+ self._init_rope()
399
+
400
+ def _init_rope(self):
401
+ if self.config.rope_scaling is None:
402
+ self.rotary_emb = PhiRotaryEmbedding(
403
+ int(self.partial_rotary_factor * self.head_dim),
404
+ max_position_embeddings=self.max_position_embeddings,
405
+ base=self.rope_theta,
406
+ )
407
+ else:
408
+ scaling_type = self.config.rope_scaling["type"]
409
+ scaling_factor = self.config.rope_scaling["factor"]
410
+ if scaling_type == "linear":
411
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
412
+ int(self.partial_rotary_factor * self.head_dim),
413
+ max_position_embeddings=self.max_position_embeddings,
414
+ scaling_factor=scaling_factor,
415
+ base=self.rope_theta,
416
+ )
417
+ elif scaling_type == "dynamic":
418
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
419
+ int(self.partial_rotary_factor * self.head_dim),
420
+ max_position_embeddings=self.max_position_embeddings,
421
+ scaling_factor=scaling_factor,
422
+ base=self.rope_theta,
423
+ )
424
+ else:
425
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
426
+
427
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
428
+ @torch.autocast("cpu", enabled=False)
429
+ @torch.autocast("cuda", enabled=False)
430
+ def forward(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ attention_mask: Optional[torch.Tensor] = None,
434
+ position_ids: Optional[torch.LongTensor] = None,
435
+ past_key_value: Optional[Cache] = None,
436
+ output_attentions: bool = False,
437
+ use_cache: bool = False,
438
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
439
+ bsz, q_len, _ = hidden_states.size()
440
+
441
+ query_states = self.q_proj(hidden_states)
442
+ key_states = self.k_proj(hidden_states)
443
+ value_states = self.v_proj(hidden_states)
444
+
445
+ if self.qk_layernorm:
446
+ query_states = self.q_layernorm(query_states)
447
+ key_states = self.k_layernorm(key_states)
448
+
449
+ query_states = query_states.view(
450
+ bsz, q_len, self.num_heads, self.head_dim
451
+ ).transpose(1, 2)
452
+ key_states = key_states.view(
453
+ bsz, q_len, self.num_key_value_heads, self.head_dim
454
+ ).transpose(1, 2)
455
+ value_states = value_states.view(
456
+ bsz, q_len, self.num_key_value_heads, self.head_dim
457
+ ).transpose(1, 2)
458
+
459
+ kv_seq_len = key_states.shape[-2]
460
+ if past_key_value is not None:
461
+ if self.layer_idx is None:
462
+ raise ValueError(
463
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
464
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
465
+ "with a layer index."
466
+ )
467
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
468
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
+
470
+ # Partial rotary embedding
471
+ query_rot, query_pass = (
472
+ query_states[..., : self.rotary_emb.dim],
473
+ query_states[..., self.rotary_emb.dim :],
474
+ )
475
+ key_rot, key_pass = (
476
+ key_states[..., : self.rotary_emb.dim],
477
+ key_states[..., self.rotary_emb.dim :],
478
+ )
479
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
480
+ query_rot, key_rot = apply_rotary_pos_emb(
481
+ query_rot, key_rot, cos, sin, position_ids
482
+ )
483
+
484
+ # [batch_size, seq_length, num_heads, head_dim]
485
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
486
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
487
+
488
+ if past_key_value is not None:
489
+ cache_kwargs = {
490
+ "sin": sin,
491
+ "cos": cos,
492
+ "partial_rotation_size": self.rotary_emb.dim,
493
+ }
494
+ key_states, value_states = past_key_value.update(
495
+ key_states, value_states, self.layer_idx, cache_kwargs
496
+ )
497
+
498
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
499
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
500
+
501
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
502
+ attn_weights = torch.matmul(
503
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
504
+ ) / math.sqrt(self.head_dim)
505
+
506
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
507
+ raise ValueError(
508
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
509
+ f" {attn_weights.size()}"
510
+ )
511
+
512
+ if attention_mask is not None:
513
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
514
+ raise ValueError(
515
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
516
+ )
517
+ attn_weights = attn_weights + attention_mask
518
+
519
+ # upcast attention to fp32
520
+ attn_weights = nn.functional.softmax(
521
+ attn_weights, dim=-1, dtype=torch.float32
522
+ ).to(value_states.dtype)
523
+ attn_weights = nn.functional.dropout(
524
+ attn_weights, p=self.attention_dropout, training=self.training
525
+ )
526
+
527
+ attn_output = torch.matmul(attn_weights, value_states)
528
+
529
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
530
+ raise ValueError(
531
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
532
+ f" {attn_output.size()}"
533
+ )
534
+
535
+ attn_output = attn_output.transpose(1, 2).contiguous()
536
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
537
+
538
+ attn_output = self.dense(attn_output)
539
+
540
+ if not output_attentions:
541
+ attn_weights = None
542
+
543
+ return attn_output, attn_weights, past_key_value
544
+
545
+
546
+ class PhiFlashAttention2(PhiAttention):
547
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
548
+ def __init__(self, *args, **kwargs):
549
+ super().__init__(*args, **kwargs)
550
+
551
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
552
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
553
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
554
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
555
+
556
+ def forward(
557
+ self,
558
+ hidden_states: torch.Tensor,
559
+ attention_mask: Optional[torch.LongTensor] = None,
560
+ position_ids: Optional[torch.LongTensor] = None,
561
+ past_key_value: Optional[Cache] = None,
562
+ output_attentions: bool = False,
563
+ use_cache: bool = False,
564
+ **kwargs,
565
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
566
+ # PhiFlashAttention2 attention does not support output_attentions
567
+
568
+ output_attentions = False
569
+
570
+ bsz, q_len, _ = hidden_states.size()
571
+
572
+ query_states = self.q_proj(hidden_states)
573
+ key_states = self.k_proj(hidden_states)
574
+ value_states = self.v_proj(hidden_states)
575
+
576
+ if self.qk_layernorm:
577
+ query_states = self.q_layernorm(query_states)
578
+ key_states = self.k_layernorm(key_states)
579
+
580
+ # Flash attention requires the input to have the shape
581
+ # batch_size x seq_length x head_dim x hidden_dim
582
+ # therefore we just need to keep the original shape
583
+ query_states = query_states.view(
584
+ bsz, q_len, self.num_heads, self.head_dim
585
+ ).transpose(1, 2)
586
+ key_states = key_states.view(
587
+ bsz, q_len, self.num_key_value_heads, self.head_dim
588
+ ).transpose(1, 2)
589
+ value_states = value_states.view(
590
+ bsz, q_len, self.num_key_value_heads, self.head_dim
591
+ ).transpose(1, 2)
592
+
593
+ kv_seq_len = key_states.shape[-2]
594
+ if past_key_value is not None:
595
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
596
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
597
+
598
+ # Partial rotary embedding
599
+ query_rot, query_pass = (
600
+ query_states[..., : self.rotary_emb.dim],
601
+ query_states[..., self.rotary_emb.dim :],
602
+ )
603
+ key_rot, key_pass = (
604
+ key_states[..., : self.rotary_emb.dim],
605
+ key_states[..., self.rotary_emb.dim :],
606
+ )
607
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
608
+ query_rot, key_rot = apply_rotary_pos_emb(
609
+ query_rot, key_rot, cos, sin, position_ids
610
+ )
611
+
612
+ # [batch_size, seq_length, num_heads, head_dim]
613
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
614
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
615
+
616
+ if past_key_value is not None:
617
+ cache_kwargs = {
618
+ "sin": sin,
619
+ "cos": cos,
620
+ "partial_rotation_size": self.rotary_emb.dim,
621
+ }
622
+ key_states, value_states = past_key_value.update(
623
+ key_states, value_states, self.layer_idx, cache_kwargs
624
+ )
625
+
626
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
627
+ # to be able to avoid many of these transpose/reshape/view.
628
+ query_states = query_states.transpose(1, 2)
629
+ key_states = key_states.transpose(1, 2)
630
+ value_states = value_states.transpose(1, 2)
631
+
632
+ attn_dropout = self.attention_dropout if self.training else 0.0
633
+
634
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
635
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
636
+ # cast them back in the correct dtype just to be sure everything works as expected.
637
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
638
+ # in fp32.
639
+
640
+ if query_states.dtype == torch.float32:
641
+ if torch.is_autocast_enabled():
642
+ target_dtype = torch.get_autocast_gpu_dtype()
643
+ # Handle the case where the model is quantized
644
+ elif hasattr(self.config, "_pre_quantization_dtype"):
645
+ target_dtype = self.config._pre_quantization_dtype
646
+ else:
647
+ target_dtype = self.q_proj.weight.dtype
648
+
649
+ logger.warning_once(
650
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
651
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
652
+ f" {target_dtype}."
653
+ )
654
+
655
+ query_states = query_states.to(target_dtype)
656
+ key_states = key_states.to(target_dtype)
657
+ value_states = value_states.to(target_dtype)
658
+
659
+ attn_output = self._flash_attention_forward(
660
+ query_states,
661
+ key_states,
662
+ value_states,
663
+ attention_mask,
664
+ q_len,
665
+ dropout=attn_dropout,
666
+ softmax_scale=None,
667
+ )
668
+
669
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
670
+ attn_output = self.dense(attn_output)
671
+
672
+ if not output_attentions:
673
+ attn_weights = None
674
+
675
+ return attn_output, attn_weights, past_key_value
676
+
677
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
678
+ def _flash_attention_forward(
679
+ self,
680
+ query_states,
681
+ key_states,
682
+ value_states,
683
+ attention_mask,
684
+ query_length,
685
+ dropout=0.0,
686
+ softmax_scale=None,
687
+ ):
688
+ if not self._flash_attn_uses_top_left_mask:
689
+ causal = self.is_causal
690
+ else:
691
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
692
+ causal = self.is_causal and query_length != 1
693
+
694
+ # Contains at least one padding token in the sequence
695
+ if attention_mask is not None:
696
+ batch_size = query_states.shape[0]
697
+ (
698
+ query_states,
699
+ key_states,
700
+ value_states,
701
+ indices_q,
702
+ cu_seq_lens,
703
+ max_seq_lens,
704
+ ) = self._upad_input(
705
+ query_states, key_states, value_states, attention_mask, query_length
706
+ )
707
+
708
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
709
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
710
+
711
+ attn_output_unpad = flash_attn_varlen_func(
712
+ query_states,
713
+ key_states,
714
+ value_states,
715
+ cu_seqlens_q=cu_seqlens_q,
716
+ cu_seqlens_k=cu_seqlens_k,
717
+ max_seqlen_q=max_seqlen_in_batch_q,
718
+ max_seqlen_k=max_seqlen_in_batch_k,
719
+ dropout_p=dropout,
720
+ softmax_scale=softmax_scale,
721
+ causal=causal,
722
+ )
723
+
724
+ attn_output = pad_input(
725
+ attn_output_unpad, indices_q, batch_size, query_length
726
+ )
727
+ else:
728
+ attn_output = flash_attn_func(
729
+ query_states,
730
+ key_states,
731
+ value_states,
732
+ dropout,
733
+ softmax_scale=softmax_scale,
734
+ causal=causal,
735
+ )
736
+
737
+ return attn_output
738
+
739
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
740
+ def _upad_input(
741
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
742
+ ):
743
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
744
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
745
+
746
+ key_layer = index_first_axis(
747
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
748
+ indices_k,
749
+ )
750
+ value_layer = index_first_axis(
751
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
752
+ indices_k,
753
+ )
754
+ if query_length == kv_seq_len:
755
+ query_layer = index_first_axis(
756
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
757
+ indices_k,
758
+ )
759
+ cu_seqlens_q = cu_seqlens_k
760
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
761
+ indices_q = indices_k
762
+ elif query_length == 1:
763
+ max_seqlen_in_batch_q = 1
764
+ cu_seqlens_q = torch.arange(
765
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
766
+ ) # There is a memcpy here, that is very bad.
767
+ indices_q = cu_seqlens_q[:-1]
768
+ query_layer = query_layer.squeeze(1)
769
+ else:
770
+ # The -q_len: slice assumes left padding.
771
+ attention_mask = attention_mask[:, -query_length:]
772
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
773
+ query_layer, attention_mask
774
+ )
775
+
776
+ return (
777
+ query_layer,
778
+ key_layer,
779
+ value_layer,
780
+ indices_q,
781
+ (cu_seqlens_q, cu_seqlens_k),
782
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
783
+ )
784
+
785
+
786
+ PHI_ATTENTION_CLASSES = {
787
+ "flash_attention_2": PhiFlashAttention2,
788
+ "eager": PhiAttention,
789
+ }
790
+
791
+
792
+ class PhiDecoderLayer(nn.Module):
793
+ def __init__(self, config: PhiConfig, layer_idx: int):
794
+ super().__init__()
795
+ if is_flash_attn_2_available():
796
+ config._attn_implementation = "flash_attention_2"
797
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](
798
+ config, layer_idx=layer_idx
799
+ )
800
+ self.mlp = PhiMLP(config)
801
+ self.input_layernorm = nn.LayerNorm(
802
+ config.hidden_size, eps=config.layer_norm_eps
803
+ )
804
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
805
+
806
+ def forward(
807
+ self,
808
+ hidden_states: torch.Tensor,
809
+ attention_mask: Optional[torch.Tensor] = None,
810
+ position_ids: Optional[torch.LongTensor] = None,
811
+ output_attentions: Optional[bool] = False,
812
+ use_cache: Optional[bool] = False,
813
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
814
+ ) -> Tuple[
815
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
816
+ ]:
817
+ residual = hidden_states
818
+
819
+ hidden_states = self.input_layernorm(hidden_states)
820
+
821
+ # Self Attention
822
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
823
+ hidden_states=hidden_states,
824
+ attention_mask=attention_mask,
825
+ position_ids=position_ids,
826
+ past_key_value=past_key_value,
827
+ output_attentions=output_attentions,
828
+ use_cache=use_cache,
829
+ )
830
+ attn_outputs = self.resid_dropout(attn_outputs)
831
+
832
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
+ outputs = (hidden_states,)
835
+
836
+ if output_attentions:
837
+ outputs += (self_attn_weights,)
838
+
839
+ if use_cache:
840
+ outputs += (present_key_value,)
841
+
842
+ return outputs
843
+
844
+
845
+ class PhiPreTrainedModel(PreTrainedModel):
846
+ config_class = PhiConfig
847
+ base_model_prefix = "model"
848
+ supports_gradient_checkpointing = True
849
+ _no_split_modules = ["PhiDecoderLayer"]
850
+ _skip_keys_device_placement = "past_key_values"
851
+ _supports_flash_attn_2 = True
852
+ _supports_cache_class = True
853
+
854
+
855
+ class PhiModel(PhiPreTrainedModel):
856
+ """
857
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
858
+ Args:
859
+ config: PhiConfig
860
+ """
861
+
862
+ def __init__(self, config: PhiConfig):
863
+ super().__init__(config)
864
+ self.padding_idx = config.pad_token_id
865
+ self.vocab_size = config.vocab_size
866
+
867
+ self.embed_tokens = nn.Embedding(
868
+ config.vocab_size, config.hidden_size, self.padding_idx
869
+ )
870
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
871
+ self.layers = nn.ModuleList(
872
+ [
873
+ PhiDecoderLayer(config, layer_idx)
874
+ for layer_idx in range(config.num_hidden_layers)
875
+ ]
876
+ )
877
+ self.final_layernorm = nn.LayerNorm(
878
+ config.hidden_size, eps=config.layer_norm_eps
879
+ )
880
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
881
+
882
+ self.gradient_checkpointing = False
883
+ # Initialize weights and apply final processing
884
+ self.post_init()
885
+
886
+ def get_input_embeddings(self):
887
+ return self.embed_tokens
888
+
889
+ def set_input_embeddings(self, value):
890
+ self.embed_tokens = value
891
+
892
+ def forward(
893
+ self,
894
+ input_ids: torch.LongTensor = None,
895
+ attention_mask: Optional[torch.Tensor] = None,
896
+ position_ids: Optional[torch.LongTensor] = None,
897
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
898
+ inputs_embeds: Optional[torch.FloatTensor] = None,
899
+ use_cache: Optional[bool] = None,
900
+ output_attentions: Optional[bool] = None,
901
+ output_hidden_states: Optional[bool] = None,
902
+ return_dict: Optional[bool] = None,
903
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
904
+ output_attentions = (
905
+ output_attentions
906
+ if output_attentions is not None
907
+ else self.config.output_attentions
908
+ )
909
+ output_hidden_states = (
910
+ output_hidden_states
911
+ if output_hidden_states is not None
912
+ else self.config.output_hidden_states
913
+ )
914
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
915
+
916
+ return_dict = (
917
+ return_dict if return_dict is not None else self.config.use_return_dict
918
+ )
919
+
920
+ # retrieve input_ids and inputs_embeds
921
+ if input_ids is not None and inputs_embeds is not None:
922
+ raise ValueError(
923
+ "You cannot specify both input_ids and inputs_embeds at the same time"
924
+ )
925
+ elif input_ids is not None:
926
+ batch_size, seq_length = input_ids.shape[:2]
927
+ elif inputs_embeds is not None:
928
+ batch_size, seq_length = inputs_embeds.shape[:2]
929
+ else:
930
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
931
+
932
+ past_key_values_length = 0
933
+
934
+ if self.gradient_checkpointing and self.training:
935
+ if use_cache:
936
+ logger.warning_once(
937
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
938
+ )
939
+ use_cache = False
940
+
941
+ if use_cache:
942
+ use_legacy_cache = not isinstance(past_key_values, Cache)
943
+ if use_legacy_cache:
944
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
945
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
946
+
947
+ if position_ids is None:
948
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
949
+ position_ids = torch.arange(
950
+ past_key_values_length,
951
+ seq_length + past_key_values_length,
952
+ dtype=torch.long,
953
+ device=device,
954
+ )
955
+ position_ids = position_ids.unsqueeze(0)
956
+
957
+ if inputs_embeds is None:
958
+ inputs_embeds = self.embed_tokens(input_ids)
959
+
960
+ inputs_embeds = self.embed_dropout(inputs_embeds)
961
+
962
+ # Attention mask.
963
+ if self._use_flash_attention_2:
964
+ # 2d mask is passed through the layers
965
+ attention_mask = (
966
+ attention_mask
967
+ if (attention_mask is not None and 0 in attention_mask)
968
+ else None
969
+ )
970
+ else:
971
+ # 4d mask is passed through the layers
972
+ attention_mask = _prepare_4d_causal_attention_mask(
973
+ attention_mask,
974
+ (batch_size, seq_length),
975
+ inputs_embeds,
976
+ past_key_values_length,
977
+ )
978
+
979
+ hidden_states = inputs_embeds
980
+
981
+ # decoder layers
982
+ all_hidden_states = () if output_hidden_states else None
983
+ all_self_attns = () if output_attentions else None
984
+ next_decoder_cache = None
985
+
986
+ for decoder_layer in self.layers:
987
+ if output_hidden_states:
988
+ all_hidden_states += (hidden_states,)
989
+
990
+ if self.gradient_checkpointing and self.training:
991
+ layer_outputs = self._gradient_checkpointing_func(
992
+ decoder_layer.__call__,
993
+ hidden_states,
994
+ attention_mask,
995
+ position_ids,
996
+ past_key_values,
997
+ output_attentions,
998
+ )
999
+ else:
1000
+ layer_outputs = decoder_layer(
1001
+ hidden_states,
1002
+ attention_mask=attention_mask,
1003
+ position_ids=position_ids,
1004
+ past_key_value=past_key_values,
1005
+ output_attentions=output_attentions,
1006
+ use_cache=use_cache,
1007
+ )
1008
+
1009
+ hidden_states = layer_outputs[0]
1010
+
1011
+ if use_cache:
1012
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1013
+
1014
+ if output_attentions:
1015
+ all_self_attns += (layer_outputs[1],)
1016
+
1017
+ hidden_states = self.final_layernorm(hidden_states)
1018
+
1019
+ # add hidden states from the last decoder layer
1020
+ if output_hidden_states:
1021
+ all_hidden_states += (hidden_states,)
1022
+
1023
+ next_cache = None
1024
+ if use_cache:
1025
+ next_cache = (
1026
+ next_decoder_cache.to_legacy_cache()
1027
+ if use_legacy_cache
1028
+ else next_decoder_cache
1029
+ )
1030
+ if not return_dict:
1031
+ return tuple(
1032
+ v
1033
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1034
+ if v is not None
1035
+ )
1036
+ return BaseModelOutputWithPast(
1037
+ last_hidden_state=hidden_states,
1038
+ past_key_values=next_cache,
1039
+ hidden_states=all_hidden_states,
1040
+ attentions=all_self_attns,
1041
+ )
1042
+
1043
+
1044
+ class PhiForCausalLM(PhiPreTrainedModel):
1045
+ _tied_weights_keys = ["lm_head.weight"]
1046
+
1047
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1048
+ def __init__(self, config):
1049
+ super().__init__(config)
1050
+ self.model = PhiModel(config)
1051
+ self.vocab_size = config.vocab_size
1052
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
1053
+
1054
+ # Initialize weights and apply final processing
1055
+ self.post_init()
1056
+
1057
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1058
+ def get_input_embeddings(self):
1059
+ return self.model.embed_tokens
1060
+
1061
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1062
+ def set_input_embeddings(self, value):
1063
+ self.model.embed_tokens = value
1064
+
1065
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1066
+ def get_output_embeddings(self):
1067
+ return self.lm_head
1068
+
1069
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1070
+ def set_output_embeddings(self, new_embeddings):
1071
+ self.lm_head = new_embeddings
1072
+
1073
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1074
+ def set_decoder(self, decoder):
1075
+ self.model = decoder
1076
+
1077
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1078
+ def get_decoder(self):
1079
+ return self.model
1080
+
1081
+ def forward(
1082
+ self,
1083
+ input_ids: torch.LongTensor = None,
1084
+ attention_mask: Optional[torch.Tensor] = None,
1085
+ position_ids: Optional[torch.LongTensor] = None,
1086
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1087
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1088
+ labels: Optional[torch.LongTensor] = None,
1089
+ use_cache: Optional[bool] = None,
1090
+ output_attentions: Optional[bool] = None,
1091
+ output_hidden_states: Optional[bool] = None,
1092
+ return_dict: Optional[bool] = None,
1093
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1094
+ output_attentions = (
1095
+ output_attentions
1096
+ if output_attentions is not None
1097
+ else self.config.output_attentions
1098
+ )
1099
+ output_hidden_states = (
1100
+ output_hidden_states
1101
+ if output_hidden_states is not None
1102
+ else self.config.output_hidden_states
1103
+ )
1104
+ return_dict = (
1105
+ return_dict if return_dict is not None else self.config.use_return_dict
1106
+ )
1107
+
1108
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1109
+ outputs = self.model(
1110
+ input_ids=input_ids,
1111
+ attention_mask=attention_mask,
1112
+ position_ids=position_ids,
1113
+ past_key_values=past_key_values,
1114
+ inputs_embeds=inputs_embeds,
1115
+ use_cache=use_cache,
1116
+ output_attentions=output_attentions,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ )
1120
+
1121
+ hidden_states = outputs[0]
1122
+ logits = self.lm_head(hidden_states)
1123
+ logits = logits.float()
1124
+
1125
+ loss = None
1126
+ if labels is not None:
1127
+ # Shift so that tokens < n predict n
1128
+ shift_logits = logits[..., :-1, :].contiguous()
1129
+ shift_labels = labels[..., 1:].contiguous()
1130
+ # Flatten the tokens
1131
+ loss_fct = CrossEntropyLoss()
1132
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1133
+ shift_labels = shift_labels.view(-1)
1134
+ # Enable model parallelism
1135
+ shift_labels = shift_labels.to(shift_logits.device)
1136
+ loss = loss_fct(shift_logits, shift_labels)
1137
+
1138
+ if not return_dict:
1139
+ output = (logits,) + outputs[1:]
1140
+ return (loss,) + output if loss is not None else output
1141
+
1142
+ return CausalLMOutputWithPast(
1143
+ loss=loss,
1144
+ logits=logits,
1145
+ past_key_values=outputs.past_key_values,
1146
+ hidden_states=outputs.hidden_states,
1147
+ attentions=outputs.attentions,
1148
+ )
1149
+
1150
+ @staticmethod
1151
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1152
+ def _reorder_cache(past_key_values, beam_idx):
1153
+ reordered_past = ()
1154
+ for layer_past in past_key_values:
1155
+ reordered_past += (
1156
+ tuple(
1157
+ past_state.index_select(0, beam_idx.to(past_state.device))
1158
+ for past_state in layer_past
1159
+ ),
1160
+ )
1161
+ return reordered_past
1162
+
1163
+
1164
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1165
+ def __init__(self, config):
1166
+ super().__init__(config)
1167
+ self.num_labels = config.num_labels
1168
+ self.model = PhiModel(config)
1169
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1170
+
1171
+ # Initialize weights and apply final processing
1172
+ self.post_init()
1173
+
1174
+ def get_input_embeddings(self):
1175
+ return self.model.embed_tokens
1176
+
1177
+ def set_input_embeddings(self, value):
1178
+ self.model.embed_tokens = value
1179
+
1180
+ def forward(
1181
+ self,
1182
+ input_ids: torch.LongTensor = None,
1183
+ attention_mask: Optional[torch.Tensor] = None,
1184
+ position_ids: Optional[torch.LongTensor] = None,
1185
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1186
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1187
+ labels: Optional[torch.LongTensor] = None,
1188
+ use_cache: Optional[bool] = None,
1189
+ output_attentions: Optional[bool] = None,
1190
+ output_hidden_states: Optional[bool] = None,
1191
+ return_dict: Optional[bool] = None,
1192
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1193
+ r"""
1194
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1195
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1196
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1197
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1198
+ """
1199
+ return_dict = (
1200
+ return_dict if return_dict is not None else self.config.use_return_dict
1201
+ )
1202
+
1203
+ model_outputs = self.model(
1204
+ input_ids,
1205
+ attention_mask=attention_mask,
1206
+ position_ids=position_ids,
1207
+ past_key_values=past_key_values,
1208
+ inputs_embeds=inputs_embeds,
1209
+ use_cache=use_cache,
1210
+ output_attentions=output_attentions,
1211
+ output_hidden_states=output_hidden_states,
1212
+ return_dict=return_dict,
1213
+ )
1214
+ hidden_states = model_outputs[0]
1215
+ logits = self.score(hidden_states)
1216
+
1217
+ if input_ids is not None:
1218
+ batch_size = input_ids.shape[0]
1219
+ else:
1220
+ batch_size = inputs_embeds.shape[0]
1221
+
1222
+ if self.config.pad_token_id is None and batch_size != 1:
1223
+ raise ValueError(
1224
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1225
+ )
1226
+ if self.config.pad_token_id is None:
1227
+ sequence_lengths = -1
1228
+ else:
1229
+ if input_ids is not None:
1230
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1231
+ sequence_lengths = (
1232
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1233
+ )
1234
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1235
+ sequence_lengths = sequence_lengths.to(logits.device)
1236
+ else:
1237
+ sequence_lengths = -1
1238
+
1239
+ pooled_logits = logits[
1240
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1241
+ ]
1242
+
1243
+ loss = None
1244
+ if labels is not None:
1245
+ labels = labels.to(logits.device)
1246
+ if self.config.problem_type is None:
1247
+ if self.num_labels == 1:
1248
+ self.config.problem_type = "regression"
1249
+ elif self.num_labels > 1 and (
1250
+ labels.dtype == torch.long or labels.dtype == torch.int
1251
+ ):
1252
+ self.config.problem_type = "single_label_classification"
1253
+ else:
1254
+ self.config.problem_type = "multi_label_classification"
1255
+
1256
+ if self.config.problem_type == "regression":
1257
+ loss_fct = MSELoss()
1258
+ if self.num_labels == 1:
1259
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1260
+ else:
1261
+ loss = loss_fct(pooled_logits, labels)
1262
+ elif self.config.problem_type == "single_label_classification":
1263
+ loss_fct = CrossEntropyLoss()
1264
+ loss = loss_fct(
1265
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1266
+ )
1267
+ elif self.config.problem_type == "multi_label_classification":
1268
+ loss_fct = BCEWithLogitsLoss()
1269
+ loss = loss_fct(pooled_logits, labels)
1270
+ if not return_dict:
1271
+ output = (pooled_logits,) + model_outputs[1:]
1272
+ return ((loss,) + output) if loss is not None else output
1273
+
1274
+ return SequenceClassifierOutputWithPast(
1275
+ loss=loss,
1276
+ logits=pooled_logits,
1277
+ past_key_values=model_outputs.past_key_values,
1278
+ hidden_states=model_outputs.hidden_states,
1279
+ attentions=model_outputs.attentions,
1280
+ )
1281
+
1282
+
1283
+ class PhiForTokenClassification(PhiPreTrainedModel):
1284
+ def __init__(self, config: PhiConfig):
1285
+ super().__init__(config)
1286
+ self.num_labels = config.num_labels
1287
+
1288
+ self.model = PhiModel(config)
1289
+ if (
1290
+ hasattr(config, "classifier_dropout")
1291
+ and config.classifier_dropout is not None
1292
+ ):
1293
+ classifier_dropout = config.classifier_dropout
1294
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1295
+ classifier_dropout = config.hidden_dropout
1296
+ else:
1297
+ classifier_dropout = 0.1
1298
+ self.dropout = nn.Dropout(classifier_dropout)
1299
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1300
+
1301
+ # Initialize weights and apply final processing
1302
+ self.post_init()
1303
+
1304
+ def forward(
1305
+ self,
1306
+ input_ids: Optional[torch.LongTensor] = None,
1307
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1308
+ attention_mask: Optional[torch.Tensor] = None,
1309
+ inputs_embeds: Optional[torch.Tensor] = None,
1310
+ labels: Optional[torch.Tensor] = None,
1311
+ use_cache: Optional[bool] = None,
1312
+ output_attentions: Optional[bool] = None,
1313
+ output_hidden_states: Optional[bool] = None,
1314
+ return_dict: Optional[bool] = None,
1315
+ **deprecated_arguments,
1316
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1317
+ r"""
1318
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1319
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1320
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1321
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1322
+ """
1323
+ return_dict = (
1324
+ return_dict if return_dict is not None else self.config.use_return_dict
1325
+ )
1326
+
1327
+ model_outputs = self.model(
1328
+ input_ids,
1329
+ past_key_values=past_key_values,
1330
+ attention_mask=attention_mask,
1331
+ inputs_embeds=inputs_embeds,
1332
+ use_cache=use_cache,
1333
+ output_attentions=output_attentions,
1334
+ output_hidden_states=output_hidden_states,
1335
+ return_dict=return_dict,
1336
+ )
1337
+
1338
+ hidden_states = model_outputs[0]
1339
+ hidden_states = self.dropout(hidden_states)
1340
+ logits = self.classifier(hidden_states)
1341
+
1342
+ loss = None
1343
+ if labels is not None:
1344
+ # move labels to correct device to enable model parallelism
1345
+ labels = labels.to(logits.device)
1346
+ batch_size, seq_length = labels.shape
1347
+ loss_fct = CrossEntropyLoss()
1348
+ loss = loss_fct(
1349
+ logits.view(batch_size * seq_length, self.num_labels),
1350
+ labels.view(batch_size * seq_length),
1351
+ )
1352
+
1353
+ if not return_dict:
1354
+ output = (logits,) + model_outputs[2:]
1355
+ return ((loss,) + output) if loss is not None else output
1356
+
1357
+ return TokenClassifierOutput(
1358
+ loss=loss,
1359
+ logits=logits,
1360
+ hidden_states=model_outputs.hidden_states,
1361
+ attentions=model_outputs.attentions,
1362
+ )
1363
 
1364
 
1365
  @dataclass
 
1372
  image_features: Optional[torch.FloatTensor] = None
1373
 
1374
 
1375
+ class SiglipVisionEncoder(nn.Module):
1376
+ def __init__(self, config: LlavaConfig):
1377
+ super().__init__()
1378
+ self.vision_tower = SiglipVisionModel(config.vision_config)
1379
+
1380
+ self.coord_embed = nn.Sequential(
1381
+ nn.Linear(2, config.vision_embed_dim),
1382
+ nn.GELU(),
1383
+ nn.Linear(config.vision_embed_dim, config.vision_embed_dim),
1384
+ )
1385
+
1386
+ self.num_tokens = 728
1387
+
1388
+ def feature_select(self, image_forward_outs, coord_feature, num_tokens=None):
1389
+ image_features = image_forward_outs
1390
+ image_features = image_features[:, 1:]
1391
+ if num_tokens is None:
1392
+ num_tokens = self.num_tokens
1393
+ split_size = int(num_tokens / image_features.shape[0])
1394
+ sum = 0
1395
+ output_list = []
1396
+ for i in range(image_features.shape[0]):
1397
+ if i == image_features.shape[0] - 1:
1398
+ size = num_tokens - sum
1399
+ else:
1400
+ size = split_size
1401
+ sum += size
1402
+ chunk_output = image_features[i, -size:, :]
1403
+ chunk_output = chunk_output + coord_feature[i]
1404
+ output_list.append(chunk_output)
1405
+ image_features = torch.cat(output_list)
1406
+ return image_features
1407
+
1408
+ def process_image_chunks(self, image_tensor, coord_tensor, num_tokens=None):
1409
+ if image_tensor.shape[0] > 50:
1410
+ image_forward_out = []
1411
+ for i in range(0, image_tensor.shape[0], 50):
1412
+ part_forward_out = self.vision_tower(
1413
+ image_tensor[i : i + 50], output_hidden_states=True
1414
+ ).hidden_states[-1]
1415
+ image_forward_out.append(part_forward_out)
1416
+ image_forward_out = torch.cat(image_forward_out, dim=0)
1417
+ else:
1418
+ image_forward_out = self.vision_tower(
1419
+ image_tensor, output_hidden_states=True
1420
+ ).hidden_states[-1]
1421
+ coord_feature = self.coord_embed(coord_tensor)
1422
+ if len(coord_feature.shape) == 1:
1423
+ coord_feature = coord_feature.unsqueeze(0)
1424
+ image_feature = self.feature_select(
1425
+ image_forward_out, coord_feature, num_tokens
1426
+ ).to(image_tensor.dtype)
1427
+ return image_feature
1428
+
1429
+ def forward(
1430
+ self, images: List[torch.Tensor], coords: List[torch.Tensor], num_tokens=None
1431
+ ):
1432
+ image_features = []
1433
+ for i, image in enumerate(images):
1434
+ image_feature = self.process_image_chunks(image, coords[i], num_tokens)
1435
+ image_features.append(image_feature)
1436
+ image_features = torch.stack(image_features)
1437
+ return image_features
1438
+
1439
+
1440
  class LlavaMultiModalProjector(nn.Module):
1441
  def __init__(self, config: LlavaConfig):
1442
  super().__init__()
1443
 
1444
  self.linear_1 = nn.Linear(
1445
  config.vision_embed_dim,
1446
+ config.text_config.hidden_size,
1447
  bias=True,
1448
  )
1449
  self.act = nn.GELU()
1450
  self.linear_2 = nn.Linear(
1451
+ config.text_config.hidden_size,
1452
+ config.text_config.hidden_size,
1453
  bias=True,
1454
  )
 
1455
 
1456
  def forward(self, image_features):
1457
  hidden_states = self.linear_1(image_features)
 
1483
  return self.language_model._supports_sdpa
1484
 
1485
 
1486
+ class LlavaForCausalLM(LlavaPreTrainedModel):
1487
  def __init__(self, config: LlavaConfig):
1488
  super().__init__(config)
1489
+ self.vision_model = SiglipVisionEncoder(config)
 
1490
 
1491
  self.multi_modal_projector = LlavaMultiModalProjector(config)
1492
  self.vocab_size = config.vocab_size
 
1657
  attention_mask,
1658
  position_ids,
1659
  )
1660
+ else:
1661
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1662
+ # generation with cache
1663
+ if past_key_values is not None and image_features is not None and input_ids.shape[1] == 1:
1664
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1665
+ # that are set to 0
1666
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1667
+
1668
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1669
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
1670
+
1671
+ # Get the target length
1672
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
1673
+
1674
+ extended_attention_mask = torch.ones(
1675
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
1676
+ dtype=attention_mask.dtype,
1677
+ device=attention_mask.device,
1678
+ )
1679
+
1680
+ # Filter out only the tokens that can be un-attended, this can happen
1681
+ # if one uses Llava + Fused modules where the cache on the
1682
+ # first iteration is already big enough, or if one passes custom cache
1683
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
1684
+ new_batch_index = batch_index[valid_indices]
1685
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1686
+
1687
+ # Zero-out the places where we don't need to attend
1688
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
1689
+
1690
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
1691
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1692
 
1693
  outputs = self.language_model(
1694
  input_ids=None,
 
1704
 
1705
  logits = outputs[0]
1706
 
 
1707
  if not return_dict:
1708
  output = (logits,) + outputs[1:]
1709
  return output
 
1725
  image_features=None,
1726
  **kwargs,
1727
  ):
1728
+ if past_key_values is not None:
1729
+ if isinstance(past_key_values, Cache):
1730
+ cache_length = past_key_values.get_seq_length()
1731
+ past_length = past_key_values.seen_tokens
1732
+ max_cache_length = past_key_values.get_max_length()
1733
+ else:
1734
+ cache_length = past_length = past_key_values[0][0].shape[2]
1735
+ max_cache_length = None
1736
+
1737
+ # Keep only the unprocessed tokens:
1738
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1739
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1740
+ # input)
1741
+ if (
1742
+ attention_mask is not None
1743
+ and attention_mask.shape[1] > input_ids.shape[1]
1744
+ ):
1745
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1746
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1747
+ # input_ids based on the past_length.
1748
+ elif past_length < input_ids.shape[1]+image_features.shape[1]-1:
1749
+ past_length -= image_features.shape[1]-1
1750
+ input_ids = input_ids[:, past_length:]
1751
+ attention_mask = attention_mask[:, past_length:]
1752
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1753
+
1754
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1755
+ if (
1756
+ max_cache_length is not None
1757
+ and attention_mask is not None
1758
+ and cache_length + input_ids.shape[1] > max_cache_length
1759
+ ):
1760
+ attention_mask = attention_mask[:, -max_cache_length:]
1761
+
1762
+ position_ids = kwargs.get("position_ids", None)
1763
+ if attention_mask is not None and position_ids is None:
1764
+ # create position_ids on the fly for batch generation
1765
+ position_ids = attention_mask.long().cumsum(-1) - 1
1766
+ position_ids.masked_fill_(attention_mask == 0, 1)
1767
+ if past_key_values:
1768
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1769
 
1770
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1771
  if inputs_embeds is not None and past_key_values is None:
1772
  model_inputs = {"inputs_embeds": inputs_embeds}
1773
  else:
 
1775
 
1776
  model_inputs.update(
1777
  {
1778
+ "position_ids": position_ids,
1779
  "past_key_values": past_key_values,
1780
  "use_cache": kwargs.get("use_cache"),
1781
  "attention_mask": attention_mask,