yixinsong commited on
Commit
3909b06
1 Parent(s): f94ad98
README.md CHANGED
@@ -4,8 +4,8 @@ language:
4
  - en
5
  ---
6
 
7
- # Model Card for SuperSparse-Mixtral
8
- The SuperSparse-Mixtral Large Language Model (LLM) is an sparsified version of the Mixtral.
9
 
10
  <img src="takeaway.png" alt="avatar" width="300" height="200"/>
11
 
@@ -13,7 +13,7 @@ The average performance is evaluated using benchmarks from the OpenLLM Leaderboa
13
 
14
  ## Inference
15
 
16
- Our code for accelerating SuperSparse-Mixtral is currently being refined. Stay tuned! Now you can run this model like dense model.
17
 
18
  ## Chat-Template
19
 
@@ -25,7 +25,7 @@ We take ChatML as our chat template:
25
 
26
  ## Allow Finetuning
27
 
28
- As we merged the predictors for FFN neurons in models, you can finetune SuperSparse-Mixtral with any framework and algorithm.
29
 
30
  ## License
31
 
 
4
  - en
5
  ---
6
 
7
+ # Model Card for TurboSparse-Mixtral
8
+ The TurboSparse-Mixtral Large Language Model (LLM) is an sparsified version of the Mixtral.
9
 
10
  <img src="takeaway.png" alt="avatar" width="300" height="200"/>
11
 
 
13
 
14
  ## Inference
15
 
16
+ Our code for accelerating TurboSparse-Mixtral is currently being refined. Stay tuned! Now you can run this model like dense model.
17
 
18
  ## Chat-Template
19
 
 
25
 
26
  ## Allow Finetuning
27
 
28
+ As we merged the predictors for FFN neurons in models, you can finetune TurboSparse-Mixtral with any framework and algorithm.
29
 
30
  ## License
31
 
config.json CHANGED
@@ -3,9 +3,9 @@
3
  "TurboSparseMixtralForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_supersparsemixtral.SuperSparseMixtralConfig",
7
- "AutoModel": "modeling_supersparsemixtral.SuperSparseMixtralForCausalLM",
8
- "AutoModelForCausalLM": "modeling_supersparsemixtral.SuperSparseMixtralForCausalLM"
9
  },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 1,
@@ -15,7 +15,7 @@
15
  "initializer_range": 0.02,
16
  "intermediate_size": 14336,
17
  "max_position_embeddings": 32768,
18
- "model_type": "trubosparsemixtral",
19
  "num_attention_heads": 32,
20
  "num_experts_per_tok": 2,
21
  "num_hidden_layers": 32,
 
3
  "TurboSparseMixtralForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_turbosparsemixtral.TurboSparseMixtralConfig",
7
+ "AutoModel": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM",
8
+ "AutoModelForCausalLM": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM"
9
  },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 1,
 
15
  "initializer_range": 0.02,
16
  "intermediate_size": 14336,
17
  "max_position_embeddings": 32768,
18
+ "model_type": "turbosparsemixtral",
19
  "num_attention_heads": 32,
20
  "num_experts_per_tok": 2,
21
  "num_hidden_layers": 32,
configuration_supersparsemixtral.py → configuration_turbosparsemixtral.py RENAMED
@@ -22,7 +22,7 @@ from transformers.utils import logging
22
 
23
  logger = logging.get_logger(__name__)
24
 
25
- class SuperSparseMixtralConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
28
  Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
@@ -106,7 +106,7 @@ class SuperSparseMixtralConfig(PretrainedConfig):
106
  >>> configuration = model.config
107
  ```"""
108
 
109
- model_type = "mixtral"
110
  keys_to_ignore_at_inference = ["past_key_values"]
111
 
112
  def __init__(
 
22
 
23
  logger = logging.get_logger(__name__)
24
 
25
+ class TurboSparseMixtralConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
28
  Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
 
106
  >>> configuration = model.config
107
  ```"""
108
 
109
+ model_type = "turbosparsemixtral"
110
  keys_to_ignore_at_inference = ["past_key_values"]
111
 
112
  def __init__(
modeling_supersparsemixtral.py → modeling_turbosparsemixtral.py RENAMED
@@ -54,7 +54,7 @@ from transformers.utils import (
54
  replace_return_docstrings,
55
  is_torch_fx_available,
56
  )
57
- from .configuration_supersparsemixtral import SuperSparseMixtralConfig
58
  @dataclass
59
  class AttentionMaskConverter:
60
  """
@@ -634,7 +634,7 @@ def _get_unpad_data(attention_mask):
634
 
635
 
636
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
637
- class SuperSparseMixtralRMSNorm(nn.Module):
638
  def __init__(self, hidden_size, eps=1e-6):
