Update modeling_Hixtral.py
Browse files- modeling_Hixtral.py +2 -2
modeling_Hixtral.py
CHANGED
@@ -788,7 +788,7 @@ class HixtralDecoderLayer(nn.Module):
|
|
788 |
super().__init__()
|
789 |
self.hidden_size = config.hidden_size
|
790 |
|
791 |
-
self.self_attn =
|
792 |
|
793 |
self.block_sparse_moe = HixtralSparseMoeBlock(config)
|
794 |
self.input_layernorm = HixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -1020,7 +1020,7 @@ class HixtralModel(HixtralPreTrainedModel):
|
|
1020 |
self.embed_tokens = value
|
1021 |
|
1022 |
# Ignore copy
|
1023 |
-
@add_start_docstrings_to_model_forward(
|
1024 |
def forward(
|
1025 |
self,
|
1026 |
input_ids: torch.LongTensor = None,
|
|
|
788 |
super().__init__()
|
789 |
self.hidden_size = config.hidden_size
|
790 |
|
791 |
+
self.self_attn = HIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
792 |
|
793 |
self.block_sparse_moe = HixtralSparseMoeBlock(config)
|
794 |
self.input_layernorm = HixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
1020 |
self.embed_tokens = value
|
1021 |
|
1022 |
# Ignore copy
|
1023 |
+
@add_start_docstrings_to_model_forward(HIXTRAL_INPUTS_DOCSTRING)
|
1024 |
def forward(
|
1025 |
self,
|
1026 |
input_ids: torch.LongTensor = None,
|