Update modeling_bunny_llama.py
Browse files- modeling_bunny_llama.py +6 -0
modeling_bunny_llama.py
CHANGED
@@ -702,11 +702,17 @@ class BunnyMetaForCausalLM(ABC):
|
|
702 |
if labels is None:
|
703 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
704 |
|
|
|
|
|
705 |
# remove the padding using attention_mask -- TODO: double check
|
706 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
707 |
zip(input_ids, attention_mask)]
|
708 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
709 |
|
|
|
|
|
|
|
|
|
710 |
new_input_embeds = []
|
711 |
new_labels = []
|
712 |
cur_image_idx = 0
|
|
|
702 |
if labels is None:
|
703 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
704 |
|
705 |
+
input_ids_temp = input_ids # points to the actual input_ids tensor
|
706 |
+
|
707 |
# remove the padding using attention_mask -- TODO: double check
|
708 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
709 |
zip(input_ids, attention_mask)]
|
710 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
711 |
|
712 |
+
# -- TODO: better implementation?
|
713 |
+
# replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
|
714 |
+
input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
|
715 |
+
|
716 |
new_input_embeds = []
|
717 |
new_labels = []
|
718 |
cur_image_idx = 0
|