Guanzheng commited on
Commit
74e9804
1 Parent(s): d1b9f30

Update modeling_mixtral_clex.py

Browse files
Files changed (1) hide show
  1. modeling_mixtral_clex.py +10 -9
modeling_mixtral_clex.py CHANGED
@@ -51,7 +51,7 @@ from transformers.utils import (
51
  replace_return_docstrings,
52
  )
53
  from transformers.utils.import_utils import is_torch_fx_available
54
- from .configuration_mixtral_clex import MixtralConfig
55
  from .clex_layer import CLEXScalingRotaryEmbedding
56
 
57
  if is_flash_attn_2_available():
@@ -71,7 +71,7 @@ if is_torch_fx_available():
71
 
72
  logger = logging.get_logger(__name__)
73
 
74
- _CONFIG_FOR_DOC = "MixtralConfig"
75
 
76
 
77
  def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
@@ -254,7 +254,7 @@ class MixtralAttention(nn.Module):
254
  and "Generating Long Sequences with Sparse Transformers".
255
  """
256
 
257
- def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
258
  super().__init__()
259
  self.config = config
260
  self.layer_idx = layer_idx
@@ -847,7 +847,7 @@ MIXTRAL_ATTENTION_CLASSES = {
847
 
848
 
849
  class MixtralBLockSparseTop2MLP(nn.Module):
850
- def __init__(self, config: MixtralConfig):
851
  super().__init__()
852
  self.ffn_dim = config.intermediate_size
853
  self.hidden_dim = config.hidden_size
@@ -935,7 +935,7 @@ class MixtralSparseMoeBlock(nn.Module):
935
 
936
 
937
  class MixtralDecoderLayer(nn.Module):
938
- def __init__(self, config: MixtralConfig, layer_idx: int):
939
  super().__init__()
940
  self.hidden_size = config.hidden_size
941
 
@@ -1024,7 +1024,7 @@ MIXTRAL_START_DOCSTRING = r"""
1024
  and behavior.
1025
 
1026
  Parameters:
1027
- config ([`MixtralConfig`]):
1028
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1029
  load the weights associated with the model, only the configuration. Check out the
1030
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -1037,7 +1037,7 @@ MIXTRAL_START_DOCSTRING = r"""
1037
  )
1038
  # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
1039
  class MixtralPreTrainedModel(PreTrainedModel):
1040
- config_class = MixtralConfig
1041
  base_model_prefix = "model"
1042
  supports_gradient_checkpointing = True
1043
  _no_split_modules = ["MixtralDecoderLayer", "CLEXScalingRotaryEmbedding"]
@@ -1135,10 +1135,10 @@ class MixtralModel(MixtralPreTrainedModel):
1135
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1136
 
1137
  Args:
1138
- config: MixtralConfig
1139
  """
1140
 
1141
- def __init__(self, config: MixtralConfig):
1142
  super().__init__(config)
1143
  self.padding_idx = config.pad_token_id
1144
  self.vocab_size = config.vocab_size
@@ -1410,6 +1410,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
1410
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1411
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1412
  ```"""
 
1413
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1414
  output_router_logits = (
1415
  output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
51
  replace_return_docstrings,
52
  )
53
  from transformers.utils.import_utils import is_torch_fx_available
54
+ from .configuration_mixtral_clex import CLEXMixtralConfig
55
  from .clex_layer import CLEXScalingRotaryEmbedding
56
 
57
  if is_flash_attn_2_available():
 
71
 
72
  logger = logging.get_logger(__name__)
73
 
74
+ _CONFIG_FOR_DOC = "CLEXMixtralConfig"
75
 
76
 
77
  def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
 
254
  and "Generating Long Sequences with Sparse Transformers".
255
  """
256
 
257
+ def __init__(self, config: CLEXMixtralConfig, layer_idx: Optional[int] = None):
258
  super().__init__()
259
  self.config = config
260
  self.layer_idx = layer_idx
 
847
 
848
 
849
  class MixtralBLockSparseTop2MLP(nn.Module):
850
+ def __init__(self, config: CLEXMixtralConfig):
851
  super().__init__()
852
  self.ffn_dim = config.intermediate_size
853
  self.hidden_dim = config.hidden_size
 
935
 
936
 
937
  class MixtralDecoderLayer(nn.Module):
938
+ def __init__(self, config: CLEXMixtralConfig, layer_idx: int):
939
  super().__init__()
940
  self.hidden_size = config.hidden_size
941
 
 
1024
  and behavior.
1025
 
1026
  Parameters:
1027
+ config ([`CLEXMixtralConfig`]):
1028
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1029
  load the weights associated with the model, only the configuration. Check out the
1030
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
1037
  )
1038
  # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
1039
  class MixtralPreTrainedModel(PreTrainedModel):
1040
+ config_class = CLEXMixtralConfig
1041
  base_model_prefix = "model"
1042
  supports_gradient_checkpointing = True
1043
  _no_split_modules = ["MixtralDecoderLayer", "CLEXScalingRotaryEmbedding"]
 
1135
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1136
 
1137
  Args:
1138
+ config: CLEXMixtralConfig
1139
  """
1140
 
1141
+ def __init__(self, config: CLEXMixtralConfig):
1142
  super().__init__(config)
1143
  self.padding_idx = config.pad_token_id
1144
  self.vocab_size = config.vocab_size
 
1410
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1411
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1412
  ```"""
1413
+ print(input_ids[0,20:30].tolist())
1414
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1415
  output_router_logits = (
1416
  output_router_logits if output_router_logits is not None else self.config.output_router_logits