|
from typing import Callable, Generator, Iterator, List, Optional, Union |
|
import ctypes |
|
from ctypes import ( |
|
c_bool, |
|
c_char_p, |
|
c_int, |
|
c_int8, |
|
c_int32, |
|
c_uint8, |
|
c_uint32, |
|
c_size_t, |
|
c_float, |
|
c_double, |
|
c_void_p, |
|
POINTER, |
|
_Pointer, |
|
Structure, |
|
Array, |
|
) |
|
import pathlib |
|
import os |
|
import sys |
|
|
|
|
|
def _load_shared_library(lib_base_name: str): |
|
|
|
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) |
|
|
|
|
|
_lib_paths: List[pathlib.Path] = [] |
|
|
|
if sys.platform.startswith("linux"): |
|
_lib_paths += [ |
|
_base_path / f"lib{lib_base_name}.so", |
|
] |
|
else: |
|
raise RuntimeError("Unsupported platform") |
|
|
|
if "LLAMA2_CU_LIB" in os.environ: |
|
lib_base_name = os.environ["LLAMA2_CU_LIB"] |
|
_lib = pathlib.Path(lib_base_name) |
|
_base_path = _lib.parent.resolve() |
|
_lib_paths = [_lib.resolve()] |
|
|
|
cdll_args = dict() |
|
|
|
|
|
|
|
for _lib_path in _lib_paths: |
|
if _lib_path.exists(): |
|
try: |
|
return ctypes.CDLL(str(_lib_path), **cdll_args) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") |
|
|
|
raise FileNotFoundError( |
|
f"Shared library with base name '{lib_base_name}' not found" |
|
) |
|
|
|
|
|
_lib_base_name = "llama2" |
|
|
|
|
|
_lib = _load_shared_library(_lib_base_name) |
|
|
|
|
|
def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p: |
|
return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8')) |
|
|
|
_lib.llama2_init.argtypes = [c_char_p, c_char_p] |
|
_lib.llama2_init.restype = c_void_p |
|
|
|
def llama2_free(ctx: c_void_p) -> None: |
|
_lib.llama2_free(ctx) |
|
|
|
_lib.llama2_free.argtypes = [c_void_p] |
|
_lib.llama2_free.restype = None |
|
|
|
def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int: |
|
return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed) |
|
|
|
_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int] |
|
_lib.llama2_generate.restype = c_int |
|
|
|
def llama2_get_last(ctx: c_void_p) -> bytes: |
|
return _lib.llama2_get_last(ctx) |
|
|
|
_lib.llama2_get_last.argtypes = [c_void_p] |
|
_lib.llama2_get_last.restype = c_char_p |
|
|
|
def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]: |
|
tokens = (c_int * (len(text) + 3))() |
|
n_tokens = (c_int * 1)() |
|
_lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens) |
|
return tokens[:n_tokens[0]] |
|
|
|
_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)] |
|
_lib.llama2_tokenize.restype = None |
|
|
|
class Llama2: |
|
def __init__( |
|
self, |
|
model_path: str, |
|
tokenizer_path: str='tokenizer.bin', |
|
n_ctx: int = 0, |
|
n_batch: int = 0) -> None: |
|
self.n_ctx = n_ctx |
|
self.n_batch = n_batch |
|
self.llama2_ctx = llama2_init(model_path, tokenizer_path) |
|
|
|
def tokenize( |
|
self, text: str, add_bos: bool = True, add_eos: bool = False |
|
) -> List[int]: |
|
return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos) |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
max_tokens: int = 128, |
|
temperature: float = 0.8, |
|
top_p: float = 0.95, |
|
min_p: float = 0.05, |
|
typical_p: float = 1.0, |
|
logprobs: Optional[int] = None, |
|
frequency_penalty: float = 0.0, |
|
presence_penalty: float = 0.0, |
|
repeat_penalty: float = 1.1, |
|
top_k: int = 40, |
|
stream: bool = False, |
|
seed: Optional[int] = None, |
|
) -> Iterator[str]: |
|
if seed is None: |
|
seed = 42 |
|
ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed) |
|
if ret != 0: |
|
raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'") |
|
bytes_buffer = b'' |
|
while True: |
|
result = llama2_get_last(self.llama2_ctx) |
|
if result is None: |
|
break |
|
bytes_buffer += result |
|
try: |
|
string = bytes_buffer.decode('utf-8') |
|
except UnicodeDecodeError: |
|
pass |
|
else: |
|
bytes_buffer = b'' |
|
yield string |
|
|
|
|
|
|