Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import ctypes | |
from typing import ( | |
List, | |
Optional, | |
Sequence, | |
) | |
from dataclasses import dataclass, field | |
import numpy as np | |
import numpy.typing as npt | |
from .llama_types import * | |
from .llama_grammar import LlamaGrammar | |
from ._utils import suppress_stdout_stderr | |
import llama_cpp.llama_cpp as llama_cpp | |
# Python wrappers over llama.h structs | |
class _LlamaModel: | |
"""Intermediate Python wrapper for a llama.cpp llama_model. | |
NOTE: For stability it's recommended you use the Llama class instead.""" | |
_llama_free_model = None | |
# NOTE: this must be "saved" here to avoid exceptions when calling __del__ | |
def __init__( | |
self, | |
*, | |
path_model: str, | |
params: llama_cpp.llama_model_params, | |
verbose: bool = True, | |
): | |
self.path_model = path_model | |
self.params = params | |
self.verbose = verbose | |
self._llama_free_model = llama_cpp._lib.llama_free_model # type: ignore | |
self.model = None | |
if not os.path.exists(path_model): | |
raise ValueError(f"Model path does not exist: {path_model}") | |
with suppress_stdout_stderr(disable=verbose): | |
self.model = llama_cpp.llama_load_model_from_file( | |
self.path_model.encode("utf-8"), self.params | |
) | |
if self.model is None: | |
raise ValueError(f"Failed to load model from file: {path_model}") | |
def __del__(self): | |
if self.model is not None and self._llama_free_model is not None: | |
self._llama_free_model(self.model) | |
self.model = None | |
def vocab_type(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_vocab_type(self.model) | |
def n_vocab(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_n_vocab(self.model) | |
def n_ctx_train(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_n_ctx_train(self.model) | |
def n_embd(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_n_embd(self.model) | |
def rope_freq_scale_train(self) -> float: | |
assert self.model is not None | |
return llama_cpp.llama_rope_freq_scale_train(self.model) | |
def desc(self) -> str: | |
assert self.model is not None | |
buf = ctypes.create_string_buffer(1024) | |
llama_cpp.llama_model_desc(self.model, buf, 1024) | |
return buf.value.decode("utf-8") | |
def size(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_model_size(self.model) | |
def n_params(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_model_n_params(self.model) | |
def get_tensor(self, name: str) -> ctypes.c_void_p: | |
assert self.model is not None | |
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) | |
def apply_lora_from_file( | |
self, | |
lora_path: str, | |
scale: float, | |
path_base_model: Optional[str], | |
n_threads: int, | |
): | |
assert self.model is not None | |
return llama_cpp.llama_model_apply_lora_from_file( | |
self.model, | |
lora_path.encode("utf-8"), | |
scale, | |
path_base_model.encode("utf-8") | |
if path_base_model is not None | |
else ctypes.c_char_p(0), | |
n_threads, | |
) | |
# Vocab | |
def token_get_text(self, token: int) -> str: | |
# TODO: Fix | |
assert self.model is not None | |
return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") | |
def token_get_score(self, token: int) -> float: | |
assert self.model is not None | |
return llama_cpp.llama_token_get_score(self.model, token) | |
def token_get_type(self, token: int) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_get_type(self.model, token) | |
# Special tokens | |
def token_bos(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_bos(self.model) | |
def token_eos(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_eos(self.model) | |
def token_nl(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_nl(self.model) | |
def token_prefix(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_prefix(self.model) | |
def token_middle(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_middle(self.model) | |
def token_suffix(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_suffix(self.model) | |
def token_eot(self) -> int: | |
assert self.model is not None | |
return llama_cpp.llama_token_eot(self.model) | |
# Tokenization | |
def tokenize(self, text: bytes, add_bos: bool, special: bool): | |
assert self.model is not None | |
n_ctx = self.n_ctx_train() | |
tokens = (llama_cpp.llama_token * n_ctx)() | |
n_tokens = llama_cpp.llama_tokenize( | |
self.model, text, len(text), tokens, n_ctx, add_bos, special | |
) | |
if n_tokens < 0: | |
n_tokens = abs(n_tokens) | |
tokens = (llama_cpp.llama_token * n_tokens)() | |
n_tokens = llama_cpp.llama_tokenize( | |
self.model, text, len(text), tokens, n_tokens, add_bos, special | |
) | |
if n_tokens < 0: | |
raise RuntimeError( | |
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}' | |
) | |
return list(tokens[:n_tokens]) | |
def token_to_piece(self, token: int, special: bool = False) -> bytes: | |
assert self.model is not None | |
buf = ctypes.create_string_buffer(32) | |
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, special) | |
return bytes(buf) | |
def detokenize(self, tokens: List[int], special: bool = False) -> bytes: | |
assert self.model is not None | |
output = b"" | |
size = 32 | |
buffer = (ctypes.c_char * size)() | |
for token in tokens: | |
n = llama_cpp.llama_token_to_piece( | |
self.model, llama_cpp.llama_token(token), buffer, size, special | |
) | |
assert n <= size | |
output += bytes(buffer[:n]) | |
# NOTE: Llama1 models automatically added a space at the start of the prompt | |
# this line removes a leading space if the first token is a beginning of sentence token | |
return ( | |
output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b' ' else output | |
) | |
# Extra | |
def metadata(self) -> Dict[str, str]: | |
assert self.model is not None | |
metadata: Dict[str, str] = {} | |
buffer_size = 1024 | |
buffer = ctypes.create_string_buffer(buffer_size) | |
# zero the buffer | |
buffer.value = b'\0' * buffer_size | |
# iterate over model keys | |
for i in range(llama_cpp.llama_model_meta_count(self.model)): | |
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) | |
if nbytes > buffer_size: | |
buffer_size = nbytes + 1 | |
buffer = ctypes.create_string_buffer(buffer_size) | |
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) | |
key = buffer.value.decode("utf-8") | |
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) | |
if nbytes > buffer_size: | |
buffer_size = nbytes + 1 | |
buffer = ctypes.create_string_buffer(buffer_size) | |
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) | |
value = buffer.value.decode("utf-8") | |
metadata[key] = value | |
return metadata | |
def default_params(): | |
"""Get the default llama_model_params.""" | |
return llama_cpp.llama_model_default_params() | |
class _LlamaContext: | |
"""Intermediate Python wrapper for a llama.cpp llama_context. | |
NOTE: For stability it's recommended you use the Llama class instead.""" | |
_llama_free = None | |
def __init__( | |
self, | |
*, | |
model: _LlamaModel, | |
params: llama_cpp.llama_context_params, | |
verbose: bool = True, | |
): | |
self.model = model | |
self.params = params | |
self.verbose = verbose | |
self._llama_free = llama_cpp._lib.llama_free # type: ignore | |
self.ctx = None | |
assert self.model.model is not None | |
self.ctx = llama_cpp.llama_new_context_with_model( | |
self.model.model, self.params | |
) | |
if self.ctx is None: | |
raise ValueError("Failed to create llama_context") | |
def __del__(self): | |
if self.ctx is not None and self._llama_free is not None: | |
self._llama_free(self.ctx) | |
self.ctx = None | |
def n_ctx(self) -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_n_ctx(self.ctx) | |
def pooling_type(self) -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_pooling_type(self.ctx) | |
def kv_cache_clear(self): | |
assert self.ctx is not None | |
llama_cpp.llama_kv_cache_clear(self.ctx) | |
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): | |
assert self.ctx is not None | |
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) | |
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): | |
assert self.ctx is not None | |
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) | |
def kv_cache_seq_keep(self, seq_id: int): | |
assert self.ctx is not None | |
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) | |
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): | |
assert self.ctx is not None | |
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) | |
def get_state_size(self) -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_get_state_size(self.ctx) | |
# TODO: copy_state_data | |
# TODO: set_state_data | |
# TODO: llama_load_session_file | |
# TODO: llama_save_session_file | |
def decode(self, batch: "_LlamaBatch"): | |
assert self.ctx is not None | |
assert batch.batch is not None | |
return_code = llama_cpp.llama_decode( | |
self.ctx, | |
batch.batch, | |
) | |
if return_code != 0: | |
raise RuntimeError(f"llama_decode returned {return_code}") | |
def set_n_threads(self, n_threads: int, n_threads_batch: int): | |
assert self.ctx is not None | |
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) | |
def get_logits(self): | |
assert self.ctx is not None | |
return llama_cpp.llama_get_logits(self.ctx) | |
def get_logits_ith(self, i: int): | |
assert self.ctx is not None | |
return llama_cpp.llama_get_logits_ith(self.ctx, i) | |
def get_embeddings(self): | |
assert self.ctx is not None | |
return llama_cpp.llama_get_embeddings(self.ctx) | |
# Sampling functions | |
def set_rng_seed(self, seed: int): | |
assert self.ctx is not None | |
llama_cpp.llama_set_rng_seed(self.ctx, seed) | |
def sample_repetition_penalties( | |
self, | |
candidates: "_LlamaTokenDataArray", | |
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", | |
penalty_last_n: int, | |
penalty_repeat: float, | |
penalty_freq: float, | |
penalty_present: float, | |
): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_repetition_penalties( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
last_tokens_data, | |
penalty_last_n, | |
penalty_repeat, | |
penalty_freq, | |
penalty_present, | |
) | |
def sample_softmax(self, candidates: "_LlamaTokenDataArray"): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_softmax( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
) | |
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_top_k( | |
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep | |
) | |
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_top_p( | |
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
) | |
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_min_p( | |
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
) | |
def sample_tail_free( | |
self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int | |
): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_tail_free( | |
self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep | |
) | |
def sample_typical( | |
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int | |
): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_typical( | |
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep | |
) | |
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): | |
assert self.ctx is not None | |
llama_cpp.llama_sample_temp( | |
self.ctx, llama_cpp.byref(candidates.candidates), temp | |
) | |
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): | |
assert self.ctx is not None | |
assert grammar.grammar is not None | |
llama_cpp.llama_sample_grammar( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
grammar.grammar, | |
) | |
def sample_token_mirostat( | |
self, | |
candidates: "_LlamaTokenDataArray", | |
tau: float, | |
eta: float, | |
m: int, | |
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], | |
) -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_sample_token_mirostat( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
tau, | |
eta, | |
m, | |
mu, | |
) | |
def sample_token_mirostat_v2( | |
self, candidates: "_LlamaTokenDataArray", tau: float, eta: float, mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float] | |
) -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_sample_token_mirostat_v2( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
tau, | |
eta, | |
mu, | |
) | |
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_sample_token_greedy( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
) | |
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: | |
assert self.ctx is not None | |
return llama_cpp.llama_sample_token( | |
self.ctx, | |
llama_cpp.byref(candidates.candidates), | |
) | |
# Grammar | |
def grammar_accept_token(self, grammar: LlamaGrammar, token: int): | |
assert self.ctx is not None | |
assert grammar.grammar is not None | |
llama_cpp.llama_grammar_accept_token(self.ctx, grammar.grammar, token) | |
def reset_timings(self): | |
assert self.ctx is not None | |
llama_cpp.llama_reset_timings(self.ctx) | |
def print_timings(self): | |
assert self.ctx is not None | |
llama_cpp.llama_print_timings(self.ctx) | |
# Utility functions | |
def default_params(): | |
"""Get the default llama_context_params.""" | |
return llama_cpp.llama_context_default_params() | |
class _LlamaBatch: | |
_llama_batch_free = None | |
def __init__( | |
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True | |
): | |
self._n_tokens = n_tokens | |
self.embd = embd | |
self.n_seq_max = n_seq_max | |
self.verbose = verbose | |
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore | |
self.batch = None | |
self.batch = llama_cpp.llama_batch_init( | |
self._n_tokens, self.embd, self.n_seq_max | |
) | |
def __del__(self): | |
if self.batch is not None and self._llama_batch_free is not None: | |
self._llama_batch_free(self.batch) | |
self.batch = None | |
def n_tokens(self) -> int: | |
assert self.batch is not None | |
return self.batch.n_tokens | |
def reset(self): | |
assert self.batch is not None | |
self.batch.n_tokens = 0 | |
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): | |
assert self.batch is not None | |
n_tokens = len(batch) | |
self.batch.n_tokens = n_tokens | |
for i in range(n_tokens): | |
self.batch.token[i] = batch[i] | |
self.batch.pos[i] = n_past + i | |
self.batch.seq_id[i][0] = 0 | |
self.batch.n_seq_id[i] = 1 | |
self.batch.logits[i] = logits_all | |
self.batch.logits[n_tokens - 1] = True | |
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): | |
assert self.batch is not None | |
n_tokens = len(batch) | |
n_tokens0 = self.batch.n_tokens | |
self.batch.n_tokens += n_tokens | |
for i in range(n_tokens): | |
j = n_tokens0 + i | |
self.batch.token[j] = batch[i] | |
self.batch.pos[j] = i | |
self.batch.seq_id[j][0] = seq_id | |
self.batch.n_seq_id[j] = 1 | |
self.batch.logits[j] = logits_all | |
self.batch.logits[n_tokens - 1] = True | |
class _LlamaTokenDataArray: | |
def __init__(self, *, n_vocab: int): | |
self.n_vocab = n_vocab | |
self.candidates_data = np.array( | |
[], | |
dtype=np.dtype( | |
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True | |
), | |
) | |
self.candidates_data.resize(3, self.n_vocab, refcheck=False) | |
self.candidates = llama_cpp.llama_token_data_array( | |
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), | |
size=self.n_vocab, | |
sorted=False, | |
) | |
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore | |
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single) | |
def copy_logits(self, logits: npt.NDArray[np.single]): | |
self.candidates_data["id"][:] = self.default_candidates_data_id | |
self.candidates_data["logit"][:] = logits | |
self.candidates_data["p"][:] = self.default_candidates_data_p | |
self.candidates.data = self.candidates_data.ctypes.data_as( | |
llama_cpp.llama_token_data_p | |
) | |
self.candidates.sorted = ctypes.c_bool(False) | |
self.candidates.size = ctypes.c_size_t(self.n_vocab) | |
# Python wrappers over common/common | |
def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: | |
assert model.model is not None | |
n_tokens = len(text) + 1 if add_bos else len(text) | |
result = (llama_cpp.llama_token * n_tokens)() | |
n_tokens = llama_cpp.llama_tokenize( | |
model.model, | |
text.encode("utf-8"), | |
len(text), | |
result, | |
n_tokens, | |
add_bos, | |
special, | |
) | |
if n_tokens < 0: | |
result = (llama_cpp.llama_token * -n_tokens)() | |
check = llama_cpp.llama_tokenize( | |
model.model, | |
text.encode("utf-8"), | |
len(text), | |
result, | |
len(result), | |
add_bos, | |
special, | |
) | |
if check != -n_tokens: | |
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') | |
else: | |
result = result[:n_tokens] | |
return list(result) | |
def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str: | |
assert model.model is not None | |
result = (ctypes.c_char * 8)(0) | |
n_tokens = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) | |
if n_tokens < 0: | |
result = (ctypes.c_char * -n_tokens)(0) | |
check = llama_cpp.llama_token_to_piece(model.model, token, result, len(result), special) | |
if check != -n_tokens: | |
raise RuntimeError(f"Failed to get piece: token={token}") | |
else: | |
result = result[:n_tokens] | |
return bytes(result).decode("utf-8") | |
def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: | |
bos_id = model.token_bos() | |
result = "" | |
for i, token in enumerate(tokens): | |
piece = _token_to_piece(model, token) | |
if ( | |
(tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) | |
) and piece[0] == " ": | |
piece = piece[1:] | |
result += piece | |
return result | |
def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: | |
result = "" | |
for token in tokens: | |
piece = _token_to_piece(model, token) | |
result += piece | |
return result | |
def _should_add_bos(model: _LlamaModel) -> bool: | |
assert model.model is not None | |
add_bos = llama_cpp.llama_add_bos_token(model.model) | |
if add_bos != -1: | |
return add_bos != 0 | |
else: | |
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM | |
# Embedding functions | |
def _normalize_embedding(embedding): | |
norm = float(np.linalg.norm(embedding)) | |
if norm == 0.0: | |
return embedding | |
return [v / norm for v in embedding] | |
# Python wrappers over common/sampling structs | |
class _LlamaSamplingParams: | |
n_prev: int = 64 | |
n_probs: int = 0 | |
top_k: int = 40 | |
top_p: float = 0.95 | |
min_p: float = 0.05 | |
tfs_z: float = 1.00 | |
typical_p: float = 1.00 | |
temp: float = 0.80 | |
penalty_last_n: int = 64 | |
penalty_repeat: float = 1.10 | |
penalty_freq: float = 0.00 | |
penalty_present: float = 0.00 | |
mirostat: int = 0 | |
mirostat_tau: float = 5.00 | |
mirostat_eta: float = 0.10 | |
penalize_nl: bool = True | |
grammar: str = "" | |
cfg_negative_prompt: str = "" | |
cfg_scale: float = 1.00 | |
logit_bias: dict[int, float] = field(default_factory=dict) | |
class _LlamaSamplingContext: | |
params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) | |
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) | |
grammar: Optional[LlamaGrammar] = None | |
# NOTE: Missing parsed_grammar | |
prev: list[int] = field(default_factory=list) | |
cur: list[llama_cpp.llama_token_data] = field(default_factory=list) | |
def reset(self): | |
self.prev = [] | |
self.cur = [] | |
if self.grammar is not None: | |
self.grammar.reset() | |
def cp(self): | |
return _LlamaSamplingContext( | |
params=self.params, | |
mirostat_mu=self.mirostat_mu, | |
grammar=self.grammar, | |
prev=self.prev.copy(), | |
cur=self.cur.copy(), | |
) | |
def last(self) -> Optional[int]: | |
if len(self.prev) > 0: | |
return self.prev[-1] | |
else: | |
return None | |
def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: | |
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") | |
def sample( | |
self, ctx_main: _LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None | |
): | |
n_vocab = ctx_main.model.n_vocab() | |
id: int = 0 | |
if logits_array is None: | |
logits = ctx_main.get_logits_ith(idx) | |
logits_array = np.array( | |
ctypes.cast(logits, ctypes.POINTER(ctypes.c_float * n_vocab)).contents, | |
dtype=np.single, | |
) | |
# apply logit_bias | |
for token, logit_bias in self.params.logit_bias.items(): | |
logits_array[token] += logit_bias | |
token_data_array = _LlamaTokenDataArray( | |
n_vocab=n_vocab | |
) # TODO: Only create this once | |
token_data_array.copy_logits(logits_array) | |
# apply penalties | |
if len(self.prev) > 0: | |
nl_token = ctx_main.model.token_nl() | |
nl_logit = logits_array[nl_token] | |
last_tokens = self.prev[-self.params.penalty_last_n:] | |
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n) | |
if last_tokens_size > 0: | |
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens) | |
ctx_main.sample_repetition_penalties( | |
token_data_array, | |
last_tokens_p, | |
last_tokens_size, | |
self.params.penalty_repeat, | |
self.params.penalty_freq, | |
self.params.penalty_present, | |
) | |
if not self.params.penalize_nl: | |
token_data_array.candidates_data["logit"][nl_token] = nl_logit | |
if self.grammar is not None: | |
ctx_main.sample_grammar(token_data_array, self.grammar) | |
if self.params.temp < 0: | |
ctx_main.sample_softmax(token_data_array) | |
id = token_data_array.candidates_data["id"][0] | |
elif self.params.temp == 0: | |
id = ctx_main.sample_token_greedy(token_data_array) | |
else: | |
if self.params.mirostat == 1: | |
mirostat_m = 100 | |
ctx_main.sample_temp(token_data_array, self.params.temp) | |
id = ctx_main.sample_token_mirostat( | |
token_data_array, | |
self.params.mirostat_tau, | |
self.params.mirostat_eta, | |
mirostat_m, | |
ctypes.pointer(self.mirostat_mu), | |
) | |
elif self.params.mirostat == 2: | |
ctx_main.sample_temp(token_data_array, self.params.temp) | |
id = ctx_main.sample_token_mirostat_v2( | |
token_data_array, | |
self.params.mirostat_tau, | |
self.params.mirostat_eta, | |
ctypes.pointer(self.mirostat_mu), | |
) | |
else: | |
min_keep = max(1, self.params.n_probs) | |
ctx_main.sample_top_k( | |
token_data_array, self.params.top_k, min_keep=min_keep | |
) | |
ctx_main.sample_tail_free( | |
token_data_array, self.params.tfs_z, min_keep=min_keep | |
) | |
ctx_main.sample_typical( | |
token_data_array, self.params.typical_p, min_keep=min_keep | |
) | |
ctx_main.sample_top_p( | |
token_data_array, self.params.top_p, min_keep=min_keep | |
) | |
ctx_main.sample_min_p( | |
token_data_array, self.params.min_p, min_keep=min_keep | |
) | |
ctx_main.sample_temp(token_data_array, self.params.temp) | |
id = ctx_main.sample_token(token_data_array) | |
return id | |
def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): | |
if apply_grammar and self.grammar is not None: | |
ctx_main.grammar_accept_token(self.grammar, id) | |
self.prev.append(id) | |