damerajee commited on
Commit
bdfe503
1 Parent(s): 0c282b0

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +3 -3
modeling_Llamoe.py CHANGED
@@ -755,7 +755,7 @@ LLAMA_ATTENTION_CLASSES = {
755
 
756
 
757
  class LlamoeDecoderLayer(nn.Module):
758
- def __init__(self, config: GemmoeConfig, layer_idx: int):
759
  super().__init__()
760
  self.hidden_size = config.hidden_size
761
 
@@ -860,7 +860,7 @@ LLAMA_START_DOCSTRING = r"""
860
  )
861
 
862
  class LlammoePreTrainedModel(PreTrainedModel):
863
- config_class = GemmoeConfig
864
  base_model_prefix = "model"
865
  supports_gradient_checkpointing = True
866
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
@@ -977,7 +977,7 @@ class LlamoeModel(GemmoePreTrainedModel):
977
  config: GemmoeConfig
978
  """
979
 
980
- def __init__(self, config: GemmoeConfig):
981
  super().__init__(config)
982
  self.padding_idx = config.pad_token_id
983
  self.vocab_size = config.vocab_size
 
755
 
756
 
757
  class LlamoeDecoderLayer(nn.Module):
758
+ def __init__(self, config: LlamoeConfig, layer_idx: int):
759
  super().__init__()
760
  self.hidden_size = config.hidden_size
761
 
 
860
  )
861
 
862
  class LlammoePreTrainedModel(PreTrainedModel):
863
+ config_class = LlamoeConfig
864
  base_model_prefix = "model"
865
  supports_gradient_checkpointing = True
866
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
 
977
  config: GemmoeConfig
978
  """
979
 
980
+ def __init__(self, config: LlamoeConfig):
981
  super().__init__(config)
982
  self.padding_idx = config.pad_token_id
983
  self.vocab_size = config.vocab_size