639
  """
640
  MixtralRMSNorm is equivalent to T5LayerNorm
@@ -653,7 +653,7 @@ class SuperSparseMixtralRMSNorm(nn.Module):
653
 
654
  # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
655
  # TODO @longjie no longer copied from Mistral after static cache
656
- class SuperSparseMixtralRotaryEmbedding(nn.Module):
657
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
658
  super().__init__()
659
 
@@ -742,13 +742,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
742
 
743
  # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
744
  # TODO @longjie no longer copied from Mistral after static cache
745
- class SuperSparseMixtralAttention(nn.Module):
746
  """
747
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
748
  and "Generating Long Sequences with Sparse Transformers".
749
  """
750
 
751
- def __init__(self, config: SuperSparseMixtralConfig, layer_idx: Optional[int] = None):
752
  super().__init__()
753
  self.config = config
754
  self.layer_idx = layer_idx
@@ -779,7 +779,7 @@ class SuperSparseMixtralAttention(nn.Module):
779
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
780
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
781
 
782
- self.rotary_emb = SuperSparseMixtralRotaryEmbedding(
783
  self.head_dim,
784
  max_position_embeddings=self.max_position_embeddings,
785
  base=self.rope_theta,
@@ -867,7 +867,7 @@ class SuperSparseMixtralAttention(nn.Module):
867
 
868
  # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
869
  # TODO @longjie no longer copied from Mistral after static cache
870
- class SuperSparseMixtralFlashAttention2(SuperSparseMixtralAttention):
871
  """
872
  Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
873
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -1154,7 +1154,7 @@ class SuperSparseMixtralFlashAttention2(SuperSparseMixtralAttention):
1154
 
1155
  # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
1156
  # TODO @longjie no longer copied from Mistral after static cache
1157
- class SuperSparseMixtralSdpaAttention(SuperSparseMixtralAttention):
1158
  """
1159
  Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1160
  `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
@@ -1246,9 +1246,9 @@ class SuperSparseMixtralSdpaAttention(SuperSparseMixtralAttention):
1246
 
1247
 
1248
  MIXTRAL_ATTENTION_CLASSES = {
1249
- "eager": SuperSparseMixtralAttention,
1250
- "flash_attention_2": SuperSparseMixtralFlashAttention2,
1251
- "sdpa": SuperSparseMixtralSdpaAttention,
1252
  }
1253
 
1254
  class MLP(nn.Module):
@@ -1264,8 +1264,8 @@ class MLP(nn.Module):
1264
  x = self.fc2(x)
1265
  x = x.sigmoid()
1266
  return x
1267
- class SuperSparseMixtralBlockSparseTop2MLP(nn.Module):
1268
- def __init__(self, config: SuperSparseMixtralConfig, layer_id):
1269
  super().__init__()
1270
  self.ffn_dim = config.intermediate_size
1271
  self.hidden_dim = config.hidden_size
@@ -1288,7 +1288,7 @@ class SuperSparseMixtralBlockSparseTop2MLP(nn.Module):
1288
  return current_hidden_states
1289
 
1290
 
1291
- class SuperSparseMixtralSparseMoeBlock(nn.Module):
1292
  """
1293
  This implementation is
