damerajee commited on
Commit
9dacccb
·
verified ·
1 Parent(s): fdf3b4d

Update modeling_Hixtral.py

Browse files
Files changed (1) hide show
  1. modeling_Hixtral.py +2 -2
modeling_Hixtral.py CHANGED
@@ -788,7 +788,7 @@ class HixtralDecoderLayer(nn.Module):
788
  super().__init__()
789
  self.hidden_size = config.hidden_size
790
 
791
- self.self_attn = LLAMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
792
 
793
  self.block_sparse_moe = HixtralSparseMoeBlock(config)
794
  self.input_layernorm = HixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1020,7 +1020,7 @@ class HixtralModel(HixtralPreTrainedModel):
1020
  self.embed_tokens = value
1021
 
1022
  # Ignore copy
1023
- @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
1024
  def forward(
1025
  self,
1026
  input_ids: torch.LongTensor = None,
 
788
  super().__init__()
789
  self.hidden_size = config.hidden_size
790
 
791
+ self.self_attn = HIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
792
 
793
  self.block_sparse_moe = HixtralSparseMoeBlock(config)
794
  self.input_layernorm = HixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1020
  self.embed_tokens = value
1021
 
1022
  # Ignore copy
1023
+ @add_start_docstrings_to_model_forward(HIXTRAL_INPUTS_DOCSTRING)
1024
  def forward(
1025
  self,
1026
  input_ids: torch.LongTensor = None,