visheratin commited on
Commit
ed8f61a
1 Parent(s): 304a0a4

Update model files

Browse files
Files changed (1) hide show
  1. modeling_llava.py +13 -118
modeling_llava.py CHANGED
@@ -9,8 +9,7 @@ from torch import nn
9
  from transformers import PreTrainedModel
10
  from transformers.modeling_outputs import ModelOutput
11
 
12
- from modeling_phi import PhiForCausalLM, InferenceParams
13
- from processing_llava import OpenCLIPImageProcessor
14
  from configuration_llava import LlavaConfig
15
  from open_clip import create_model
16
 
@@ -22,7 +21,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
22
  past_key_values: Optional[List[torch.FloatTensor]] = None
23
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
  attentions: Optional[Tuple[torch.FloatTensor]] = None
25
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
26
 
27
 
28
  class LlavaMultiModalProjector(nn.Module):
@@ -214,14 +213,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
214
  def forward(
215
  self,
216
  input_ids: torch.LongTensor = None,
217
- pixel_values: torch.FloatTensor = None,
218
  attention_mask: Optional[torch.Tensor] = None,
219
  position_ids: Optional[torch.LongTensor] = None,
220
  past_key_values: Optional[List[torch.FloatTensor]] = None,
221
  inputs_embeds: Optional[torch.FloatTensor] = None,
222
- vision_feature_layer: Optional[int] = None,
223
- vision_feature_select_strategy: Optional[str] = None,
224
- labels: Optional[torch.LongTensor] = None,
225
  use_cache: Optional[bool] = None,
226
  output_attentions: Optional[bool] = None,
227
  output_hidden_states: Optional[bool] = None,
@@ -242,14 +238,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
242
  )
243
 
244
  if inputs_embeds is None:
245
- # 1. Extra the input embeddings
246
  inputs_embeds = self.get_input_embeddings()(input_ids)
247
-
248
- # 2. Merge text and images
249
- if pixel_values is not None and input_ids.shape[1] != 1:
250
- image_outputs = self.vision_model(pixel_values)
251
-
252
- image_features = self.multi_modal_projector(image_outputs)
253
  (
254
  inputs_embeds,
255
  attention_mask,
@@ -261,46 +251,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
261
  attention_mask,
262
  position_ids,
263
  )
264
- # if labels is None:
265
- # labels = torch.full_like(
266
- # attention_mask, self.config.ignore_index
267
- # ).to(torch.long)
268
- else:
269
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
270
- # generation with cache
271
- if (
272
- past_key_values is not None
273
- and pixel_values is not None
274
- and input_ids.shape[1] == 1
275
- ):
276
- # Retrieve the first layer to inspect the logits and mask out the hidden states
277
- # that are set to 0
278
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
279
-
280
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
281
- batch_index, non_attended_tokens = torch.where(
282
- first_layer_past_key_value.float().sum(-2) == 0
283
- )
284
-
285
- # Get the target length
286
- target_seqlen = first_layer_past_key_value.shape[-1] + 1
287
-
288
- extended_attention_mask = torch.ones(
289
- (
290
- attention_mask.shape[0],
291
- target_seqlen - attention_mask.shape[1],
292
- ),
293
- dtype=attention_mask.dtype,
294
- device=attention_mask.device,
295
- )
296
-
297
- # Zero-out the places where we don't need to attend
298
- extended_attention_mask[batch_index, non_attended_tokens] = 0
299
-
300
- attention_mask = torch.cat(
301
- (attention_mask, extended_attention_mask), dim=1
302
- )
303
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
304
 
