ver217 commited on
Commit
11f282f
1 Parent(s): 015a18b

update config & modeling

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_grok1.py +6 -4
config.json CHANGED
@@ -28,6 +28,6 @@
28
  "num_experts": 8,
29
  "output_router_logits": false,
30
  "router_aux_loss_coef": 0.001,
31
- "torch_dtype": "float16",
32
  "transformers_version": "4.35.0"
33
  }
 
28
  "num_experts": 8,
29
  "output_router_logits": false,
30
  "router_aux_loss_coef": 0.001,
31
+ "torch_dtype": "bfloat16",
32
  "transformers_version": "4.35.0"
33
  }
modeling_grok1.py CHANGED
@@ -7,14 +7,16 @@ from transformers.modeling_utils import PreTrainedModel
7
  from transformers.utils import logging
8
 
9
  try:
10
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 
11
 
12
  HAS_MASK_UTILS = True
13
  except ImportError:
14
  HAS_MASK_UTILS = False
15
 
16
  from .configuration_grok1 import Grok1Config
17
- from .modeling_grok1_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
 
18
 
19
  logger = logging.get_logger(__name__)
20
 
@@ -549,7 +551,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
549
 
550
 
551
  class Grok1Model(Grok1PretrainedModel):
552
- def __init__(self, config: Grok1Config) -> None:
553
  super().__init__(config)
554
  self.padding_idx = config.pad_token_id
555
  self.vocab_size = config.vocab_size
@@ -787,7 +789,7 @@ class Grok1Model(Grok1PretrainedModel):
787
  class Grok1ModelForCausalLM(Grok1PretrainedModel):
788
  _tied_weights_keys = ["lm_head.weight"]
789
 
790
- def __init__(self, config: Grok1Config):
791
  super().__init__(config)
792
  self.model = Grok1Model(config)
793
  self.vocab_size = config.vocab_size
 
7
  from transformers.utils import logging
8
 
9
  try:
10
+ from transformers.modeling_attn_mask_utils import \
11
+ _prepare_4d_causal_attention_mask
12
 
13
  HAS_MASK_UTILS = True
14
  except ImportError:
15
  HAS_MASK_UTILS = False
16
 
17
  from .configuration_grok1 import Grok1Config
18
+ from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast,
19
+ MoeModelOutputWithPast)
20
 
21
  logger = logging.get_logger(__name__)
22
 
 
551
 
552
 
553
  class Grok1Model(Grok1PretrainedModel):
554
+ def __init__(self, config: Grok1Config, **kwargs) -> None:
555
  super().__init__(config)
556
  self.padding_idx = config.pad_token_id
557
  self.vocab_size = config.vocab_size
 
789
  class Grok1ModelForCausalLM(Grok1PretrainedModel):
790
  _tied_weights_keys = ["lm_head.weight"]
791
 
792
+ def __init__(self, config: Grok1Config, **kwargs):
793
  super().__init__(config)
794
  self.model = Grok1Model(config)
795
  self.vocab_size = config.vocab_size