Update modeling_mixtral_clex.py
Browse files- modeling_mixtral_clex.py +9 -9
modeling_mixtral_clex.py
CHANGED
@@ -51,7 +51,7 @@ from transformers.utils import (
|
|
51 |
replace_return_docstrings,
|
52 |
)
|
53 |
from transformers.utils.import_utils import is_torch_fx_available
|
54 |
-
from .configuration_mixtral_clex import
|
55 |
from .clex_layer import CLEXScalingRotaryEmbedding
|
56 |
|
57 |
if is_flash_attn_2_available():
|
@@ -71,7 +71,7 @@ if is_torch_fx_available():
|
|
71 |
|
72 |
logger = logging.get_logger(__name__)
|
73 |
|
74 |
-
_CONFIG_FOR_DOC = "
|
75 |
|
76 |
|
77 |
def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
|
@@ -254,7 +254,7 @@ class MixtralAttention(nn.Module):
|
|
254 |
and "Generating Long Sequences with Sparse Transformers".
|
255 |
"""
|
256 |
|
257 |
-
def __init__(self, config:
|
258 |
super().__init__()
|
259 |
self.config = config
|
260 |
self.layer_idx = layer_idx
|
@@ -847,7 +847,7 @@ MIXTRAL_ATTENTION_CLASSES = {
|
|
847 |
|
848 |
|
849 |
class MixtralBLockSparseTop2MLP(nn.Module):
|
850 |
-
def __init__(self, config:
|
851 |
super().__init__()
|
852 |
self.ffn_dim = config.intermediate_size
|
853 |
self.hidden_dim = config.hidden_size
|
@@ -935,7 +935,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
|
935 |
|
936 |
|
937 |
class MixtralDecoderLayer(nn.Module):
|
938 |
-
def __init__(self, config:
|
939 |
super().__init__()
|
940 |
self.hidden_size = config.hidden_size
|
941 |
|
@@ -1024,7 +1024,7 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
1024 |
and behavior.
|
1025 |
|
1026 |
Parameters:
|
1027 |
-
config ([`
|
1028 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1029 |
load the weights associated with the model, only the configuration. Check out the
|
1030 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -1037,7 +1037,7 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
1037 |
)
|
1038 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
1039 |
class MixtralPreTrainedModel(PreTrainedModel):
|
1040 |
-
config_class =
|
1041 |
base_model_prefix = "model"
|
1042 |
supports_gradient_checkpointing = True
|
1043 |
_no_split_modules = ["MixtralDecoderLayer", "CLEXScalingRotaryEmbedding"]
|
@@ -1135,10 +1135,10 @@ class MixtralModel(MixtralPreTrainedModel):
|
|
1135 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
1136 |
|
1137 |
Args:
|
1138 |
-
config:
|
1139 |
"""
|
1140 |
|
1141 |
-
def __init__(self, config:
|
1142 |
super().__init__(config)
|
1143 |
self.padding_idx = config.pad_token_id
|
1144 |
self.vocab_size = config.vocab_size
|
|
|
51 |
replace_return_docstrings,
|
52 |
)
|
53 |
from transformers.utils.import_utils import is_torch_fx_available
|
54 |
+
from .configuration_mixtral_clex import CLEXMixtralConfig
|
55 |
from .clex_layer import CLEXScalingRotaryEmbedding
|
56 |
|
57 |
if is_flash_attn_2_available():
|
|
|
71 |
|
72 |
logger = logging.get_logger(__name__)
|
73 |
|
74 |
+
_CONFIG_FOR_DOC = "CLEXMixtralConfig"
|
75 |
|
76 |
|
77 |
def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
|
|
|
254 |
and "Generating Long Sequences with Sparse Transformers".
|
255 |
"""
|
256 |
|
257 |
+
def __init__(self, config: CLEXMixtralConfig, layer_idx: Optional[int] = None):
|
258 |
super().__init__()
|
259 |
self.config = config
|
260 |
self.layer_idx = layer_idx
|
|
|
847 |
|
848 |
|
849 |
class MixtralBLockSparseTop2MLP(nn.Module):
|
850 |
+
def __init__(self, config: CLEXMixtralConfig):
|
851 |
super().__init__()
|
852 |
self.ffn_dim = config.intermediate_size
|
853 |
self.hidden_dim = config.hidden_size
|
|
|
935 |
|
936 |
|
937 |
class MixtralDecoderLayer(nn.Module):
|
938 |
+
def __init__(self, config: CLEXMixtralConfig, layer_idx: int):
|
939 |
super().__init__()
|
940 |
self.hidden_size = config.hidden_size
|
941 |
|
|
|
1024 |
and behavior.
|
1025 |
|
1026 |
Parameters:
|
1027 |
+
config ([`CLEXMixtralConfig`]):
|
1028 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1029 |
load the weights associated with the model, only the configuration. Check out the
|
1030 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
1037 |
)
|
1038 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
1039 |
class MixtralPreTrainedModel(PreTrainedModel):
|
1040 |
+
config_class = CLEXMixtralConfig
|
1041 |
base_model_prefix = "model"
|
1042 |
supports_gradient_checkpointing = True
|
1043 |
_no_split_modules = ["MixtralDecoderLayer", "CLEXScalingRotaryEmbedding"]
|
|
|
1135 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
1136 |
|
1137 |
Args:
|
1138 |
+
config: CLEXMixtralConfig
|
1139 |
"""
|
1140 |
|
1141 |
+
def __init__(self, config: CLEXMixtralConfig):
|
1142 |
super().__init__(config)
|
1143 |
self.padding_idx = config.pad_token_id
|
1144 |
self.vocab_size = config.vocab_size
|