damerajee commited on
Commit
bd78fad
1 Parent(s): 28eb841

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +4 -4
modeling_Llamoe.py CHANGED
@@ -467,7 +467,7 @@ class LlamoeAttention(nn.Module):
467
  return attn_output, attn_weights, past_key_value
468
 
469
 
470
- class LlamoeFlashAttention2(LlamaAttention):
471
  """
472
  Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
473
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -662,7 +662,7 @@ class LlamoeFlashAttention2(LlamaAttention):
662
  )
663
 
664
 
665
- class LlamoeSdpaAttention(LlamaAttention):
666
  """
667
  Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
668
  `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
@@ -970,7 +970,7 @@ GEMMOE_INPUTS_DOCSTRING = r"""
970
  GEMMOE_START_DOCSTRING,
971
  )
972
 
973
- class LlamoeModel(GemmoePreTrainedModel):
974
  """
975
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
976
  Args:
@@ -1180,7 +1180,7 @@ class LlamoeModel(GemmoePreTrainedModel):
1180
 
1181
  return causal_mask
1182
 
1183
- class LlamoeForCausalLM(GemmoePreTrainedModel):
1184
  _tied_weights_keys = ["lm_head.weight"]
1185
 
1186
  def __init__(self, config):
 
467
  return attn_output, attn_weights, past_key_value
468
 
469
 
470
+ class LlamoeFlashAttention2(LlamoeAttention):
471
  """
472
  Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
473
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
 
662
  )
663
 
664
 
665
+ class LlamoeSdpaAttention(LlamoeAttention):
666
  """
667
  Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
668
  `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
 
970
  GEMMOE_START_DOCSTRING,
971
  )
972
 
973
+ class LlamoeModel(LlammoePreTrainedModel):
974
  """
975
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
976
  Args:
 
1180
 
1181
  return causal_mask
1182
 
1183
+ class LlamoeForCausalLM(LlammoePreTrainedModel):
1184
  _tied_weights_keys = ["lm_head.weight"]
1185
 
1186
  def __init__(self, config):