not-lain commited on
Commit
0ae2d10
1 Parent(s): 1288f9a

refactor changes

Browse files
Files changed (1) hide show
  1. modeling_cerule_gemma.py +7 -7
modeling_cerule_gemma.py CHANGED
@@ -872,7 +872,7 @@ if is_torch_fx_available():
872
 
873
  logger = logging.get_logger(__name__)
874
 
875
- _CONFIG_FOR_DOC = "GemmaConfig"
876
 
877
 
878
  def _get_unpad_data(attention_mask):
@@ -1003,7 +1003,7 @@ class GemmaAttention(nn.Module):
1003
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1004
 
1005
  # Ignore copy
1006
- def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
1007
  super().__init__()
1008
  self.config = config
1009
  self.layer_idx = layer_idx
@@ -1396,7 +1396,7 @@ GEMMA_ATTENTION_CLASSES = {
1396
 
1397
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
1398
  class GemmaDecoderLayer(nn.Module):
1399
- def __init__(self, config: GemmaConfig, layer_idx: int):
1400
  super().__init__()
1401
  self.hidden_size = config.hidden_size
1402
 
@@ -1480,7 +1480,7 @@ GEMMA_START_DOCSTRING = r"""
1480
  and behavior.
1481
 
1482
  Parameters:
1483
- config ([`GemmaConfig`]):
1484
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1485
  load the weights associated with the model, only the configuration. Check out the
1486
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -1492,7 +1492,7 @@ GEMMA_START_DOCSTRING = r"""
1492
  GEMMA_START_DOCSTRING,
1493
  )
1494
  class GemmaPreTrainedModel(PreTrainedModel):
1495
- config_class = GemmaConfig
1496
  base_model_prefix = "model"
1497
  supports_gradient_checkpointing = True
1498
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
@@ -1618,7 +1618,7 @@ class GemmaModel(GemmaPreTrainedModel):
1618
  config: GemmaConfig
1619
  """
1620
 
1621
- def __init__(self, config: GemmaConfig):
1622
  super().__init__(config)
1623
  self.padding_idx = config.pad_token_id
1624
  self.vocab_size = config.vocab_size
@@ -2155,7 +2155,7 @@ from .configuration_gemma import CeruleGemmaConfig
2155
  class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
2156
  config_class = CeruleGemmaConfig
2157
 
2158
- def __init__(self, config: GemmaConfig):
2159
  super(CeruleGemmaModel, self).__init__(config)
2160
 
2161
 
 
872
 
873
  logger = logging.get_logger(__name__)
874
 
875
+ _CONFIG_FOR_DOC = "CeruleGemmaConfig"
876
 
877
 
878
  def _get_unpad_data(attention_mask):
 
1003
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1004
 
1005
  # Ignore copy
1006
+ def __init__(self, config: CeruleGemmaConfig, layer_idx: Optional[int] = None):
1007
  super().__init__()
1008
  self.config = config
1009
  self.layer_idx = layer_idx
 
1396
 
1397
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
1398
  class GemmaDecoderLayer(nn.Module):
1399
+ def __init__(self, config: CeruleGemmaConfig, layer_idx: int):
1400
  super().__init__()
1401
  self.hidden_size = config.hidden_size
1402
 
 
1480
  and behavior.
1481
 
1482
  Parameters:
1483
+ config ([`CeruleGemmaConfig`]):
1484
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1485
  load the weights associated with the model, only the configuration. Check out the
1486
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
1492
  GEMMA_START_DOCSTRING,
1493
  )
1494
  class GemmaPreTrainedModel(PreTrainedModel):
1495
+ config_class = CeruleGemmaConfig
1496
  base_model_prefix = "model"
1497
  supports_gradient_checkpointing = True
1498
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
 
1618
  config: GemmaConfig
1619
  """
1620
 
1621
+ def __init__(self, config: CeruleGemmaConfig):
1622
  super().__init__(config)
1623
  self.padding_idx = config.pad_token_id
1624
  self.vocab_size = config.vocab_size
 
2155
  class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
2156
  config_class = CeruleGemmaConfig
2157
 
2158
+ def __init__(self, config: CeruleGemmaConfig):
2159
  super(CeruleGemmaModel, self).__init__(config)
2160
 
2161