1294
  strictly equivalent to standard MoE with full capacity (no
@@ -1310,7 +1310,7 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
1310
  # gating
1311
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
1312
 
1313
- self.experts = nn.ModuleList([SuperSparseMixtralBlockSparseTop2MLP(config, layer_id) for _ in range(self.num_experts)])
1314
 
1315
  # Jitter parameters
1316
  self.jitter_noise = config.router_jitter_noise
@@ -1356,16 +1356,16 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
1356
  return final_hidden_states, router_logits
1357
 
1358
 
1359
- class SuperSparseMixtralDecoderLayer(nn.Module):
1360
- def __init__(self, config: SuperSparseMixtralConfig, layer_idx: int):
1361
  super().__init__()
1362
  self.hidden_size = config.hidden_size
1363
 
1364
  self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1365
 
1366
- self.block_sparse_moe = SuperSparseMixtralSparseMoeBlock(config, layer_idx)
1367
- self.input_layernorm = SuperSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1368
- self.post_attention_layernorm = SuperSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1369
 
1370
  def forward(
1371
  self,
@@ -1451,11 +1451,11 @@ MIXTRAL_START_DOCSTRING = r"""
1451
  MIXTRAL_START_DOCSTRING,
1452
  )
1453
  # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
1454
- class SuperSparseMixtralPreTrainedModel(PreTrainedModel):
1455
- config_class = SuperSparseMixtralConfig
1456
  base_model_prefix = "model"
1457
  supports_gradient_checkpointing = True
1458
- _no_split_modules = ["SuperSparseMixtralDecoderLayer"]
1459
  _skip_keys_device_placement = "past_key_values"
1460
  _supports_flash_attn_2 = True
1461
  _supports_sdpa = True
@@ -1546,7 +1546,7 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
1546
  )
1547
  # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1548
  # TODO @longjie no longer copied from Mistral after static cache
1549
- class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
1550
  """
1551
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1552
 
@@ -1554,17 +1554,17 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
1554
  config: MixtralConfig
1555
  """
1556
 
1557
- def __init__(self, config: SuperSparseMixtralConfig):
1558
  super().__init__(config)
1559
  self.padding_idx = config.pad_token_id
1560
  self.vocab_size = config.vocab_size
1561
 
1562
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1563
  self.layers = nn.ModuleList(
1564
- [SuperSparseMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1565
  )
1566
  self._attn_implementation = config._attn_implementation
1567
- self.norm = SuperSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1568
 
1569
  self.gradient_checkpointing = False
1570
  # Initialize weights and apply final processing
@@ -1741,12 +1741,12 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
1741
  )
1742
 
1743
 
1744
- class SuperSparseMixtralForCausalLM(SuperSparseMixtralPreTrainedModel):
1745
  _tied_weights_keys = ["lm_head.weight"]
1746
 
1747
  def __init__(self, config):
1748
  super().__init__(config)
1749
- self.model = SuperSparseMixtralModel(config)
1750
  self.vocab_size = config.vocab_size
1751
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1752
  self.router_aux_loss_coef = config.router_aux_loss_coef
@@ -1974,11 +1974,11 @@ class SuperSparseMixtralForCausalLM(SuperSparseMixtralPreTrainedModel):
1974
  MIXTRAL_START_DOCSTRING,
1975
  )
1976
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1977
- class SuperSparseMixtralForSequenceClassification(SuperSparseMixtralPreTrainedModel):
1978
  def __init__(self, config):
1979
  super().__init__(config)
1980
  self.num_labels = config.num_labels
1981
- self.model = SuperSparseMixtralModel(config)
1982
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1983
 
1984
  # Initialize weights and apply final processing
@@ -2090,11 +2090,11 @@ class SuperSparseMixtralForSequenceClassification(SuperSparseMixtralPreTrainedMo
2090
  MIXTRAL_START_DOCSTRING,
2091
  )
2092
  # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
2093
- class SuperSparseMixtralForTokenClassification(SuperSparseMixtralPreTrainedModel):
2094
  def __init__(self, config):
2095
  super().__init__(config)
2096
  self.num_labels = config.num_labels
2097
- self.model = SuperSparseMixtralModel(config)
2098
  if getattr(config, "classifier_dropout", None) is not None:
2099
  classifier_dropout = config.classifier_dropout
2100
  elif getattr(config, "hidden_dropout", None) is not None:
 
54
  replace_return_docstrings,
55
  is_torch_fx_available,
56
  )
57
+ from .configuration_turbosparsemixtral import TurboSparseMixtralConfig
58
  @dataclass
59
  class AttentionMaskConverter:
60
  """
 
634
 
635
 
636
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
637
+ class TurboSparseMixtralRMSNorm(nn.Module):
638
  def __init__(self, hidden_size, eps=1e-6):
639
  """
640
  MixtralRMSNorm is equivalent to T5LayerNorm
 
653
 
654
  # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
655
  # TODO @longjie no longer copied from Mistral after static cache
656
+ class TurboSparseMixtralRotaryEmbedding(nn.Module):
657
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
658
  super().__init__()
659
 
 
742
 
743
  # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
744
  # TODO @longjie no longer copied from Mistral after static cache
745
+ class TurboSparseMixtralAttention(nn.Module):
746
  """
747
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
748
  and "Generating Long Sequences with Sparse Transformers".
749
  """
750
 
751
+ def __init__(self, config: TurboSparseMixtralConfig, layer_idx: Optional[int] = None):
752
  super().__init__()
753
  self.config = config
754
  self.layer_idx = layer_idx
 
779
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
780
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
781
 
782
+ self.rotary_emb = TurboSparseMixtralRotaryEmbedding(
783
  self.head_dim,
784
  max_position_embeddings=self.max_position_embeddings,
785
  base=self.rope_theta,
 
867
 
868
  # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
869
  # TODO @longjie no longer copied from Mistral after static cache
870
+ class TurboSparseMixtralFlashAttention2(TurboSparseMixtralAttention):
871
  """
872
  Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
873
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
 
1154
 
1155
  # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
1156
  # TODO @longjie no longer copied from Mistral after static cache
1157
+ class TurboSparseMixtralSdpaAttention(TurboSparseMixtralAttention):
1158
  """
1159
  Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1160
  `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
 
1246
 
1247
 
1248
  MIXTRAL_ATTENTION_CLASSES = {
1249
+ "eager": TurboSparseMixtralAttention,
1250
+ "flash_attention_2": TurboSparseMixtralFlashAttention2,
1251
+ "sdpa": TurboSparseMixtralSdpaAttention,
1252
  }
1253
 
1254
  class MLP(nn.Module):
 
1264
  x = self.fc2(x)
1265
  x = x.sigmoid()
1266
  return x
1267
+ class TurboSparseMixtralBlockSparseTop2MLP(nn.Module):
1268
+ def __init__(self, config: TurboSparseMixtralConfig, layer_id):
1269
  super().__init__()
1270
  self.ffn_dim = config.intermediate_size
1271
  self.hidden_dim = config.hidden_size
 
1288
  return current_hidden_states
1289
 
1290
 
1291
+ class TurboSparseMixtralSparseMoeBlock(nn.Module):
1292
  """
1293
  This implementation is
1294
  strictly equivalent to standard MoE with full capacity (no
 
1310
  # gating
1311
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
1312
 
1313
+ self.experts = nn.ModuleList([TurboSparseMixtralBlockSparseTop2MLP(config, layer_id) for _ in range(self.num_experts)])
1314
 
1315
  # Jitter parameters
1316
  self.jitter_noise = config.router_jitter_noise
 
1356
  return final_hidden_states, router_logits
1357
 
1358
 
1359
+ class TurboSparseMixtralDecoderLayer(nn.Module):
1360
+ def __init__(self, config: TurboSparseMixtralConfig, layer_idx: int):
1361
  super().__init__()
1362
  self.hidden_size = config.hidden_size
1363
 
1364
  self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1365
 
1366
+ self.block_sparse_moe = TurboSparseMixtralSparseMoeBlock(config, layer_idx)
1367
+ self.input_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1368
+ self.post_attention_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1369
 
1370
  def forward(
1371
  self,
 
1451
  MIXTRAL_START_DOCSTRING,
1452
  )
1453
  # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
1454
+ class TurboSparseMixtralPreTrainedModel(PreTrainedModel):
1455
+ config_class = TurboSparseMixtralConfig
1456
  base_model_prefix = "model"
1457
  supports_gradient_checkpointing = True
1458
+ _no_split_modules = ["TurboSparseMixtralDecoderLayer"]
1459
  _skip_keys_device_placement = "past_key_values"
1460
  _supports_flash_attn_2 = True
1461
  _supports_sdpa = True
 
1546
  )
1547
  # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1548
  # TODO @longjie no longer copied from Mistral after static cache
1549
+ class TurboSparseMixtralModel(TurboSparseMixtralPreTrainedModel):
1550
  """
1551
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1552
 
 
1554
  config: MixtralConfig
1555
  """
1556
 
1557
+ def __init__(self, config: TurboSparseMixtralConfig):
1558
  super().__init__(config)
1559
  self.padding_idx = config.pad_token_id
1560
  self.vocab_size = config.vocab_size
1561
 
1562
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1563
  self.layers = nn.ModuleList(
1564
+ [TurboSparseMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1565
  )
1566
  self._attn_implementation = config._attn_implementation
1567
+ self.norm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1568
 
1569
  self.gradient_checkpointing = False
1570
  # Initialize weights and apply final processing
 
1741
  )
1742
 
1743
 
1744
+ class TurboSparseMixtralForCausalLM(TurboSparseMixtralPreTrainedModel):
1745
  _tied_weights_keys = ["lm_head.weight"]
1746
 
1747
  def __init__(self, config):
1748
  super().__init__(config)
1749
+ self.model = TurboSparseMixtralModel(config)
1750
  self.vocab_size = config.vocab_size
1751
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1752
  self.router_aux_loss_coef = config.router_aux_loss_coef
 
1974
  MIXTRAL_START_DOCSTRING,
1975
  )
1976
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
1977
+ class TurboSparseMixtralForSequenceClassification(TurboSparseMixtralPreTrainedModel):
1978
  def __init__(self, config):
1979
  super().__init__(config)
1980
  self.num_labels = config.num_labels
1981
+ self.model = TurboSparseMixtralModel(config)
1982
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1983
 
1984
  # Initialize weights and apply final processing
 
2090
  MIXTRAL_START_DOCSTRING,
2091
  )
2092
  # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
2093
+ class TurboSparseMixtralForTokenClassification(TurboSparseMixtralPreTrainedModel):
2094
  def __init__(self, config):
2095
  super().__init__(config)
2096
  self.num_labels = config.num_labels
2097
+ self.model = TurboSparseMixtralModel(config)
2098
  if getattr(config, "classifier_dropout", None) is not None:
2099
  classifier_dropout = config.classifier_dropout
2100
  elif getattr(config, "hidden_dropout", None) is not None: