import os import numpy as np import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import time from mlx_lm import load, generate from mlx_lm.utils import generate_step from .base_engine import BaseEngine from ..configs import ( MODEL_PATH, ) def generate_string( model: nn.Module, tokenizer: PreTrainedTokenizer, prompt: str, temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, formatter: Callable = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, stop_strings: Optional[Tuple[str]] = None ): prompt_tokens = mx.array(tokenizer.encode(prompt)) stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' tic = time.perf_counter() tokens = [] skip = 0 REPLACEMENT_CHAR = "\ufffd" for (token, prob), n in zip( generate_step( prompt_tokens, model, temp, repetition_penalty, repetition_context_size, ), range(max_tokens), ): if token == tokenizer.eos_token_id: break if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() tokens.append(token.item()) if stop_strings is not None: token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") if token_string.strip().endswith(stop_strings): break token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") return token_string def generate_yield_string( model: nn.Module, tokenizer: PreTrainedTokenizer, prompt: str, temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, formatter: Callable = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, stop_strings: Optional[Tuple[str]] = None ): """ Generate text from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. temp (float): The temperature for sampling (default 0). max_tokens (int): The maximum number of tokens (default 100). verbose (bool): If ``True``, print tokens and timing information (default ``False``). formatter (Optional[Callable]): A function which takes a token and a probability and displays it. repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. """ if verbose: print("=" * 10) print("Prompt:", prompt) stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' prompt_tokens = mx.array(tokenizer.encode(prompt)) tic = time.perf_counter() tokens = [] skip = 0 REPLACEMENT_CHAR = "\ufffd" for (token, prob), n in zip( generate_step( prompt_tokens, model, temp, repetition_penalty, repetition_context_size, ), range(max_tokens), ): if token == tokenizer.eos_token_id: break # if n == 0: # prompt_time = time.perf_counter() - tic # tic = time.perf_counter() tokens.append(token.item()) # if verbose: # s = tokenizer.decode(tokens) # if formatter: # formatter(s[skip:], prob.item()) # skip = len(s) # elif REPLACEMENT_CHAR not in s: # print(s[skip:], end="", flush=True) # skip = len(s) token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") yield token_string if stop_strings is not None and token_string.strip().endswith(stop_strings): break # token_count = len(tokens) # token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") # if verbose: # print(token_string[skip:], flush=True) # gen_time = time.perf_counter() - tic # print("=" * 10) # if token_count == 0: # print("No tokens generated for this prompt") # return # prompt_tps = prompt_tokens.size / prompt_time # gen_tps = (token_count - 1) / gen_time # print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") # print(f"Generation: {gen_tps:.3f} tokens-per-sec") # return token_string class MlxEngine(BaseEngine): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._model = None self._tokenizer = None @property def tokenizer(self) -> PreTrainedTokenizer: return self._tokenizer def load_model(self, ): model_path = MODEL_PATH self._model, self._tokenizer = load(model_path) self.model_path = model_path print(f'Load MLX model from {model_path}') def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): num_tokens = len(self.tokenizer.encode(prompt)) response = None for response in generate_yield_string( self._model, self._tokenizer, prompt, temp=temperature, max_tokens=max_tokens, repetition_penalty=kwargs.get("repetition_penalty", None), stop_strings=stop_strings, ): yield response, num_tokens if response is not None: full_text = prompt + response num_tokens = len(self.tokenizer.encode(full_text)) yield response, num_tokens def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): """ ! MLX does not support """ responses = [ generate_string( self._model, self._tokenizer, s, temp=temperature, max_tokens=max_tokens, repetition_penalty=kwargs.get("repetition_penalty", None), stop_strings=stop_strings, ) for s in prompts ] return responses