305
  outputs = self.language_model(
306
  input_ids=None,
@@ -316,37 +266,17 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
316
 
317
  logits = outputs[0]
318
 
319
- loss = None
320
- if labels is not None:
321
- # Shift so that tokens < n predict n
322
- if attention_mask is not None:
323
- shift_attention_mask = attention_mask[..., 1:]
324
- shift_logits = logits[..., :-1, :][
325
- shift_attention_mask.to(logits.device) != 0
326
- ].contiguous()
327
- shift_labels = labels[..., 1:][
328
- shift_attention_mask.to(labels.device) != 0
329
- ].contiguous()
330
- else:
331
- shift_logits = logits[..., :-1, :].contiguous()
332
- shift_labels = labels[..., 1:].contiguous()
333
- # Flatten the tokens
334
- loss_fct = nn.CrossEntropyLoss()
335
- loss = loss_fct(
336
- shift_logits.view(-1, shift_logits.size(-1)),
337
- shift_labels.view(-1).to(shift_logits.device),
338
- )
339
 
340
  if not return_dict:
341
  output = (logits,) + outputs[1:]
342
- return (loss,) + output if loss is not None else output
343
 
344
  return LlavaCausalLMOutputWithPast(
345
- loss=loss,
346
  logits=logits,
347
  past_key_values=outputs.past_key_values,
348
  hidden_states=outputs.hidden_states,
349
  attentions=outputs.attentions,
 
350
  )
351
 
352
  def prepare_inputs_for_generation(
@@ -354,49 +284,15 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
354
  input_ids,
355
  past_key_values=None,
356
  inputs_embeds=None,
357
- pixel_values=None,
358
  attention_mask=None,
 
359
  **kwargs,
360
  ):
361
- if past_key_values is not None:
362
- if isinstance(past_key_values, InferenceParams):
363
- cache_length = past_key_values.max_seqlen
364
- past_length = past_key_values.seqlen_offset
365
- else:
366
- cache_length = past_length = past_key_values[0][0].shape[2]
367
-
368
- # Keep only the unprocessed tokens:
369
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
370
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
371
- # input)
372
- if (
373
- attention_mask is not None
374
- and attention_mask.shape[1] > input_ids.shape[1]
375
- ):
376
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
377
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
378
- # input_ids based on the past_length.
379
- elif past_length < input_ids.shape[1]:
380
- input_ids = input_ids[:, past_length:]
381
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
382
- elif self.config.image_token_index in input_ids:
383
- input_ids = input_ids[:, input_ids.shape[1] - 1 :]
384
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
385
- # older attention values, as their corresponding values are not part of the input.
386
- if cache_length < past_length and attention_mask is not None:
387
- attention_mask = attention_mask[
388
- :, -(cache_length + input_ids.shape[1]) :
389
- ]
390
-
391
- position_ids = kwargs.get("position_ids", None)
392
- if attention_mask is not None and position_ids is None:
393
- # create position_ids on the fly for batch generation
394
- position_ids = attention_mask.long().cumsum(-1) - 1
395
- position_ids.masked_fill_(attention_mask == 0, 1)
396
- if past_key_values:
397
- position_ids = position_ids[:, -input_ids.shape[1] :]
398
-
399
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
400
  if inputs_embeds is not None and past_key_values is None:
401
  model_inputs = {"inputs_embeds": inputs_embeds}
402
  else:
@@ -404,11 +300,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
404
 
405
  model_inputs.update(
406
  {
407
- "position_ids": position_ids,
408
  "past_key_values": past_key_values,
409
  "use_cache": kwargs.get("use_cache"),
410
  "attention_mask": attention_mask,
411
- "pixel_values": pixel_values,
412
  }
413
  )
414
  return model_inputs
 
9
  from transformers import PreTrainedModel
10
  from transformers.modeling_outputs import ModelOutput
11
 
12
+ from modeling_phi import PhiForCausalLM
 
13
  from configuration_llava import LlavaConfig
14
  from open_clip import create_model
15
 
 
21
  past_key_values: Optional[List[torch.FloatTensor]] = None
22
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
23
  attentions: Optional[Tuple[torch.FloatTensor]] = None
24
+ image_features: Optional[torch.FloatTensor] = None
25
 
26
 
27
  class LlavaMultiModalProjector(nn.Module):
 
213
  def forward(
214
  self,
215
  input_ids: torch.LongTensor = None,
216
+ image_features: torch.FloatTensor = None,
217
  attention_mask: Optional[torch.Tensor] = None,
218
  position_ids: Optional[torch.LongTensor] = None,
219
  past_key_values: Optional[List[torch.FloatTensor]] = None,
220
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
 
 
221
  use_cache: Optional[bool] = None,
222
  output_attentions: Optional[bool] = None,
223
  output_hidden_states: Optional[bool] = None,
 
238
  )
239
 
240
  if inputs_embeds is None:
 
241
  inputs_embeds = self.get_input_embeddings()(input_ids)
242
+ if image_features is not None and input_ids.shape[1] != 1:
 
 
 
 
 
243
  (
244
  inputs_embeds,
245
  attention_mask,
 
251
  attention_mask,
252
  position_ids,
253
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  outputs = self.language_model(
256
  input_ids=None,
 
266
 
267
  logits = outputs[0]
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  if not return_dict:
271
  output = (logits,) + outputs[1:]
272
+ return output
273
 
274
  return LlavaCausalLMOutputWithPast(
 
275
  logits=logits,
276
  past_key_values=outputs.past_key_values,
277
  hidden_states=outputs.hidden_states,
278
  attentions=outputs.attentions,
279
+ image_features=image_features,
280
  )
281
 
282
  def prepare_inputs_for_generation(
 
284
  input_ids,
285
  past_key_values=None,
286
  inputs_embeds=None,
 
287
  attention_mask=None,
288
+ image_features=None,
289
  **kwargs,
290
  ):
291
+ res = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, attention_mask, **kwargs)
292
+ input_ids = res["input_ids"]
293
+ past_key_values = res["past_key_values"]
294
+ attention_mask = res["attention_mask"]
295
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  if inputs_embeds is not None and past_key_values is None:
297
  model_inputs = {"inputs_embeds": inputs_embeds}
298
  else:
 
300
 
301
  model_inputs.update(
302
  {
 
303
  "past_key_values": past_key_values,
304
  "use_cache": kwargs.get("use_cache"),
305
  "attention_mask": attention_mask,
306
+ "image_features": image_features,
307
  }
308
  )
309
  return model_inputs