File size: 5,458 Bytes
14e4843
 
aac38e8
14e4843
794e78c
33e1c9d
14e4843
 
 
 
3020792
 
d6d7ec6
 
3020792
14e4843
 
 
 
d6d7ec6
 
14e4843
 
d6d7ec6
14e4843
 
 
 
 
 
 
794e78c
 
f71ffb9
 
e6c97c0
d6d7ec6
e6c97c0
d6d7ec6
14e4843
aac38e8
 
f71ffb9
 
 
 
 
14e4843
 
 
 
 
 
 
 
 
 
 
 
aac38e8
 
 
 
 
 
 
 
 
 
33e1c9d
aac38e8
33e1c9d
 
 
14e4843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6d7ec6
14e4843
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import os
import shutil
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from moe_infinity import MoE
from typing import List, Tuple, Optional, Union

from lm_eval.api.registry import register_model

from src.backend.hflm_with_measurement import HFLMWithMeasurement


@register_model("moe-infinity")
class MoEHFLM(HFLMWithMeasurement):
    def __init__(
        self,
        pretrained: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        moe_config: dict = None,
        offload_path=os.path.expanduser("~"),
        device_memory_ratio=0.75,
        use_chat_template=True,
        *args,
        **kwargs,
    ):
        # Initialize parent class without calling _create_model in the parent's __init__
        self.checkpoint = pretrained
        self.moe_config = moe_config if moe_config is not None else {}
        self.offload_path = offload_path
        self.device_memory_ratio = device_memory_ratio
        self.use_chat_template = use_chat_template
        if "device" in kwargs:
            kwargs.pop("device")
        if os.path.exists(os.path.join(self.offload_path, "moe-infinity-offloads")):
            shutil.rmtree(os.path.join(self.offload_path, "moe-infinity-offloads"))
        kwargs["device_map"] = "cuda:0"
        super().__init__(
            *args, **kwargs, pretrained=pretrained
        )  # Assuming HFLM accepts a 'pretrained' arg and handles it
        # self._create_model()

    def __del__(self):
        self._model.engine.clean_up() # clean up hooks
        self._model.engine.archer_engine.clean_up_resources() # clean up resources
        if os.path.exists(os.path.join(self.offload_path, "moe-infinity-offloads")):
            shutil.rmtree(os.path.join(self.offload_path, "moe-infinity-offloads")) # clean up offload model


    def _create_model(self, *args, **kwargs):
        """
        Initializes the MoE model from MoE-infinity with the provided configuration.
        """
        # Ensure default configurations are set if not provided
        default_moe_config = {
            "offload_path": os.path.join(self.offload_path, "moe-infinity-offloads"),
            "device_memory_ratio": self.device_memory_ratio,  # Default value, adjust as necessary
        }
        # Update default config with any user-provided config
        final_moe_config = {**default_moe_config, **self.moe_config}

        # dirty fix, to be removed when MoE-infinity supports move input to correct device
        def MoEGenDecorator(func):
            def wrapper(*args, **kwargs):
                # Ensure all tensor in the input are in the same device as the model
                args = [arg.to("cuda:0") if isinstance(arg, torch.Tensor) else arg for arg in args]
                kwargs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
                return func(*args, **kwargs)
            return wrapper

        self._model = MoE(self.checkpoint, final_moe_config)
        self._model.generate = MoEGenDecorator(self._model.generate)
        # self._model = AutoModelForCausalLM.from_pretrained(
        #     self.checkpoint, torch_dtype=torch.float16, device_map="auto"
        # )

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.model.config, attr):
                return getattr(self.model.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH

    def tok_batch_encode(
        self,
        strings: List[str],
        padding_side: str = "left",
        left_truncate_len: int = None,
        truncation: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        if self.use_chat_template:
            try:
                updated_strings = []
                for input_string in strings:
                    messages = [
                        {"role": "user", "content": f"{input_string}"},
                    ]
                    updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    updated_strings.append(updated_string)
                strings = updated_strings[:]
            except:
                print(f"failed to update input string with chat template: {self._model}")
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        add_special_tokens = False

        encoding = self.tokenizer(
            strings,
            truncation=truncation,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]