Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import sys | |
import uuid | |
import time | |
import json | |
import ctypes | |
import fnmatch | |
import multiprocessing | |
from typing import ( | |
List, | |
Optional, | |
Union, | |
Generator, | |
Sequence, | |
Iterator, | |
Deque, | |
Callable, | |
Dict, | |
) | |
from collections import deque | |
from pathlib import Path | |
from llama_cpp.llama_types import List | |
from .llama_types import * | |
from .llama_grammar import LlamaGrammar | |
from .llama_cache import ( | |
BaseLlamaCache, | |
LlamaCache, # type: ignore | |
LlamaDiskCache, # type: ignore | |
LlamaRAMCache, # type: ignore | |
) | |
from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer | |
import llama_cpp.llama_cpp as llama_cpp | |
import llama_cpp.llama_chat_format as llama_chat_format | |
from llama_cpp.llama_speculative import LlamaDraftModel | |
import numpy as np | |
import numpy.typing as npt | |
from ._internals import ( | |
_LlamaModel, # type: ignore | |
_LlamaContext, # type: ignore | |
_LlamaBatch, # type: ignore | |
_LlamaTokenDataArray, # type: ignore | |
_LlamaSamplingParams, # type: ignore | |
_LlamaSamplingContext, # type: ignore | |
_normalize_embedding, # type: ignore | |
) | |
from ._logger import set_verbose | |
from ._utils import suppress_stdout_stderr | |
class Llama: | |
"""High-level Python wrapper for a llama.cpp model.""" | |
__backend_initialized = False | |
def __init__( | |
self, | |
model_path: str, | |
*, | |
# Model Params | |
n_gpu_layers: int = 0, | |
split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER, | |
main_gpu: int = 0, | |
tensor_split: Optional[List[float]] = None, | |
vocab_only: bool = False, | |
use_mmap: bool = True, | |
use_mlock: bool = False, | |
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None, | |
# Context Params | |
seed: int = llama_cpp.LLAMA_DEFAULT_SEED, | |
n_ctx: int = 512, | |
n_batch: int = 512, | |
n_threads: Optional[int] = None, | |
n_threads_batch: Optional[int] = None, | |
rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, | |
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, | |
rope_freq_base: float = 0.0, | |
rope_freq_scale: float = 0.0, | |
yarn_ext_factor: float = -1.0, | |
yarn_attn_factor: float = 1.0, | |
yarn_beta_fast: float = 32.0, | |
yarn_beta_slow: float = 1.0, | |
yarn_orig_ctx: int = 0, | |
logits_all: bool = False, | |
embedding: bool = False, | |
offload_kqv: bool = True, | |
flash_attn: bool = False, | |
# Sampling Params | |
last_n_tokens_size: int = 64, | |
# LoRA Params | |
lora_base: Optional[str] = None, | |
lora_scale: float = 1.0, | |
lora_path: Optional[str] = None, | |
# Backend Params | |
numa: Union[bool, int] = False, | |
# Chat Format Params | |
chat_format: Optional[str] = None, | |
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, | |
# Speculative Decoding | |
draft_model: Optional[LlamaDraftModel] = None, | |
# Tokenizer Override | |
tokenizer: Optional[BaseLlamaTokenizer] = None, | |
# KV cache quantization | |
type_k: Optional[int] = None, | |
type_v: Optional[int] = None, | |
# Misc | |
verbose: bool = True, | |
# Extra Params | |
**kwargs, # type: ignore | |
): | |
"""Load a llama.cpp model from `model_path`. | |
Examples: | |
Basic usage | |
>>> import llama_cpp | |
>>> model = llama_cpp.Llama( | |
... model_path="path/to/model", | |
... ) | |
>>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"]) | |
the lazy dog | |
Loading a chat model | |
>>> import llama_cpp | |
>>> model = llama_cpp.Llama( | |
... model_path="path/to/model", | |
... chat_format="llama-2", | |
... ) | |
>>> print(model.create_chat_completion( | |
... messages=[{ | |
... "role": "user", | |
... "content": "what is the meaning of life?" | |
... }] | |
... )) | |
Args: | |
model_path: Path to the model. | |
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded. | |
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options. | |
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored | |
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split. | |
vocab_only: Only load the vocabulary no weights. | |
use_mmap: Use mmap if possible. | |
use_mlock: Force the system to keep the model in RAM. | |
kv_overrides: Key-value overrides for the model. | |
seed: RNG seed, -1 for random | |
n_ctx: Text context, 0 = from model | |
n_batch: Prompt processing maximum batch size | |
n_threads: Number of threads to use for generation | |
n_threads_batch: Number of threads to use for batch processing | |
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054 | |
pooling_type: Pooling type, from `enum llama_pooling_type`. | |
rope_freq_base: RoPE base frequency, 0 = from model | |
rope_freq_scale: RoPE frequency scaling factor, 0 = from model | |
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model | |
yarn_attn_factor: YaRN magnitude scaling factor | |
yarn_beta_fast: YaRN low correction dim | |
yarn_beta_slow: YaRN high correction dim | |
yarn_orig_ctx: YaRN original context size | |
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. | |
embedding: Embedding mode only. | |
offload_kqv: Offload K, Q, V to GPU. | |
flash_attn: Use flash attention. | |
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. | |
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. | |
lora_path: Path to a LoRA file to apply to the model. | |
numa: numa policy | |
chat_format: String specifying the chat format to use when calling create_chat_completion. | |
chat_handler: Optional chat handler to use when calling create_chat_completion. | |
draft_model: Optional draft model to use for speculative decoding. | |
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp. | |
verbose: Print verbose output to stderr. | |
type_k: KV cache data type for K (default: f16) | |
type_v: KV cache data type for V (default: f16) | |
Raises: | |
ValueError: If the model path does not exist. | |
Returns: | |
A Llama instance. | |
""" | |
self.verbose = verbose | |
set_verbose(verbose) | |
if not Llama.__backend_initialized: | |
with suppress_stdout_stderr(disable=verbose): | |
llama_cpp.llama_backend_init() | |
Llama.__backend_initialized = True | |
if isinstance(numa, bool): | |
self.numa = ( | |
llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE | |
if numa | |
else llama_cpp.GGML_NUMA_STRATEGY_DISABLED | |
) | |
else: | |
self.numa = numa | |
if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED: | |
with suppress_stdout_stderr(disable=verbose): | |
llama_cpp.llama_numa_init(self.numa) | |
self.model_path = model_path | |
# Model Params | |
self.model_params = llama_cpp.llama_model_default_params() | |
self.model_params.n_gpu_layers = ( | |
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers | |
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers | |
self.model_params.split_mode = split_mode | |
self.model_params.main_gpu = main_gpu | |
self.tensor_split = tensor_split | |
self._c_tensor_split = None | |
if self.tensor_split is not None: | |
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES: | |
raise ValueError( | |
f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}" | |
) | |
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES | |
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES | |
self._c_tensor_split = FloatArray( | |
*tensor_split # type: ignore | |
) # keep a reference to the array so it is not gc'd | |
self.model_params.tensor_split = self._c_tensor_split | |
self.model_params.vocab_only = vocab_only | |
self.model_params.use_mmap = use_mmap if lora_path is None else False | |
self.model_params.use_mlock = use_mlock | |
# kv_overrides is the original python dict | |
self.kv_overrides = kv_overrides | |
if kv_overrides is not None: | |
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs | |
kvo_array_len = len(kv_overrides) + 1 # for sentinel element | |
self._kv_overrides_array = ( | |
llama_cpp.llama_model_kv_override * kvo_array_len | |
)() | |
for i, (k, v) in enumerate(kv_overrides.items()): | |
self._kv_overrides_array[i].key = k.encode("utf-8") | |
if isinstance(v, bool): | |
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL | |
self._kv_overrides_array[i].value.bool_value = v | |
elif isinstance(v, int): | |
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT | |
self._kv_overrides_array[i].value.int_value = v | |
elif isinstance(v, float): | |
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT | |
self._kv_overrides_array[i].value.float_value = v | |
elif isinstance(v, str): # type: ignore | |
v_bytes = v.encode("utf-8") | |
if len(v_bytes) > 128: # TODO: Make this a constant | |
raise ValueError(f"Value for {k} is too long: {v}") | |
v_bytes = v_bytes.ljust(128, b"\0") | |
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR | |
# copy min(v_bytes, 128) to str_value | |
ctypes.memmove( | |
self._kv_overrides_array[i].value.str_value, | |
v_bytes, | |
min(len(v_bytes), 128), | |
) | |
else: | |
raise ValueError(f"Unknown value type for {k}: {v}") | |
self._kv_overrides_array[-1].key = ( | |
b"\0" # ensure sentinel element is zeroed | |
) | |
self.model_params.kv_overrides = self._kv_overrides_array | |
self.n_batch = min(n_ctx, n_batch) # ??? | |
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) | |
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() | |
# Context Params | |
self.context_params = llama_cpp.llama_context_default_params() | |
self.context_params.seed = seed | |
self.context_params.n_ctx = n_ctx | |
self.context_params.n_batch = self.n_batch | |
self.context_params.n_threads = self.n_threads | |
self.context_params.n_threads_batch = self.n_threads_batch | |
self.context_params.rope_scaling_type = ( | |
rope_scaling_type | |
if rope_scaling_type is not None | |
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED | |
) | |
self.context_params.pooling_type = pooling_type | |
self.context_params.rope_freq_base = ( | |
rope_freq_base if rope_freq_base != 0.0 else 0 | |
) | |
self.context_params.rope_freq_scale = ( | |
rope_freq_scale if rope_freq_scale != 0.0 else 0 | |
) | |
self.context_params.yarn_ext_factor = ( | |
yarn_ext_factor if yarn_ext_factor != 0.0 else 0 | |
) | |
self.context_params.yarn_attn_factor = ( | |
yarn_attn_factor if yarn_attn_factor != 0.0 else 0 | |
) | |
self.context_params.yarn_beta_fast = ( | |
yarn_beta_fast if yarn_beta_fast != 0.0 else 0 | |
) | |
self.context_params.yarn_beta_slow = ( | |
yarn_beta_slow if yarn_beta_slow != 0.0 else 0 | |
) | |
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 | |
self.context_params.logits_all = ( | |
logits_all if draft_model is None else True | |
) # Must be set to True for speculative decoding | |
self.context_params.embeddings = embedding # TODO: Rename to embeddings | |
self.context_params.offload_kqv = offload_kqv | |
self.context_params.flash_attn = flash_attn | |
# KV cache quantization | |
if type_k is not None: | |
self.context_params.type_k = type_k | |
if type_v is not None: | |
self.context_params.type_v = type_v | |
# Sampling Params | |
self.last_n_tokens_size = last_n_tokens_size | |
self.cache: Optional[BaseLlamaCache] = None | |
self.lora_base = lora_base | |
self.lora_scale = lora_scale | |
self.lora_path = lora_path | |
if not os.path.exists(model_path): | |
raise ValueError(f"Model path does not exist: {model_path}") | |
self._model = _LlamaModel( | |
path_model=self.model_path, params=self.model_params, verbose=self.verbose | |
) | |
# Override tokenizer | |
self.tokenizer_ = tokenizer or LlamaTokenizer(self) | |
# Set the default value for the context and correct the batch | |
if n_ctx == 0: | |
n_ctx = self._model.n_ctx_train() | |
self.n_batch = min(n_ctx, n_batch) | |
self.context_params.n_ctx = self._model.n_ctx_train() | |
self.context_params.n_batch = self.n_batch | |
self._ctx = _LlamaContext( | |
model=self._model, | |
params=self.context_params, | |
verbose=self.verbose, | |
) | |
self._batch = _LlamaBatch( | |
n_tokens=self.n_batch, | |
embd=0, | |
n_seq_max=self.context_params.n_ctx, | |
verbose=self.verbose, | |
) | |
if self.lora_path: | |
if self._model.apply_lora_from_file( | |
self.lora_path, | |
self.lora_scale, | |
self.lora_base, | |
self.n_threads, | |
): | |
raise RuntimeError( | |
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" | |
) | |
if self.verbose: | |
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) | |
self.chat_format = chat_format | |
self.chat_handler = chat_handler | |
self.draft_model = draft_model | |
self._n_vocab = self.n_vocab() | |
self._n_ctx = self.n_ctx() | |
self._token_nl = self.token_nl() | |
self._token_eos = self.token_eos() | |
self._candidates = _LlamaTokenDataArray(n_vocab=self._n_vocab) | |
self.n_tokens = 0 | |
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) | |
self.scores: npt.NDArray[np.single] = np.ndarray( | |
(n_ctx, self._n_vocab), dtype=np.single | |
) | |
self._mirostat_mu = ctypes.c_float( | |
2.0 * 5.0 | |
) # TODO: Move this to sampling context | |
try: | |
self.metadata = self._model.metadata() | |
except Exception as e: | |
self.metadata = {} | |
if self.verbose: | |
print(f"Failed to load metadata: {e}", file=sys.stderr) | |
if self.verbose: | |
print(f"Model metadata: {self.metadata}", file=sys.stderr) | |
if ( | |
self.chat_format is None | |
and self.chat_handler is None | |
and "tokenizer.chat_template" in self.metadata | |
): | |
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata( | |
self.metadata | |
) | |
if chat_format is not None: | |
self.chat_format = chat_format | |
if self.verbose: | |
print(f"Guessed chat format: {chat_format}", file=sys.stderr) | |
else: | |
template = self.metadata["tokenizer.chat_template"] | |
try: | |
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"]) | |
except: | |
eos_token_id = self.token_eos() | |
try: | |
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"]) | |
except: | |
bos_token_id = self.token_bos() | |
eos_token = self._model.token_get_text(eos_token_id) | |
bos_token = self._model.token_get_text(bos_token_id) | |
if self.verbose: | |
print(f"Using gguf chat template: {template}", file=sys.stderr) | |
print(f"Using chat eos_token: {eos_token}", file=sys.stderr) | |
print(f"Using chat bos_token: {bos_token}", file=sys.stderr) | |
self.chat_handler = llama_chat_format.Jinja2ChatFormatter( | |
template=template, | |
eos_token=eos_token, | |
bos_token=bos_token, | |
stop_token_ids=[eos_token_id], | |
).to_chat_handler() | |
if self.chat_format is None and self.chat_handler is None: | |
self.chat_format = "llama-2" | |
if self.verbose: | |
print(f"Using fallback chat format: {chat_format}", file=sys.stderr) | |
def ctx(self) -> llama_cpp.llama_context_p: | |
assert self._ctx.ctx is not None | |
return self._ctx.ctx | |
def model(self) -> llama_cpp.llama_model_p: | |
assert self._model.model is not None | |
return self._model.model | |
def _input_ids(self) -> npt.NDArray[np.intc]: | |
return self.input_ids[: self.n_tokens] | |
def _scores(self) -> npt.NDArray[np.single]: | |
return self.scores[: self.n_tokens, :] | |
def eval_tokens(self) -> Deque[int]: | |
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx) | |
def eval_logits(self) -> Deque[List[float]]: | |
return deque( | |
self.scores[: self.n_tokens, :].tolist(), | |
maxlen=self._n_ctx if self.context_params.logits_all else 1, | |
) | |
def tokenize( | |
self, text: bytes, add_bos: bool = True, special: bool = False | |
) -> List[int]: | |
"""Tokenize a string. | |
Args: | |
text: The utf-8 encoded string to tokenize. | |
Raises: | |
RuntimeError: If the tokenization failed. | |
Returns: | |
A list of tokens. | |
""" | |
return self.tokenizer_.tokenize(text, add_bos, special) | |
def detokenize( | |
self, tokens: List[int], prev_tokens: Optional[List[int]] = None | |
) -> bytes: | |
"""Detokenize a list of tokens. | |
Args: | |
tokens: The list of tokens to detokenize. | |
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided | |
Returns: | |
The detokenized string. | |
""" | |
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens) | |
def set_cache(self, cache: Optional[BaseLlamaCache]): | |
"""Set the cache. | |
Args: | |
cache: The cache to set. | |
""" | |
self.cache = cache | |
def set_seed(self, seed: int): | |
"""Set the random seed. | |
Args: | |
seed: The random seed. | |
""" | |
assert self._ctx.ctx is not None | |
llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed) | |
def reset(self): | |
"""Reset the model state.""" | |
self.n_tokens = 0 | |
def eval(self, tokens: Sequence[int]): | |
"""Evaluate a list of tokens. | |
Args: | |
tokens: The list of tokens to evaluate. | |
""" | |
assert self._ctx.ctx is not None | |
assert self._batch.batch is not None | |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) | |
for i in range(0, len(tokens), self.n_batch): | |
batch = tokens[i : min(len(tokens), i + self.n_batch)] | |
n_past = self.n_tokens | |
n_tokens = len(batch) | |
self._batch.set_batch( | |
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all | |
) | |
self._ctx.decode(self._batch) | |
# Save tokens | |
self.input_ids[n_past : n_past + n_tokens] = batch | |
# Save logits | |
if self.context_params.logits_all: | |
rows = n_tokens | |
cols = self._n_vocab | |
logits = self._ctx.get_logits()[: rows * cols] | |
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits | |
else: | |
rows = 1 | |
cols = self._n_vocab | |
logits = self._ctx.get_logits()[: rows * cols] | |
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits | |
# Update n_tokens | |
self.n_tokens += n_tokens | |
def sample( | |
self, | |
top_k: int = 40, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
temp: float = 0.80, | |
repeat_penalty: float = 1.1, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_eta: float = 0.1, | |
mirostat_tau: float = 5.0, | |
penalize_nl: bool = True, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
idx: Optional[int] = None, | |
): | |
"""Sample a token from the model. | |
Args: | |
top_k: The top-k sampling parameter. | |
top_p: The top-p sampling parameter. | |
temp: The temperature parameter. | |
repeat_penalty: The repeat penalty parameter. | |
Returns: | |
The sampled token. | |
""" | |
assert self._ctx is not None | |
assert self.n_tokens > 0 | |
if idx is None: | |
logits: npt.NDArray[np.single] = self._scores[-1, :] | |
else: | |
logits = self._scores[idx, :] | |
if logits_processor is not None: | |
logits[:] = ( | |
logits_processor(self._input_ids, logits) | |
if idx is None | |
else logits_processor(self._input_ids[: idx + 1], logits) | |
) | |
sampling_params = _LlamaSamplingParams( | |
top_k=top_k, | |
top_p=top_p, | |
min_p=min_p, | |
tfs_z=tfs_z, | |
typical_p=typical_p, | |
temp=temp, | |
penalty_last_n=self.last_n_tokens_size, | |
penalty_repeat=repeat_penalty, | |
penalty_freq=frequency_penalty, | |
penalty_present=presence_penalty, | |
mirostat=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
penalize_nl=penalize_nl, | |
) | |
sampling_context = _LlamaSamplingContext( | |
params=sampling_params, | |
grammar=grammar, | |
) | |
sampling_context.prev = list(self.eval_tokens) | |
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) | |
sampling_context.accept( | |
ctx_main=self._ctx, | |
id=id, | |
apply_grammar=grammar is not None, | |
) | |
return id | |
def generate( | |
self, | |
tokens: Sequence[int], | |
top_k: int = 40, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
temp: float = 0.80, | |
repeat_penalty: float = 1.1, | |
reset: bool = True, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
penalize_nl: bool = True, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
) -> Generator[int, Optional[Sequence[int]], None]: | |
"""Create a generator of tokens from a prompt. | |
Examples: | |
>>> llama = Llama("models/ggml-7b.bin") | |
>>> tokens = llama.tokenize(b"Hello, world!") | |
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.1): | |
... print(llama.detokenize([token])) | |
Args: | |
tokens: The prompt tokens. | |
top_k: The top-k sampling parameter. | |
top_p: The top-p sampling parameter. | |
temp: The temperature parameter. | |
repeat_penalty: The repeat penalty parameter. | |
reset: Whether to reset the model state. | |
Yields: | |
The generated tokens. | |
""" | |
# Reset mirostat sampling | |
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau) | |
# Check for kv cache prefix match | |
if reset and self.n_tokens > 0: | |
longest_prefix = 0 | |
for a, b in zip(self._input_ids, tokens[:-1]): | |
if a == b: | |
longest_prefix += 1 | |
else: | |
break | |
if longest_prefix > 0: | |
if self.verbose: | |
print("Llama.generate: prefix-match hit", file=sys.stderr) | |
reset = False | |
tokens = tokens[longest_prefix:] | |
self.n_tokens = longest_prefix | |
# Reset the model state | |
if reset: | |
self.reset() | |
# Reset the grammar | |
if grammar is not None: | |
grammar.reset() | |
sample_idx = self.n_tokens + len(tokens) - 1 | |
tokens = list(tokens) | |
# Eval and sample | |
while True: | |
self.eval(tokens) | |
while sample_idx < self.n_tokens: | |
token = self.sample( | |
top_k=top_k, | |
top_p=top_p, | |
min_p=min_p, | |
typical_p=typical_p, | |
temp=temp, | |
repeat_penalty=repeat_penalty, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
penalize_nl=penalize_nl, | |
idx=sample_idx, | |
) | |
sample_idx += 1 | |
if stopping_criteria is not None and stopping_criteria( | |
self._input_ids, self._scores[-1, :] | |
): | |
return | |
tokens_or_none = yield token | |
tokens.clear() | |
tokens.append(token) | |
if tokens_or_none is not None: | |
tokens.extend(tokens_or_none) | |
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: | |
self.n_tokens = sample_idx | |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) | |
break | |
if self.draft_model is not None: | |
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens | |
draft_tokens = self.draft_model( | |
self.input_ids[: self.n_tokens + len(tokens)] | |
) | |
tokens.extend( | |
draft_tokens.astype(int)[ | |
: self._n_ctx - self.n_tokens - len(tokens) | |
] | |
) | |
def create_embedding( | |
self, input: Union[str, List[str]], model: Optional[str] = None | |
) -> CreateEmbeddingResponse: | |
"""Embed a string. | |
Args: | |
input: The utf-8 encoded string to embed. | |
Returns: | |
An embedding object. | |
""" | |
assert self._model.model is not None | |
model_name: str = model if model is not None else self.model_path | |
input = input if isinstance(input, list) else [input] | |
# get numeric embeddings | |
embeds: Union[List[List[float]], List[List[List[float]]]] | |
total_tokens: int | |
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore | |
# convert to CreateEmbeddingResponse | |
data: List[Embedding] = [ | |
{ | |
"object": "embedding", | |
"embedding": emb, | |
"index": idx, | |
} | |
for idx, emb in enumerate(embeds) | |
] | |
return { | |
"object": "list", | |
"data": data, | |
"model": model_name, | |
"usage": { | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens, | |
}, | |
} | |
def embed( | |
self, | |
input: Union[str, List[str]], | |
normalize: bool = False, | |
truncate: bool = True, | |
return_count: bool = False, | |
): | |
"""Embed a string. | |
Args: | |
input: The utf-8 encoded string to embed. | |
Returns: | |
A list of embeddings | |
""" | |
assert self._ctx.ctx is not None | |
n_embd = self.n_embd() | |
n_batch = self.n_batch | |
# get pooling information | |
pooling_type = self.pooling_type() | |
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE | |
if self.context_params.embeddings == False: | |
raise RuntimeError( | |
"Llama model must be created with embedding=True to call this method" | |
) | |
if self.verbose: | |
llama_cpp.llama_reset_timings(self._ctx.ctx) | |
if isinstance(input, str): | |
inputs = [input] | |
else: | |
inputs = input | |
# reset batch | |
self._batch.reset() | |
# decode and fetch embeddings | |
data: Union[List[List[float]], List[List[List[float]]]] = [] | |
def decode_batch(seq_sizes: List[int]): | |
assert self._ctx.ctx is not None | |
llama_cpp.llama_kv_cache_clear(self._ctx.ctx) | |
self._ctx.decode(self._batch) | |
self._batch.reset() | |
# store embeddings | |
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE: | |
pos: int = 0 | |
for i, size in enumerate(seq_sizes): | |
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx) | |
embedding: List[List[float]] = [ | |
ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size) | |
] | |
if normalize: | |
embedding = [_normalize_embedding(e) for e in embedding] | |
data.append(embedding) | |
pos += size | |
else: | |
for i in range(len(seq_sizes)): | |
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i) | |
embedding: List[float] = ptr[:n_embd] | |
if normalize: | |
embedding = _normalize_embedding(embedding) | |
data.append(embedding) | |
# init state | |
total_tokens = 0 | |
s_batch = [] | |
t_batch = 0 | |
p_batch = 0 | |
# accumulate batches and encode | |
for text in inputs: | |
tokens = self.tokenize(text.encode("utf-8")) | |
if truncate: | |
tokens = tokens[:n_batch] | |
n_tokens = len(tokens) | |
total_tokens += n_tokens | |
# check for overrun | |
if n_tokens > n_batch: | |
raise ValueError( | |
f"Requested tokens ({n_tokens}) exceed batch size of {n_batch}" | |
) | |
# time to eval batch | |
if t_batch + n_tokens > n_batch: | |
decode_batch(s_batch) | |
s_batch = [] | |
t_batch = 0 | |
p_batch = 0 | |
# add to batch | |
self._batch.add_sequence(tokens, p_batch, logits_all) | |
# update batch stats | |
s_batch.append(n_tokens) | |
t_batch += n_tokens | |
p_batch += 1 | |
# hanlde last batch | |
decode_batch(s_batch) | |
if self.verbose: | |
llama_cpp.llama_print_timings(self._ctx.ctx) | |
output = data[0] if isinstance(input, str) else data | |
llama_cpp.llama_kv_cache_clear(self._ctx.ctx) | |
self.reset() | |
if return_count: | |
return output, total_tokens | |
else: | |
return output | |
def _create_completion( | |
self, | |
prompt: Union[str, List[int]], | |
suffix: Optional[str] = None, | |
max_tokens: Optional[int] = 16, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
logprobs: Optional[int] = None, | |
echo: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
top_k: int = 40, | |
stream: bool = False, | |
seed: Optional[int] = None, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
logit_bias: Optional[Dict[str, float]] = None, | |
) -> Union[ | |
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] | |
]: | |
assert self._ctx is not None | |
assert suffix is None or suffix.__class__ is str | |
completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
created: int = int(time.time()) | |
# If prompt is empty, initialize completion with BOS token to avoid | |
# detokenization including a space at the beginning of the completion | |
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()] | |
# Add blank space to start of prompt to match OG llama tokenizer | |
prompt_tokens: List[int] = ( | |
( | |
self.tokenize(prompt.encode("utf-8"), special=True) | |
if prompt != "" | |
else [self.token_bos()] | |
) | |
if isinstance(prompt, str) | |
else prompt | |
) | |
text: bytes = b"" | |
returned_tokens: int = 0 | |
stop = ( | |
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] | |
) | |
model_name: str = model if model is not None else self.model_path | |
# NOTE: This likely doesn't work correctly for the first token in the prompt | |
# because of the extra space added to the start of the prompt_tokens | |
if logit_bias is not None: | |
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()} | |
def logit_bias_processor( | |
input_ids: npt.NDArray[np.intc], | |
scores: npt.NDArray[np.single], | |
) -> npt.NDArray[np.single]: | |
new_scores = np.copy( | |
scores | |
) # Does it make sense to copy the whole array or can we just overwrite the original one? | |
for input_id, score in logit_bias_map.items(): | |
new_scores[input_id] = score + scores[input_id] | |
return new_scores | |
_logit_bias_processor = LogitsProcessorList([logit_bias_processor]) | |
if logits_processor is None: | |
logits_processor = _logit_bias_processor | |
else: | |
logits_processor = logits_processor.extend(_logit_bias_processor) | |
if self.verbose: | |
self._ctx.reset_timings() | |
if len(prompt_tokens) >= self._n_ctx: | |
raise ValueError( | |
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" | |
) | |
if max_tokens is None or max_tokens <= 0: | |
# Unlimited, depending on n_ctx. | |
max_tokens = self._n_ctx - len(prompt_tokens) | |
# Truncate max_tokens if requested tokens would exceed the context window | |
max_tokens = ( | |
max_tokens | |
if max_tokens + len(prompt_tokens) < self._n_ctx | |
else (self._n_ctx - len(prompt_tokens)) | |
) | |
if stop != []: | |
stop_sequences = [s.encode("utf-8") for s in stop] | |
else: | |
stop_sequences = [] | |
if logprobs is not None and self.context_params.logits_all is False: | |
raise ValueError( | |
"logprobs is not supported for models created with logits_all=False" | |
) | |
if self.cache: | |
try: | |
cache_item = self.cache[prompt_tokens] | |
cache_prefix_len = Llama.longest_token_prefix( | |
cache_item.input_ids.tolist(), prompt_tokens | |
) | |
eval_prefix_len = Llama.longest_token_prefix( | |
self._input_ids.tolist(), prompt_tokens | |
) | |
if cache_prefix_len > eval_prefix_len: | |
self.load_state(cache_item) | |
if self.verbose: | |
print("Llama._create_completion: cache hit", file=sys.stderr) | |
except KeyError: | |
if self.verbose: | |
print("Llama._create_completion: cache miss", file=sys.stderr) | |
if seed is not None: | |
self._ctx.set_rng_seed(seed) | |
finish_reason = "length" | |
multibyte_fix = 0 | |
for token in self.generate( | |
prompt_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
min_p=min_p, | |
typical_p=typical_p, | |
temp=temperature, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
repeat_penalty=repeat_penalty, | |
stopping_criteria=stopping_criteria, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
): | |
assert self._model.model is not None | |
if llama_cpp.llama_token_is_eog(self._model.model, token): | |
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) | |
finish_reason = "stop" | |
break | |
completion_tokens.append(token) | |
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) | |
# Contains multi-byte UTF8 | |
for k, char in enumerate(all_text[-3:]): | |
k = 3 - k | |
for num, pattern in [(2, 192), (3, 224), (4, 240)]: | |
# Bitwise AND check | |
if num > k and pattern & char == pattern: | |
multibyte_fix = num - k | |
# Stop incomplete bytes from passing | |
if multibyte_fix > 0: | |
multibyte_fix -= 1 | |
continue | |
any_stop = [s for s in stop_sequences if s in all_text] | |
if len(any_stop) > 0: | |
first_stop = any_stop[0] | |
text = all_text[: all_text.index(first_stop)] | |
finish_reason = "stop" | |
break | |
if stream: | |
remaining_tokens = completion_tokens[returned_tokens:] | |
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) | |
remaining_length = len(remaining_text) | |
# We want to avoid yielding any characters from | |
# the generated text if they are part of a stop | |
# sequence. | |
first_stop_position = 0 | |
for s in stop_sequences: | |
for i in range(min(len(s), remaining_length), 0, -1): | |
if remaining_text.endswith(s[:i]): | |
if i > first_stop_position: | |
first_stop_position = i | |
break | |
token_end_position = 0 | |
if logprobs is not None: | |
# not sure how to handle this branch when dealing | |
# with CJK output, so keep it unchanged | |
for token in remaining_tokens: | |
if token == self.token_bos(): | |
continue | |
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])) | |
# Check if stop sequence is in the token | |
if token_end_position > ( | |
remaining_length - first_stop_position | |
): | |
break | |
token_str = self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( | |
"utf-8", errors="ignore" | |
) | |
text_offset = len(prompt) + len( | |
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( | |
"utf-8", errors="ignore" | |
) | |
) | |
token_offset = len(prompt_tokens) + returned_tokens | |
logits = self._scores[token_offset - 1, :] | |
current_logprobs = Llama.logits_to_logprobs(logits).tolist() | |
sorted_logprobs = list( | |
sorted( | |
zip(current_logprobs, range(len(current_logprobs))), | |
reverse=True, | |
) | |
) | |
top_logprob = { | |
self.detokenize([i]).decode( | |
"utf-8", errors="ignore" | |
): logprob | |
for logprob, i in sorted_logprobs[:logprobs] | |
} | |
top_logprob.update({token_str: current_logprobs[int(token)]}) | |
logprobs_or_none = { | |
"tokens": [ | |
self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( | |
"utf-8", errors="ignore" | |
) | |
], | |
"text_offset": [text_offset], | |
"token_logprobs": [current_logprobs[int(token)]], | |
"top_logprobs": [top_logprob], | |
} | |
returned_tokens += 1 | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode( | |
"utf-8", errors="ignore" | |
), | |
"index": 0, | |
"logprobs": logprobs_or_none, | |
"finish_reason": None, | |
} | |
], | |
} | |
else: | |
while len(remaining_tokens) > 0: | |
decode_success = False | |
for i in range(1, len(remaining_tokens) + 1): | |
try: | |
bs = self.detokenize(remaining_tokens[:i], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) | |
ts = bs.decode("utf-8") | |
decode_success = True | |
break | |
except UnicodeError: | |
pass | |
else: | |
break | |
if not decode_success: | |
# all remaining tokens cannot be decoded to a UTF-8 character | |
break | |
token_end_position += len(bs) | |
if token_end_position > ( | |
remaining_length - first_stop_position | |
): | |
break | |
remaining_tokens = remaining_tokens[i:] | |
returned_tokens += i | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": ts, | |
"index": 0, | |
"logprobs": None, | |
"finish_reason": None, | |
} | |
], | |
} | |
if len(completion_tokens) >= max_tokens: | |
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) | |
finish_reason = "length" | |
break | |
if stopping_criteria is not None and stopping_criteria( | |
self._input_ids, self._scores[-1, :] | |
): | |
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) | |
finish_reason = "stop" | |
if self.verbose: | |
self._ctx.print_timings() | |
if stream: | |
remaining_tokens = completion_tokens[returned_tokens:] | |
all_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) | |
any_stop = [s for s in stop_sequences if s in all_text] | |
if len(any_stop) > 0: | |
end = min(all_text.index(stop) for stop in any_stop) | |
else: | |
end = len(all_text) | |
token_end_position = 0 | |
for token in remaining_tokens: | |
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])) | |
logprobs_or_none: Optional[CompletionLogprobs] = None | |
if logprobs is not None: | |
if token == self.token_bos(): | |
continue | |
token_str = self.detokenize([token]).decode( | |
"utf-8", errors="ignore" | |
) | |
text_offset = len(prompt) + len( | |
self.detokenize(completion_tokens[:returned_tokens], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]) | |
) | |
token_offset = len(prompt_tokens) + returned_tokens - 1 | |
logits = self._scores[token_offset, :] | |
current_logprobs = Llama.logits_to_logprobs(logits).tolist() | |
sorted_logprobs = list( | |
sorted( | |
zip(current_logprobs, range(len(current_logprobs))), | |
reverse=True, | |
) | |
) | |
top_logprob = { | |
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob | |
for logprob, i in sorted_logprobs[:logprobs] | |
} | |
top_logprob.update({token_str: current_logprobs[int(token)]}) | |
logprobs_or_none = { | |
"tokens": [ | |
self.detokenize([token]).decode("utf-8", errors="ignore") | |
], | |
"text_offset": [text_offset], | |
"token_logprobs": [current_logprobs[int(token)]], | |
"top_logprobs": [top_logprob], | |
} | |
if token_end_position >= end: | |
last_text = self.detokenize([token]) | |
if token_end_position == end - 1: | |
break | |
returned_tokens += 1 | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": last_text[ | |
: len(last_text) - (token_end_position - end) | |
].decode("utf-8", errors="ignore"), | |
"index": 0, | |
"logprobs": logprobs_or_none, | |
"finish_reason": None, | |
} | |
], | |
} | |
break | |
returned_tokens += 1 | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": self.detokenize([token]).decode( | |
"utf-8", errors="ignore" | |
), | |
"index": 0, | |
"logprobs": logprobs_or_none, | |
"finish_reason": None, | |
} | |
], | |
} | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": "", | |
"index": 0, | |
"logprobs": None, | |
"finish_reason": finish_reason, | |
} | |
], | |
} | |
if self.cache: | |
if self.verbose: | |
print("Llama._create_completion: cache save", file=sys.stderr) | |
self.cache[prompt_tokens + completion_tokens] = self.save_state() | |
print("Llama._create_completion: cache saved", file=sys.stderr) | |
return | |
if self.cache: | |
if self.verbose: | |
print("Llama._create_completion: cache save", file=sys.stderr) | |
self.cache[prompt_tokens + completion_tokens] = self.save_state() | |
text_str = text.decode("utf-8", errors="ignore") | |
if echo: | |
text_str = prompt + text_str | |
if suffix is not None: | |
text_str = text_str + suffix | |
logprobs_or_none: Optional[CompletionLogprobs] = None | |
if logprobs is not None: | |
text_offset = 0 if echo else len(prompt) | |
token_offset = 0 if echo else len(prompt_tokens[1:]) | |
text_offsets: List[int] = [] | |
token_logprobs: List[Optional[float]] = [] | |
tokens: List[str] = [] | |
top_logprobs: List[Optional[Dict[str, float]]] = [] | |
if echo: | |
# Remove leading BOS token | |
all_tokens = prompt_tokens[1:] + completion_tokens | |
else: | |
all_tokens = completion_tokens | |
all_token_strs = [ | |
self.detokenize([token], prev_tokens=all_tokens[:i]).decode("utf-8", errors="ignore") | |
for i, token in enumerate(all_tokens) | |
] | |
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] | |
# TODO: may be able to change this loop to use np.take_along_dim | |
for idx, (token, token_str, logprobs_token) in enumerate( | |
zip(all_tokens, all_token_strs, all_logprobs) | |
): | |
if token == self.token_bos(): | |
continue | |
text_offsets.append( | |
text_offset | |
+ len( | |
self.detokenize(all_tokens[:idx]).decode( | |
"utf-8", errors="ignore" | |
) | |
) | |
) | |
tokens.append(token_str) | |
sorted_logprobs = list( | |
sorted( | |
zip(logprobs_token, range(len(logprobs_token))), reverse=True | |
) | |
) | |
token_logprobs.append(logprobs_token[int(token)]) | |
top_logprob: Optional[Dict[str, float]] = { | |
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode("utf-8", errors="ignore"): logprob | |
for logprob, i in sorted_logprobs[:logprobs] | |
} | |
top_logprob.update({token_str: logprobs_token[int(token)]}) | |
top_logprobs.append(top_logprob) | |
# Weird idosincracy of the OpenAI API where | |
# token_logprobs and top_logprobs are null for | |
# the first token. | |
if echo and len(all_tokens) > 0: | |
token_logprobs[0] = None | |
top_logprobs[0] = None | |
logprobs_or_none = { | |
"tokens": tokens, | |
"text_offset": text_offsets, | |
"token_logprobs": token_logprobs, | |
"top_logprobs": top_logprobs, | |
} | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"choices": [ | |
{ | |
"text": text_str, | |
"index": 0, | |
"logprobs": logprobs_or_none, | |
"finish_reason": finish_reason, | |
} | |
], | |
"usage": { | |
"prompt_tokens": len(prompt_tokens), | |
"completion_tokens": len(completion_tokens), | |
"total_tokens": len(prompt_tokens) + len(completion_tokens), | |
}, | |
} | |
def create_completion( | |
self, | |
prompt: Union[str, List[int]], | |
suffix: Optional[str] = None, | |
max_tokens: Optional[int] = 16, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
logprobs: Optional[int] = None, | |
echo: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
top_k: int = 40, | |
stream: bool = False, | |
seed: Optional[int] = None, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
logit_bias: Optional[Dict[str, float]] = None, | |
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: | |
"""Generate text from a prompt. | |
Args: | |
prompt: The prompt to generate text from. | |
suffix: A suffix to append to the generated text. If None, no suffix is appended. | |
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. | |
temperature: The temperature to use for sampling. | |
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | |
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |
logprobs: The number of logprobs to return. If None, no logprobs are returned. | |
echo: Whether to echo the prompt. | |
stop: A list of strings to stop generation when encountered. | |
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt. | |
presence_penalty: The penalty to apply to tokens based on their presence in the prompt. | |
repeat_penalty: The penalty to apply to repeated tokens. | |
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
stream: Whether to stream the results. | |
seed: The seed to use for sampling. | |
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |
mirostat_mode: The mirostat sampling mode. | |
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. | |
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. | |
model: The name to use for the model in the completion object. | |
stopping_criteria: A list of stopping criteria to use. | |
logits_processor: A list of logits processors to use. | |
grammar: A grammar to use for constrained sampling. | |
logit_bias: A logit bias to use. | |
Raises: | |
ValueError: If the requested tokens exceed the context window. | |
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt. | |
Returns: | |
Response object containing the generated text. | |
""" | |
completion_or_chunks = self._create_completion( | |
prompt=prompt, | |
suffix=suffix, | |
max_tokens=-1 if max_tokens is None else max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
min_p=min_p, | |
typical_p=typical_p, | |
logprobs=logprobs, | |
echo=echo, | |
stop=stop, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=stream, | |
seed=seed, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
stopping_criteria=stopping_criteria, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
logit_bias=logit_bias, | |
) | |
if stream: | |
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks | |
return chunks | |
completion: Completion = next(completion_or_chunks) # type: ignore | |
return completion | |
def __call__( | |
self, | |
prompt: str, | |
suffix: Optional[str] = None, | |
max_tokens: Optional[int] = 16, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
logprobs: Optional[int] = None, | |
echo: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
top_k: int = 40, | |
stream: bool = False, | |
seed: Optional[int] = None, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
logit_bias: Optional[Dict[str, float]] = None, | |
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: | |
"""Generate text from a prompt. | |
Args: | |
prompt: The prompt to generate text from. | |
suffix: A suffix to append to the generated text. If None, no suffix is appended. | |
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. | |
temperature: The temperature to use for sampling. | |
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | |
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |
logprobs: The number of logprobs to return. If None, no logprobs are returned. | |
echo: Whether to echo the prompt. | |
stop: A list of strings to stop generation when encountered. | |
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt. | |
presence_penalty: The penalty to apply to tokens based on their presence in the prompt. | |
repeat_penalty: The penalty to apply to repeated tokens. | |
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
stream: Whether to stream the results. | |
seed: The seed to use for sampling. | |
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | |
mirostat_mode: The mirostat sampling mode. | |
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. | |
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. | |
model: The name to use for the model in the completion object. | |
stopping_criteria: A list of stopping criteria to use. | |
logits_processor: A list of logits processors to use. | |
grammar: A grammar to use for constrained sampling. | |
logit_bias: A logit bias to use. | |
Raises: | |
ValueError: If the requested tokens exceed the context window. | |
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt. | |
Returns: | |
Response object containing the generated text. | |
""" | |
return self.create_completion( | |
prompt=prompt, | |
suffix=suffix, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
min_p=min_p, | |
typical_p=typical_p, | |
logprobs=logprobs, | |
echo=echo, | |
stop=stop, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=stream, | |
seed=seed, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
stopping_criteria=stopping_criteria, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
logit_bias=logit_bias, | |
) | |
def create_chat_completion( | |
self, | |
messages: List[ChatCompletionRequestMessage], | |
functions: Optional[List[ChatCompletionFunction]] = None, | |
function_call: Optional[ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[ChatCompletionTool]] = None, | |
tool_choice: Optional[ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
seed: Optional[int] = None, | |
response_format: Optional[ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
logit_bias: Optional[Dict[str, float]] = None, | |
logprobs: Optional[bool] = None, | |
top_logprobs: Optional[int] = None, | |
) -> Union[ | |
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] | |
]: | |
"""Generate a chat completion from a list of messages. | |
Args: | |
messages: A list of messages to generate a response for. | |
functions: A list of functions to use for the chat completion. | |
function_call: A function call to use for the chat completion. | |
tools: A list of tools to use for the chat completion. | |
tool_choice: A tool choice to use for the chat completion. | |
temperature: The temperature to use for sampling. | |
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 | |
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | |
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. | |
stream: Whether to stream the results. | |
stop: A list of strings to stop generation when encountered. | |
seed: The seed to use for sampling. | |
response_format: The response format to use for the chat completion. Use { "type": "json_object" } to contstrain output to only valid json. | |
max_tokens: The maximum number of tokens to generate. If max_tokens <= 0 or None, the maximum number of tokens to generate is unlimited and depends on n_ctx. | |
presence_penalty: The penalty to apply to tokens based on their presence in the prompt. | |
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt. | |
repeat_penalty: The penalty to apply to repeated tokens. | |
tfs_z: The tail-free sampling parameter. | |
mirostat_mode: The mirostat sampling mode. | |
mirostat_tau: The mirostat sampling tau parameter. | |
mirostat_eta: The mirostat sampling eta parameter. | |
model: The name to use for the model in the completion object. | |
logits_processor: A list of logits processors to use. | |
grammar: A grammar to use. | |
logit_bias: A logit bias to use. | |
Returns: | |
Generated chat completion or a stream of chat completion chunks. | |
""" | |
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler( | |
self.chat_format | |
) | |
return handler( | |
llama=self, | |
messages=messages, | |
functions=functions, | |
function_call=function_call, | |
tools=tools, | |
tool_choice=tool_choice, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
logprobs=logprobs, | |
top_logprobs=top_logprobs, | |
stream=stream, | |
stop=stop, | |
seed=seed, | |
response_format=response_format, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
logit_bias=logit_bias, | |
) | |
def create_chat_completion_openai_v1( | |
self, | |
*args: Any, | |
**kwargs: Any, | |
): | |
"""Generate a chat completion with return type based on the the OpenAI v1 API. | |
OpenAI python package is required to use this method. | |
You can install it with `pip install openai`. | |
Args: | |
*args: Positional arguments to pass to create_chat_completion. | |
**kwargs: Keyword arguments to pass to create_chat_completion. | |
Returns: | |
Generated chat completion or a stream of chat completion chunks. | |
""" | |
try: | |
from openai.types.chat import ChatCompletion, ChatCompletionChunk | |
stream = kwargs.get("stream", False) # type: ignore | |
assert isinstance(stream, bool) | |
if stream: | |
return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore | |
else: | |
return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore | |
except ImportError: | |
raise ImportError( | |
"To use create_chat_completion_openai_v1, you must install the openai package." | |
"You can install it with `pip install openai`." | |
) | |
def __getstate__(self): | |
return dict( | |
model_path=self.model_path, | |
# Model Params | |
n_gpu_layers=self.model_params.n_gpu_layers, | |
split_mode=self.model_params.split_mode, | |
main_gpu=self.model_params.main_gpu, | |
tensor_split=self.tensor_split, | |
vocab_only=self.model_params.vocab_only, | |
use_mmap=self.model_params.use_mmap, | |
use_mlock=self.model_params.use_mlock, | |
kv_overrides=self.kv_overrides, | |
# Context Params | |
seed=self.context_params.seed, | |
n_ctx=self.context_params.n_ctx, | |
n_batch=self.n_batch, | |
n_threads=self.context_params.n_threads, | |
n_threads_batch=self.context_params.n_threads_batch, | |
rope_scaling_type=self.context_params.rope_scaling_type, | |
pooling_type=self.context_params.pooling_type, | |
rope_freq_base=self.context_params.rope_freq_base, | |
rope_freq_scale=self.context_params.rope_freq_scale, | |
yarn_ext_factor=self.context_params.yarn_ext_factor, | |
yarn_attn_factor=self.context_params.yarn_attn_factor, | |
yarn_beta_fast=self.context_params.yarn_beta_fast, | |
yarn_beta_slow=self.context_params.yarn_beta_slow, | |
yarn_orig_ctx=self.context_params.yarn_orig_ctx, | |
logits_all=self.context_params.logits_all, | |
embedding=self.context_params.embeddings, | |
offload_kqv=self.context_params.offload_kqv, | |
flash_attn=self.context_params.flash_attn, | |
# Sampling Params | |
last_n_tokens_size=self.last_n_tokens_size, | |
# LoRA Params | |
lora_base=self.lora_base, | |
lora_scale=self.lora_scale, | |
lora_path=self.lora_path, | |
# Backend Params | |
numa=self.numa, | |
# Chat Format Params | |
chat_format=self.chat_format, | |
chat_handler=self.chat_handler, | |
# Speculative Decidng | |
draft_model=self.draft_model, | |
# KV cache quantization | |
type_k=self.context_params.type_k, | |
type_v=self.context_params.type_v, | |
# Misc | |
verbose=self.verbose, | |
) | |
def __setstate__(self, state): | |
self.__init__(**state) | |
def save_state(self) -> LlamaState: | |
assert self._ctx.ctx is not None | |
if self.verbose: | |
print("Llama.save_state: saving llama state", file=sys.stderr) | |
state_size = llama_cpp.llama_get_state_size(self._ctx.ctx) | |
if self.verbose: | |
print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr) | |
llama_state = (ctypes.c_uint8 * int(state_size))() | |
if self.verbose: | |
print("Llama.save_state: allocated state", file=sys.stderr) | |
n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state) | |
if self.verbose: | |
print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr) | |
if int(n_bytes) > int(state_size): | |
raise RuntimeError("Failed to copy llama state data") | |
llama_state_compact = (ctypes.c_uint8 * int(n_bytes))() | |
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes)) | |
if self.verbose: | |
print( | |
f"Llama.save_state: saving {n_bytes} bytes of llama state", | |
file=sys.stderr, | |
) | |
return LlamaState( | |
scores=self._scores.copy(), | |
input_ids=self.input_ids.copy(), | |
n_tokens=self.n_tokens, | |
llama_state=bytes(llama_state_compact), | |
llama_state_size=n_bytes, | |
) | |
def load_state(self, state: LlamaState) -> None: | |
assert self._ctx.ctx is not None | |
# Only filling in up to `n_tokens` and then zero-ing out the rest | |
self.scores[: state.n_tokens, :] = state.scores.copy() | |
self.scores[state.n_tokens :, :] = 0.0 | |
self.input_ids = state.input_ids.copy() | |
self.n_tokens = state.n_tokens | |
state_size = state.llama_state_size | |
LLamaStateArrayType = ctypes.c_uint8 * state_size | |
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) | |
if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size: | |
raise RuntimeError("Failed to set llama state data") | |
def n_ctx(self) -> int: | |
"""Return the context window size.""" | |
return self._ctx.n_ctx() | |
def n_embd(self) -> int: | |
"""Return the embedding size.""" | |
return self._model.n_embd() | |
def n_vocab(self) -> int: | |
"""Return the vocabulary size.""" | |
return self._model.n_vocab() | |
def tokenizer(self) -> LlamaTokenizer: | |
"""Return the llama tokenizer for this model.""" | |
return LlamaTokenizer(self) | |
def token_eos(self) -> int: | |
"""Return the end-of-sequence token.""" | |
return self._model.token_eos() | |
def token_bos(self) -> int: | |
"""Return the beginning-of-sequence token.""" | |
return self._model.token_bos() | |
def token_nl(self) -> int: | |
"""Return the newline token.""" | |
return self._model.token_nl() | |
def pooling_type(self) -> str: | |
"""Return the pooling type.""" | |
return self._ctx.pooling_type() | |
def logits_to_logprobs( | |
logits: Union[npt.NDArray[np.single], List], axis: int = -1 | |
) -> npt.NDArray[np.single]: | |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html | |
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True) | |
if logits_maxs.ndim > 0: | |
logits_maxs[~np.isfinite(logits_maxs)] = 0 | |
elif not np.isfinite(logits_maxs): | |
logits_maxs = 0 | |
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single) | |
exp = np.exp(subtract_maxs) | |
# Suppress warnings about log of zero | |
with np.errstate(divide="ignore"): | |
summed = np.sum(exp, axis=axis, keepdims=True) | |
out = np.log(summed) | |
return subtract_maxs - out | |
def longest_token_prefix(a: Sequence[int], b: Sequence[int]): | |
longest_prefix = 0 | |
for _a, _b in zip(a, b): | |
if _a == _b: | |
longest_prefix += 1 | |
else: | |
break | |
return longest_prefix | |
def from_pretrained( | |
cls, | |
repo_id: str, | |
filename: Optional[str], | |
local_dir: Optional[Union[str, os.PathLike[str]]] = None, | |
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", | |
cache_dir: Optional[Union[str, os.PathLike[str]]] = None, | |
**kwargs: Any, | |
) -> "Llama": | |
"""Create a Llama model from a pretrained model name or path. | |
This method requires the huggingface-hub package. | |
You can install it with `pip install huggingface-hub`. | |
Args: | |
repo_id: The model repo id. | |
filename: A filename or glob pattern to match the model file in the repo. | |
local_dir: The local directory to save the model to. | |
local_dir_use_symlinks: Whether to use symlinks when downloading the model. | |
**kwargs: Additional keyword arguments to pass to the Llama constructor. | |
Returns: | |
A Llama model.""" | |
try: | |
from huggingface_hub import hf_hub_download, HfFileSystem | |
from huggingface_hub.utils import validate_repo_id | |
except ImportError: | |
raise ImportError( | |
"Llama.from_pretrained requires the huggingface-hub package. " | |
"You can install it with `pip install huggingface-hub`." | |
) | |
validate_repo_id(repo_id) | |
hffs = HfFileSystem() | |
files = [ | |
file["name"] if isinstance(file, dict) else file | |
for file in hffs.ls(repo_id) | |
] | |
# split each file into repo_id, subfolder, filename | |
file_list: List[str] = [] | |
for file in files: | |
rel_path = Path(file).relative_to(repo_id) | |
file_list.append(str(rel_path)) | |
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore | |
if len(matching_files) == 0: | |
raise ValueError( | |
f"No file found in {repo_id} that match {filename}\n\n" | |
f"Available Files:\n{json.dumps(file_list)}" | |
) | |
if len(matching_files) > 1: | |
raise ValueError( | |
f"Multiple files found in {repo_id} matching {filename}\n\n" | |
f"Available Files:\n{json.dumps(files)}" | |
) | |
(matching_file,) = matching_files | |
subfolder = str(Path(matching_file).parent) | |
filename = Path(matching_file).name | |
# download the file | |
hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
subfolder=subfolder, | |
local_dir=local_dir, | |
local_dir_use_symlinks=local_dir_use_symlinks, | |
cache_dir=cache_dir, | |
) | |
if local_dir is None: | |
model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
subfolder=subfolder, | |
local_dir=local_dir, | |
local_dir_use_symlinks=local_dir_use_symlinks, | |
cache_dir=cache_dir, | |
local_files_only=True, | |
) | |
else: | |
model_path = os.path.join(local_dir, filename) | |
return cls( | |
model_path=model_path, | |
**kwargs, | |
) | |
class LlamaState: | |
def __init__( | |
self, | |
input_ids: npt.NDArray[np.intc], | |
scores: npt.NDArray[np.single], | |
n_tokens: int, | |
llama_state: bytes, | |
llama_state_size: int, | |
): | |
self.input_ids = input_ids | |
self.scores = scores | |
self.n_tokens = n_tokens | |
self.llama_state = llama_state | |
self.llama_state_size = llama_state_size | |
LogitsProcessor = Callable[ | |
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] | |
] | |
class LogitsProcessorList(List[LogitsProcessor]): | |
def __call__( | |
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] | |
) -> npt.NDArray[np.single]: | |
for processor in self: | |
scores = processor(input_ids, scores) | |
return scores | |
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] | |
class StoppingCriteriaList(List[StoppingCriteria]): | |
def __call__( | |
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] | |
) -> bool: | |
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) | |