|
import os |
|
import numpy as np |
|
import mlx.core as mx |
|
import mlx.nn as nn |
|
from huggingface_hub import snapshot_download |
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer |
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
|
import time |
|
from mlx_lm import load, generate |
|
from mlx_lm.utils import generate_step |
|
|
|
from .base_engine import BaseEngine |
|
|
|
from ..configs import ( |
|
MODEL_PATH, |
|
) |
|
|
|
def generate_string( |
|
model: nn.Module, |
|
tokenizer: PreTrainedTokenizer, |
|
prompt: str, |
|
temp: float = 0.0, |
|
max_tokens: int = 100, |
|
verbose: bool = False, |
|
formatter: Callable = None, |
|
repetition_penalty: Optional[float] = None, |
|
repetition_context_size: Optional[int] = None, |
|
stop_strings: Optional[Tuple[str]] = None |
|
): |
|
prompt_tokens = mx.array(tokenizer.encode(prompt)) |
|
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) |
|
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' |
|
|
|
tic = time.perf_counter() |
|
tokens = [] |
|
skip = 0 |
|
REPLACEMENT_CHAR = "\ufffd" |
|
|
|
for (token, prob), n in zip( |
|
generate_step( |
|
prompt_tokens, |
|
model, |
|
temp, |
|
repetition_penalty, |
|
repetition_context_size, |
|
), |
|
range(max_tokens), |
|
): |
|
if token == tokenizer.eos_token_id: |
|
break |
|
if n == 0: |
|
prompt_time = time.perf_counter() - tic |
|
tic = time.perf_counter() |
|
tokens.append(token.item()) |
|
if stop_strings is not None: |
|
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") |
|
if token_string.strip().endswith(stop_strings): |
|
break |
|
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") |
|
return token_string |
|
|
|
|
|
|
|
def generate_yield_string( |
|
model: nn.Module, |
|
tokenizer: PreTrainedTokenizer, |
|
prompt: str, |
|
temp: float = 0.0, |
|
max_tokens: int = 100, |
|
verbose: bool = False, |
|
formatter: Callable = None, |
|
repetition_penalty: Optional[float] = None, |
|
repetition_context_size: Optional[int] = None, |
|
stop_strings: Optional[Tuple[str]] = None |
|
): |
|
""" |
|
Generate text from the model. |
|
Args: |
|
model (nn.Module): The language model. |
|
tokenizer (PreTrainedTokenizer): The tokenizer. |
|
prompt (str): The string prompt. |
|
temp (float): The temperature for sampling (default 0). |
|
max_tokens (int): The maximum number of tokens (default 100). |
|
verbose (bool): If ``True``, print tokens and timing information |
|
(default ``False``). |
|
formatter (Optional[Callable]): A function which takes a token and a |
|
probability and displays it. |
|
repetition_penalty (float, optional): The penalty factor for repeating tokens. |
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. |
|
""" |
|
if verbose: |
|
print("=" * 10) |
|
print("Prompt:", prompt) |
|
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) |
|
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' |
|
prompt_tokens = mx.array(tokenizer.encode(prompt)) |
|
tic = time.perf_counter() |
|
tokens = [] |
|
skip = 0 |
|
REPLACEMENT_CHAR = "\ufffd" |
|
for (token, prob), n in zip( |
|
generate_step( |
|
prompt_tokens, |
|
model, |
|
temp, |
|
repetition_penalty, |
|
repetition_context_size, |
|
), |
|
range(max_tokens), |
|
): |
|
if token == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
|
|
tokens.append(token.item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") |
|
yield token_string |
|
if stop_strings is not None and token_string.strip().endswith(stop_strings): |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlxEngine(BaseEngine): |
|
|
|
def __init__(self, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self._model = None |
|
self._tokenizer = None |
|
|
|
@property |
|
def tokenizer(self) -> PreTrainedTokenizer: |
|
return self._tokenizer |
|
|
|
def load_model(self, ): |
|
model_path = MODEL_PATH |
|
self._model, self._tokenizer = load(model_path) |
|
self.model_path = model_path |
|
print(f'Load MLX model from {model_path}') |
|
|
|
|
|
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): |
|
num_tokens = len(self.tokenizer.encode(prompt)) |
|
response = None |
|
for response in generate_yield_string( |
|
self._model, self._tokenizer, |
|
prompt, temp=temperature, max_tokens=max_tokens, |
|
repetition_penalty=kwargs.get("repetition_penalty", None), |
|
stop_strings=stop_strings, |
|
): |
|
yield response, num_tokens |
|
if response is not None: |
|
full_text = prompt + response |
|
num_tokens = len(self.tokenizer.encode(full_text)) |
|
yield response, num_tokens |
|
|
|
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): |
|
""" |
|
! MLX does not support |
|
""" |
|
responses = [ |
|
generate_string( |
|
self._model, self._tokenizer, |
|
s, temp=temperature, max_tokens=max_tokens, |
|
repetition_penalty=kwargs.get("repetition_penalty", None), |
|
stop_strings=stop_strings, |
|
) |
|
for s in prompts |
|
] |
|
return responses |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|