BAAI
/

BoyaWu10 commited on
Commit
8d9c671
1 Parent(s): 45417f7

Update modeling_bunny_llama.py

Browse files
Files changed (1) hide show
  1. modeling_bunny_llama.py +6 -0
modeling_bunny_llama.py CHANGED
@@ -702,11 +702,17 @@ class BunnyMetaForCausalLM(ABC):
702
  if labels is None:
703
  labels = torch.full_like(input_ids, IGNORE_INDEX)
704
 
 
 
705
  # remove the padding using attention_mask -- TODO: double check
706
  input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
707
  zip(input_ids, attention_mask)]
708
  labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
709
 
 
 
 
 
710
  new_input_embeds = []
711
  new_labels = []
712
  cur_image_idx = 0
 
702
  if labels is None:
703
  labels = torch.full_like(input_ids, IGNORE_INDEX)
704
 
705
+ input_ids_temp = input_ids # points to the actual input_ids tensor
706
+
707
  # remove the padding using attention_mask -- TODO: double check
708
  input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
709
  zip(input_ids, attention_mask)]
710
  labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
711
 
712
+ # -- TODO: better implementation?
713
+ # replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
714
+ input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
715
+
716
  new_input_embeds = []
717
  new_labels = []
718
  cur_image_idx = 0