Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +4 -4
modeling_Llamoe.py
CHANGED
@@ -467,7 +467,7 @@ class LlamoeAttention(nn.Module):
|
|
467 |
return attn_output, attn_weights, past_key_value
|
468 |
|
469 |
|
470 |
-
class LlamoeFlashAttention2(
|
471 |
"""
|
472 |
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
473 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
@@ -662,7 +662,7 @@ class LlamoeFlashAttention2(LlamaAttention):
|
|
662 |
)
|
663 |
|
664 |
|
665 |
-
class LlamoeSdpaAttention(
|
666 |
"""
|
667 |
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
668 |
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
@@ -970,7 +970,7 @@ GEMMOE_INPUTS_DOCSTRING = r"""
|
|
970 |
GEMMOE_START_DOCSTRING,
|
971 |
)
|
972 |
|
973 |
-
class LlamoeModel(
|
974 |
"""
|
975 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
|
976 |
Args:
|
@@ -1180,7 +1180,7 @@ class LlamoeModel(GemmoePreTrainedModel):
|
|
1180 |
|
1181 |
return causal_mask
|
1182 |
|
1183 |
-
class LlamoeForCausalLM(
|
1184 |
_tied_weights_keys = ["lm_head.weight"]
|
1185 |
|
1186 |
def __init__(self, config):
|
|
|
467 |
return attn_output, attn_weights, past_key_value
|
468 |
|
469 |
|
470 |
+
class LlamoeFlashAttention2(LlamoeAttention):
|
471 |
"""
|
472 |
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
473 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
|
662 |
)
|
663 |
|
664 |
|
665 |
+
class LlamoeSdpaAttention(LlamoeAttention):
|
666 |
"""
|
667 |
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
668 |
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
|
970 |
GEMMOE_START_DOCSTRING,
|
971 |
)
|
972 |
|
973 |
+
class LlamoeModel(LlammoePreTrainedModel):
|
974 |
"""
|
975 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
|
976 |
Args:
|
|
|
1180 |
|
1181 |
return causal_mask
|
1182 |
|
1183 |
+
class LlamoeForCausalLM(LlammoePreTrainedModel):
|
1184 |
_tied_weights_keys = ["lm_head.weight"]
|
1185 |
|
1186 |
def __init__(self, config):
|