File size: 14,652 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from functools import partial

from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from typing import Any, Dict, List, Optional
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from exllama.tokenizer import ExLlamaTokenizer
from exllama.generator import ExLlamaGenerator
from exllama.lora import ExLlamaLora
import os, glob

from pydantic.v1 import root_validator

BROKEN_UNICODE = b'\\ufffd'.decode('unicode_escape')

class H2OExLlamaTokenizer(ExLlamaTokenizer):
    def __call__(self, text, *args, **kwargs):
        return dict(input_ids=self.encode(text))


class H2OExLlamaGenerator(ExLlamaGenerator):
    def is_exlama(self):
        return True


class Exllama(LLM):
    client: Any  #: :meta private:
    model_path: str = None
    model: Any = None
    sanitize_bot_response: bool = False
    prompter: Any = None
    context: Any = ''
    iinput: Any = ''

    """The path to the GPTQ model folder."""
    exllama_cache: ExLlamaCache = None  #: :meta private:
    config: ExLlamaConfig = None  #: :meta private:
    generator: ExLlamaGenerator = None  #: :meta private:
    tokenizer: ExLlamaTokenizer = None  #: :meta private:

    ##Langchain parameters
    logfunc = print
    stop_sequences: Optional[List[str]] = "" #, description="Sequences that immediately will stop the generator.")
    streaming: Optional[bool] = True #, description="Whether to stream the results, token by token.")

    ##Generator parameters
    disallowed_tokens: Optional[List[int]] = None # description="List of tokens to disallow during generation.")
    temperature: Optional[float] = None # description="Temperature for sampling diversity.")
    top_k: Optional[int] = None # description="Consider the most probable top_k samples, 0 to disable top_k sampling.")
    top_p: Optional[float] = None # description="Consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling.")
    min_p: Optional[float] = None # description="Do not consider tokens with probability less than this.")
    typical: Optional[float] = None # description="Locally typical sampling threshold, 0.0 to disable typical sampling.")
    token_repetition_penalty_max: Optional[float] = None # description="Repetition penalty for most recent tokens.")
    token_repetition_penalty_sustain: Optional[int] = None # description="No. most recent tokens to repeat penalty for, -1 to apply to whole context.")
    token_repetition_penalty_decay: Optional[int] = None # description="Gradually decrease penalty over this many tokens.")
    beams: Optional[int] = None # description="Number of beams for beam search.")
    beam_length: Optional[int] = None # description="Length of beams for beam search.")

    ##Config overrides
    max_seq_len: Optional[int] = 2048 # decription="Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA")
    compress_pos_emb: Optional[float] = 1.0 # description="Amount of compression to apply to the positional embedding.")
    set_auto_map: Optional[str] = None # description="Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. 20,7,7")
    gpu_peer_fix: Optional[bool] = None # description="Prevent direct copies of data between GPUs")
    alpha_value: Optional[float] = 1.0 #, description="Rope context extension alpha")

    ##Tuning
    matmul_recons_thd: Optional[int] = None
    fused_mlp_thd: Optional[int] = None
    sdp_thd: Optional[int] = None
    fused_attn: Optional[bool] = None
    matmul_fused_remap: Optional[bool] = None
    rmsnorm_no_half2: Optional[bool] = None
    rope_no_half2: Optional[bool] = None
    matmul_no_half2: Optional[bool] = None
    silu_no_half2: Optional[bool] = None
    concurrent_streams: Optional[bool] = None

    ##Lora Parameters
    lora_path: Optional[str] = None # description="Path to your lora.")

    @staticmethod
    def get_model_path_at(path):
        patterns = ["*.safetensors", "*.bin", "*.pt"]
        model_paths = []
        for pattern in patterns:
            full_pattern = os.path.join(path, pattern)
            model_paths = glob.glob(full_pattern)
            if model_paths:  # If there are any files matching the current pattern
                break  # Exit the loop as soon as we find a matching file
        if model_paths:  # If there are any files matching any of the patterns
            return model_paths[0]
        else:
            return None  # Return None if no matching files were found

    @staticmethod
    def configure_object(params, values, logfunc):
        obj_params = {k: values.get(k) for k in params}

        def apply_to(obj):
            for key, value in obj_params.items():
                if value:
                    if hasattr(obj, key):
                        setattr(obj, key, value)
                        logfunc(f"{key} {value}")
                    else:
                        raise AttributeError(f"{key} does not exist in {obj}")

        return apply_to

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        model_param_names = [
            "temperature",
            "top_k",
            "top_p",
            "min_p",
            "typical",
            "token_repetition_penalty_max",
            "token_repetition_penalty_sustain",
            "token_repetition_penalty_decay",
            "beams",
            "beam_length",
        ]

        config_param_names = [
            "max_seq_len",
            "compress_pos_emb",
            "gpu_peer_fix",
            "alpha_value"
        ]

        tuning_parameters = [
            "matmul_recons_thd",
            "fused_mlp_thd",
            "sdp_thd",
            "matmul_fused_remap",
            "rmsnorm_no_half2",
            "rope_no_half2",
            "matmul_no_half2",
            "silu_no_half2",
            "concurrent_streams",
            "fused_attn",
        ]

        ##Set logging function if verbose or set to empty lambda
        verbose = values['verbose']
        if not verbose:
            values['logfunc'] = lambda *args, **kwargs: None
        logfunc = values['logfunc']

        if values['model'] is None:
            model_path = values["model_path"]
            lora_path = values["lora_path"]

            tokenizer_path = os.path.join(model_path, "tokenizer.model")
            model_config_path = os.path.join(model_path, "config.json")
            model_path = Exllama.get_model_path_at(model_path)

            config = ExLlamaConfig(model_config_path)
            tokenizer = ExLlamaTokenizer(tokenizer_path)
            config.model_path = model_path

            configure_config = Exllama.configure_object(config_param_names, values, logfunc)
            configure_config(config)
            configure_tuning = Exllama.configure_object(tuning_parameters, values, logfunc)
            configure_tuning(config)

            ##Special parameter, set auto map, it's a function
            if values['set_auto_map']:
                config.set_auto_map(values['set_auto_map'])
                logfunc(f"set_auto_map {values['set_auto_map']}")

            model = ExLlama(config)
            exllama_cache = ExLlamaCache(model)
            generator = ExLlamaGenerator(model, tokenizer, exllama_cache)

            ##Load and apply lora to generator
            if lora_path is not None:
                lora_config_path = os.path.join(lora_path, "adapter_config.json")
                lora_path = Exllama.get_model_path_at(lora_path)
                lora = ExLlamaLora(model, lora_config_path, lora_path)
                generator.lora = lora
                logfunc(f"Loaded LORA @ {lora_path}")
        else:
            generator = values['model']
            exllama_cache = generator.cache
            model = generator.model
            config = model.config
            tokenizer = generator.tokenizer

        # Set if model existed before or not since generation-time parameters
        configure_model = Exllama.configure_object(model_param_names, values, logfunc)
        values["stop_sequences"] = [x.strip().lower() for x in values["stop_sequences"]]
        configure_model(generator.settings)

        setattr(generator.settings, "stop_sequences", values["stop_sequences"])
        logfunc(f"stop_sequences {values['stop_sequences']}")

        disallowed = values.get("disallowed_tokens")
        if disallowed:
            generator.disallow_tokens(disallowed)
            print(f"Disallowed Tokens: {generator.disallowed_tokens}")

        values["client"] = model
        values["generator"] = generator
        values["config"] = config
        values["tokenizer"] = tokenizer
        values["exllama_cache"] = exllama_cache

        return values

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "Exllama"

    def get_num_tokens(self, text: str) -> int:
        """Get the number of tokens present in the text."""
        return self.generator.tokenizer.num_tokens(text)

    def get_token_ids(self, text: str) -> List[int]:
        return self.generator.tokenizer.encode(text)
        # avoid base method that is not aware of how to properly tokenize (uses GPT2)
        # return _get_token_ids_default_method(text)

    def _call(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> str:
        assert self.tokenizer is not None
        from h2oai_pipeline import H2OTextGenerationPipeline
        prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)

        # NOTE: TGI server does not add prompting, so must do here
        data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
        prompt = self.prompter.generate_prompt(data_point)

        text = ''
        for text1 in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
            text = text1
        return text

    from enum import Enum

    class MatchStatus(Enum):
        EXACT_MATCH = 1
        PARTIAL_MATCH = 0
        NO_MATCH = 2

    def match_status(self, sequence: str, banned_sequences: List[str]):
        sequence = sequence.strip().lower()
        for banned_seq in banned_sequences:
            if banned_seq == sequence:
                return self.MatchStatus.EXACT_MATCH
            elif banned_seq.startswith(sequence):
                return self.MatchStatus.PARTIAL_MATCH
        return self.MatchStatus.NO_MATCH

    def stream(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        config = self.config
        generator = self.generator
        beam_search = (self.beams and self.beams >= 1 and self.beam_length and self.beam_length >= 1)

        ids = generator.tokenizer.encode(prompt)
        generator.gen_begin_reuse(ids)

        if beam_search:
            generator.begin_beam_search()
            token_getter = generator.beam_search
        else:
            generator.end_beam_search()
            token_getter = generator.gen_single_token

        last_newline_pos = 0
        seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0]))
        response_start = seq_length
        cursor_head = response_start

        text_callback = None
        if run_manager:
            text_callback = partial(
                run_manager.on_llm_new_token, verbose=self.verbose
            )
        # No longer assume below, assume always just new text so various langchain things work
        ##### parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
        #### text_callback:
        ####    text_callback(prompt)
        text = ""
        while (generator.gen_num_tokens() <= (
                self.max_seq_len - 4)):  # Slight extra padding space as we seem to occassionally get a few more than 1-2 tokens
            # Fetch a token
            token = token_getter()

            # If it's the ending token replace it and end the generation.
            if token.item() == generator.tokenizer.eos_token_id:
                generator.replace_last_token(generator.tokenizer.newline_token_id)
                if beam_search:
                    generator.end_beam_search()
                return

            # Tokenize the string from the last new line, we can't just decode the last token due to how sentencepiece decodes.
            stuff = generator.tokenizer.decode(generator.sequence_actual[0][last_newline_pos:])
            cursor_tail = len(stuff)
            has_unicode_combined = cursor_tail<cursor_head
            text_chunk = stuff[cursor_head:cursor_tail]
            if has_unicode_combined:
                # replace the broken unicode character with combined one
                text=text[:-2]
                text_chunk = stuff[cursor_tail-1:cursor_tail]
                
            cursor_head = cursor_tail

            # Append the generated chunk to our stream buffer
            text += text_chunk
            text = self.prompter.get_response(prompt + text, prompt=prompt,
                                              sanitize_bot_response=self.sanitize_bot_response)

            if token.item() == generator.tokenizer.newline_token_id:
                last_newline_pos = len(generator.sequence_actual[0])
                cursor_head = 0
                cursor_tail = 0

            # Check if the stream buffer is one of the stop sequences
            status = self.match_status(text, self.stop_sequences)

            if status == self.MatchStatus.EXACT_MATCH:
                # Encountered a stop, rewind our generator to before we hit the match and end generation.
                rewind_length = generator.tokenizer.encode(text).shape[-1]
                generator.gen_rewind(rewind_length)
                #gen = generator.tokenizer.decode(generator.sequence_actual[0][response_start:])
                if beam_search:
                    generator.end_beam_search()
                return
            elif status == self.MatchStatus.PARTIAL_MATCH:
                # Partially matched a stop, continue buffering but don't yield.
                continue
            elif status == self.MatchStatus.NO_MATCH:
                if text_callback and not (text_chunk == BROKEN_UNICODE):
                    text_callback(text_chunk)
                yield text  # Not a stop, yield the match buffer.

        return