future-xy
commited on
Commit
•
33e1c9d
1
Parent(s):
1c22d8d
add moe-infinity back
Browse files
src/backend/moe_infinity.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import os
|
3 |
from transformers import AutoTokenizer
|
4 |
from transformers import AutoModelForCausalLM
|
5 |
-
|
6 |
from typing import List, Tuple, Optional, Union
|
7 |
|
8 |
from lm_eval.models.huggingface import HFLM
|
@@ -45,10 +45,10 @@ class MoEHFLM(HFLM):
|
|
45 |
}
|
46 |
# Update default config with any user-provided config
|
47 |
final_moe_config = {**default_moe_config, **self.moe_config}
|
48 |
-
|
49 |
-
self._model = AutoModelForCausalLM.from_pretrained(
|
50 |
-
|
51 |
-
)
|
52 |
|
53 |
@property
|
54 |
def max_length(self):
|
|
|
2 |
import os
|
3 |
from transformers import AutoTokenizer
|
4 |
from transformers import AutoModelForCausalLM
|
5 |
+
from moe_infinity import MoE
|
6 |
from typing import List, Tuple, Optional, Union
|
7 |
|
8 |
from lm_eval.models.huggingface import HFLM
|
|
|
45 |
}
|
46 |
# Update default config with any user-provided config
|
47 |
final_moe_config = {**default_moe_config, **self.moe_config}
|
48 |
+
self._model = MoE(self.checkpoint, final_moe_config)
|
49 |
+
# self._model = AutoModelForCausalLM.from_pretrained(
|
50 |
+
# self.checkpoint, torch_dtype=torch.float16, device_map="auto"
|
51 |
+
# )
|
52 |
|
53 |
@property
|
54 |
def max_length(self):
|