update modeling file
Browse files- modeling_bamboo.py +49 -46
modeling_bamboo.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# coding=utf-8
|
2 |
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
|
|
3 |
#
|
4 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
# and OPT implementations in this library. It has been modified from its
|
@@ -72,11 +73,11 @@ def _get_unpad_data(attention_mask):
|
|
72 |
)
|
73 |
|
74 |
|
75 |
-
# Copied from transformers.models.
|
76 |
-
class
|
77 |
def __init__(self, hidden_size, eps=1e-6):
|
78 |
"""
|
79 |
-
|
80 |
"""
|
81 |
super().__init__()
|
82 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
@@ -91,8 +92,9 @@ class MistralRMSNorm(nn.Module):
|
|
91 |
|
92 |
|
93 |
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
|
|
94 |
# TODO @Arthur no longer copied from LLama after static cache
|
95 |
-
class
|
96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
97 |
super().__init__()
|
98 |
|
@@ -166,7 +168,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
166 |
return q_embed, k_embed
|
167 |
|
168 |
|
169 |
-
class
|
170 |
def __init__(self, config):
|
171 |
super().__init__()
|
172 |
self.config = config
|
@@ -194,7 +196,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
194 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
195 |
|
196 |
|
197 |
-
|
|
|
198 |
"""
|
199 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
200 |
and "Generating Long Sequences with Sparse Transformers".
|
@@ -231,7 +234,7 @@ class MistralAttention(nn.Module):
|
|
231 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
232 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
233 |
|
234 |
-
self.rotary_emb =
|
235 |
self.head_dim,
|
236 |
max_position_embeddings=self.max_position_embeddings,
|
237 |
base=self.rope_theta,
|
@@ -322,9 +325,9 @@ class MistralAttention(nn.Module):
|
|
322 |
return attn_output, attn_weights, past_key_value
|
323 |
|
324 |
|
325 |
-
class
|
326 |
"""
|
327 |
-
|
328 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
329 |
flash attention and deal with padding tokens in case the input contains any of them.
|
330 |
"""
|
@@ -618,14 +621,14 @@ class MistralFlashAttention2(MistralAttention):
|
|
618 |
|
619 |
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
620 |
# TODO @Arthur no longer copied from LLama after static cache
|
621 |
-
class
|
622 |
"""
|
623 |
-
|
624 |
-
`
|
625 |
SDPA API.
|
626 |
"""
|
627 |
|
628 |
-
# Adapted from
|
629 |
def forward(
|
630 |
self,
|
631 |
hidden_states: torch.Tensor,
|
@@ -638,7 +641,7 @@ class MistralSdpaAttention(MistralAttention):
|
|
638 |
if output_attentions:
|
639 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
640 |
logger.warning_once(
|
641 |
-
"
|
642 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
643 |
)
|
644 |
return super().forward(
|
@@ -705,23 +708,23 @@ class MistralSdpaAttention(MistralAttention):
|
|
705 |
return attn_output, None, past_key_value
|
706 |
|
707 |
|
708 |
-
|
709 |
-
"eager":
|
710 |
-
"flash_attention_2":
|
711 |
-
"sdpa":
|
712 |
}
|
713 |
|
714 |
|
715 |
-
class
|
716 |
def __init__(self, config: BambooConfig, layer_idx: int):
|
717 |
super().__init__()
|
718 |
self.hidden_size = config.hidden_size
|
719 |
|
720 |
-
self.self_attn =
|
721 |
|
722 |
-
self.mlp =
|
723 |
-
self.input_layernorm =
|
724 |
-
self.post_attention_layernorm =
|
725 |
|
726 |
def forward(
|
727 |
self,
|
@@ -783,7 +786,7 @@ class MistralDecoderLayer(nn.Module):
|
|
783 |
return outputs
|
784 |
|
785 |
|
786 |
-
|
787 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
788 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
789 |
etc.)
|
@@ -801,14 +804,14 @@ MISTRAL_START_DOCSTRING = r"""
|
|
801 |
|
802 |
|
803 |
@add_start_docstrings(
|
804 |
-
"The bare
|
805 |
-
|
806 |
)
|
807 |
-
class
|
808 |
config_class = BambooConfig
|
809 |
base_model_prefix = "model"
|
810 |
supports_gradient_checkpointing = True
|
811 |
-
_no_split_modules = ["
|
812 |
_skip_keys_device_placement = "past_key_values"
|
813 |
_supports_flash_attn_2 = True
|
814 |
_supports_sdpa = True
|
@@ -826,7 +829,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|
826 |
module.weight.data[module.padding_idx].zero_()
|
827 |
|
828 |
|
829 |
-
|
830 |
Args:
|
831 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
832 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
@@ -897,12 +900,12 @@ MISTRAL_INPUTS_DOCSTRING = r"""
|
|
897 |
|
898 |
|
899 |
@add_start_docstrings(
|
900 |
-
"The bare
|
901 |
-
|
902 |
)
|
903 |
-
class
|
904 |
"""
|
905 |
-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`
|
906 |
|
907 |
Args:
|
908 |
config: BambooConfig
|
@@ -915,10 +918,10 @@ class MistralModel(MistralPreTrainedModel):
|
|
915 |
|
916 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
917 |
self.layers = nn.ModuleList(
|
918 |
-
[
|
919 |
)
|
920 |
self._attn_implementation = config._attn_implementation
|
921 |
-
self.norm =
|
922 |
|
923 |
self.gradient_checkpointing = False
|
924 |
# Initialize weights and apply final processing
|
@@ -930,7 +933,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
930 |
def set_input_embeddings(self, value):
|
931 |
self.embed_tokens = value
|
932 |
|
933 |
-
@add_start_docstrings_to_model_forward(
|
934 |
def forward(
|
935 |
self,
|
936 |
input_ids: torch.LongTensor = None,
|
@@ -993,7 +996,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
993 |
if is_padding_right:
|
994 |
raise ValueError(
|
995 |
"You are attempting to perform batched generation with padding_side='right'"
|
996 |
-
" this may lead to unexpected behaviour for Flash Attention version of
|
997 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
998 |
)
|
999 |
|
@@ -1078,12 +1081,12 @@ class MistralModel(MistralPreTrainedModel):
|
|
1078 |
)
|
1079 |
|
1080 |
|
1081 |
-
class BambooForCausalLM(
|
1082 |
_tied_weights_keys = ["lm_head.weight"]
|
1083 |
|
1084 |
def __init__(self, config):
|
1085 |
super().__init__(config)
|
1086 |
-
self.model =
|
1087 |
self.vocab_size = config.vocab_size
|
1088 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1089 |
|
@@ -1108,7 +1111,7 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1108 |
def get_decoder(self):
|
1109 |
return self.model
|
1110 |
|
1111 |
-
@add_start_docstrings_to_model_forward(
|
1112 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1113 |
def forward(
|
1114 |
self,
|
@@ -1266,9 +1269,9 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1266 |
|
1267 |
@add_start_docstrings(
|
1268 |
"""
|
1269 |
-
The
|
1270 |
|
1271 |
-
[`
|
1272 |
(e.g. GPT-2) do.
|
1273 |
|
1274 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
@@ -1277,14 +1280,14 @@ class BambooForCausalLM(MistralPreTrainedModel):
|
|
1277 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1278 |
each row of the batch).
|
1279 |
""",
|
1280 |
-
|
1281 |
)
|
1282 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1283 |
-
class
|
1284 |
def __init__(self, config):
|
1285 |
super().__init__(config)
|
1286 |
self.num_labels = config.num_labels
|
1287 |
-
self.model =
|
1288 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1289 |
|
1290 |
# Initialize weights and apply final processing
|
@@ -1296,7 +1299,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
|
1296 |
def set_input_embeddings(self, value):
|
1297 |
self.model.embed_tokens = value
|
1298 |
|
1299 |
-
@add_start_docstrings_to_model_forward(
|
1300 |
def forward(
|
1301 |
self,
|
1302 |
input_ids: torch.LongTensor = None,
|
|
|
1 |
# coding=utf-8
|
2 |
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Copyright 2024 SJTU-IPADS AI and the HuggingFace Inc. team. All rights reserved.
|
4 |
#
|
5 |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
# and OPT implementations in this library. It has been modified from its
|
|
|
73 |
)
|
74 |
|
75 |
|
76 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRMSNorm with Mistral->Bamboo
|
77 |
+
class BambooRMSNorm(nn.Module):
|
78 |
def __init__(self, hidden_size, eps=1e-6):
|
79 |
"""
|
80 |
+
BambooRMSNorm is equivalent to T5LayerNorm
|
81 |
"""
|
82 |
super().__init__()
|
83 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
92 |
|
93 |
|
94 |
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
95 |
+
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Bamboo
|
96 |
# TODO @Arthur no longer copied from LLama after static cache
|
97 |
+
class BambooRotaryEmbedding(nn.Module):
|
98 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
99 |
super().__init__()
|
100 |
|
|
|
168 |
return q_embed, k_embed
|
169 |
|
170 |
|
171 |
+
class BambooMLP(nn.Module):
|
172 |
def __init__(self, config):
|
173 |
super().__init__()
|
174 |
self.config = config
|
|
|
196 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
197 |
|
198 |
|
199 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention
|
200 |
+
class BambooAttention(nn.Module):
|
201 |
"""
|
202 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
203 |
and "Generating Long Sequences with Sparse Transformers".
|
|
|
234 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
235 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
236 |
|
237 |
+
self.rotary_emb = BambooRotaryEmbedding(
|
238 |
self.head_dim,
|
239 |
max_position_embeddings=self.max_position_embeddings,
|
240 |
base=self.rope_theta,
|
|
|
325 |
return attn_output, attn_weights, past_key_value
|
326 |
|
327 |
|
328 |
+
class BambooFlashAttention2(BambooAttention):
|
329 |
"""
|
330 |
+
BAMBOO flash attention module. This module inherits from `BambooAttention` as the weights of the module stays
|
331 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
332 |
flash attention and deal with padding tokens in case the input contains any of them.
|
333 |
"""
|
|
|
621 |
|
622 |
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
623 |
# TODO @Arthur no longer copied from LLama after static cache
|
624 |
+
class BambooSdpaAttention(BambooAttention):
|
625 |
"""
|
626 |
+
Bamboo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
627 |
+
`BambooAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
628 |
SDPA API.
|
629 |
"""
|
630 |
|
631 |
+
# Adapted from BambooAttention.forward
|
632 |
def forward(
|
633 |
self,
|
634 |
hidden_states: torch.Tensor,
|
|
|
641 |
if output_attentions:
|
642 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
643 |
logger.warning_once(
|
644 |
+
"BambooModel is using BambooSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
645 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
646 |
)
|
647 |
return super().forward(
|
|
|
708 |
return attn_output, None, past_key_value
|
709 |
|
710 |
|
711 |
+
BAMBOO_ATTENTION_CLASSES = {
|
712 |
+
"eager": BambooAttention,
|
713 |
+
"flash_attention_2": BambooFlashAttention2,
|
714 |
+
"sdpa": BambooSdpaAttention,
|
715 |
}
|
716 |
|
717 |
|
718 |
+
class BambooDecoderLayer(nn.Module):
|
719 |
def __init__(self, config: BambooConfig, layer_idx: int):
|
720 |
super().__init__()
|
721 |
self.hidden_size = config.hidden_size
|
722 |
|
723 |
+
self.self_attn = BAMBOO_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
724 |
|
725 |
+
self.mlp = BambooMLP(config)
|
726 |
+
self.input_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
727 |
+
self.post_attention_layernorm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
728 |
|
729 |
def forward(
|
730 |
self,
|
|
|
786 |
return outputs
|
787 |
|
788 |
|
789 |
+
BAMBOO_START_DOCSTRING = r"""
|
790 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
791 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
792 |
etc.)
|
|
|
804 |
|
805 |
|
806 |
@add_start_docstrings(
|
807 |
+
"The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
|
808 |
+
BAMBOO_START_DOCSTRING,
|
809 |
)
|
810 |
+
class BambooPreTrainedModel(PreTrainedModel):
|
811 |
config_class = BambooConfig
|
812 |
base_model_prefix = "model"
|
813 |
supports_gradient_checkpointing = True
|
814 |
+
_no_split_modules = ["BambooDecoderLayer"]
|
815 |
_skip_keys_device_placement = "past_key_values"
|
816 |
_supports_flash_attn_2 = True
|
817 |
_supports_sdpa = True
|
|
|
829 |
module.weight.data[module.padding_idx].zero_()
|
830 |
|
831 |
|
832 |
+
BAMBOO_INPUTS_DOCSTRING = r"""
|
833 |
Args:
|
834 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
835 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
|
900 |
|
901 |
|
902 |
@add_start_docstrings(
|
903 |
+
"The bare Bamboo Model outputting raw hidden-states without any specific head on top.",
|
904 |
+
BAMBOO_START_DOCSTRING,
|
905 |
)
|
906 |
+
class BambooModel(BambooPreTrainedModel):
|
907 |
"""
|
908 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambooDecoderLayer`]
|
909 |
|
910 |
Args:
|
911 |
config: BambooConfig
|
|
|
918 |
|
919 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
920 |
self.layers = nn.ModuleList(
|
921 |
+
[BambooDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
922 |
)
|
923 |
self._attn_implementation = config._attn_implementation
|
924 |
+
self.norm = BambooRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
925 |
|
926 |
self.gradient_checkpointing = False
|
927 |
# Initialize weights and apply final processing
|
|
|
933 |
def set_input_embeddings(self, value):
|
934 |
self.embed_tokens = value
|
935 |
|
936 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
937 |
def forward(
|
938 |
self,
|
939 |
input_ids: torch.LongTensor = None,
|
|
|
996 |
if is_padding_right:
|
997 |
raise ValueError(
|
998 |
"You are attempting to perform batched generation with padding_side='right'"
|
999 |
+
" this may lead to unexpected behaviour for Flash Attention version of Bamboo. Make sure to "
|
1000 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1001 |
)
|
1002 |
|
|
|
1081 |
)
|
1082 |
|
1083 |
|
1084 |
+
class BambooForCausalLM(BambooPreTrainedModel):
|
1085 |
_tied_weights_keys = ["lm_head.weight"]
|
1086 |
|
1087 |
def __init__(self, config):
|
1088 |
super().__init__(config)
|
1089 |
+
self.model = BambooModel(config)
|
1090 |
self.vocab_size = config.vocab_size
|
1091 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1092 |
|
|
|
1111 |
def get_decoder(self):
|
1112 |
return self.model
|
1113 |
|
1114 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
1115 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1116 |
def forward(
|
1117 |
self,
|
|
|
1269 |
|
1270 |
@add_start_docstrings(
|
1271 |
"""
|
1272 |
+
The Bamboo Model transformer with a sequence classification head on top (linear layer).
|
1273 |
|
1274 |
+
[`BambooForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1275 |
(e.g. GPT-2) do.
|
1276 |
|
1277 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
1280 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1281 |
each row of the batch).
|
1282 |
""",
|
1283 |
+
BAMBOO_START_DOCSTRING,
|
1284 |
)
|
1285 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1286 |
+
class BambooForSequenceClassification(BambooPreTrainedModel):
|
1287 |
def __init__(self, config):
|
1288 |
super().__init__(config)
|
1289 |
self.num_labels = config.num_labels
|
1290 |
+
self.model = BambooModel(config)
|
1291 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1292 |
|
1293 |
# Initialize weights and apply final processing
|
|
|
1299 |
def set_input_embeddings(self, value):
|
1300 |
self.model.embed_tokens = value
|
1301 |
|
1302 |
+
@add_start_docstrings_to_model_forward(BAMBOO_INPUTS_DOCSTRING)
|
1303 |
def forward(
|
1304 |
self,
|
1305 |
input_ids: torch.LongTensor = None,
|