update config & modeling
Browse files- config.json +1 -1
- modeling_grok1.py +6 -4
config.json
CHANGED
@@ -28,6 +28,6 @@
|
|
28 |
"num_experts": 8,
|
29 |
"output_router_logits": false,
|
30 |
"router_aux_loss_coef": 0.001,
|
31 |
-
"torch_dtype": "
|
32 |
"transformers_version": "4.35.0"
|
33 |
}
|
|
|
28 |
"num_experts": 8,
|
29 |
"output_router_logits": false,
|
30 |
"router_aux_loss_coef": 0.001,
|
31 |
+
"torch_dtype": "bfloat16",
|
32 |
"transformers_version": "4.35.0"
|
33 |
}
|
modeling_grok1.py
CHANGED
@@ -7,14 +7,16 @@ from transformers.modeling_utils import PreTrainedModel
|
|
7 |
from transformers.utils import logging
|
8 |
|
9 |
try:
|
10 |
-
from transformers.modeling_attn_mask_utils import
|
|
|
11 |
|
12 |
HAS_MASK_UTILS = True
|
13 |
except ImportError:
|
14 |
HAS_MASK_UTILS = False
|
15 |
|
16 |
from .configuration_grok1 import Grok1Config
|
17 |
-
from .modeling_grok1_outputs import MoeCausalLMOutputWithPast,
|
|
|
18 |
|
19 |
logger = logging.get_logger(__name__)
|
20 |
|
@@ -549,7 +551,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
549 |
|
550 |
|
551 |
class Grok1Model(Grok1PretrainedModel):
|
552 |
-
def __init__(self, config: Grok1Config) -> None:
|
553 |
super().__init__(config)
|
554 |
self.padding_idx = config.pad_token_id
|
555 |
self.vocab_size = config.vocab_size
|
@@ -787,7 +789,7 @@ class Grok1Model(Grok1PretrainedModel):
|
|
787 |
class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
788 |
_tied_weights_keys = ["lm_head.weight"]
|
789 |
|
790 |
-
def __init__(self, config: Grok1Config):
|
791 |
super().__init__(config)
|
792 |
self.model = Grok1Model(config)
|
793 |
self.vocab_size = config.vocab_size
|
|
|
7 |
from transformers.utils import logging
|
8 |
|
9 |
try:
|
10 |
+
from transformers.modeling_attn_mask_utils import \
|
11 |
+
_prepare_4d_causal_attention_mask
|
12 |
|
13 |
HAS_MASK_UTILS = True
|
14 |
except ImportError:
|
15 |
HAS_MASK_UTILS = False
|
16 |
|
17 |
from .configuration_grok1 import Grok1Config
|
18 |
+
from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast,
|
19 |
+
MoeModelOutputWithPast)
|
20 |
|
21 |
logger = logging.get_logger(__name__)
|
22 |
|
|
|
551 |
|
552 |
|
553 |
class Grok1Model(Grok1PretrainedModel):
|
554 |
+
def __init__(self, config: Grok1Config, **kwargs) -> None:
|
555 |
super().__init__(config)
|
556 |
self.padding_idx = config.pad_token_id
|
557 |
self.vocab_size = config.vocab_size
|
|
|
789 |
class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
790 |
_tied_weights_keys = ["lm_head.weight"]
|
791 |
|
792 |
+
def __init__(self, config: Grok1Config, **kwargs):
|
793 |
super().__init__(config)
|
794 |
self.model = Grok1Model(config)
|
795 |
self.vocab_size = config.vocab_size
|