|
from llama_cpp import Llama |
|
import gc |
|
import threading |
|
import logging |
|
import sys |
|
|
|
log = logging.getLogger('llm_api.backend') |
|
|
|
class LlmBackend: |
|
|
|
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык." |
|
SYSTEM_TOKEN = 1788 |
|
USER_TOKEN = 1404 |
|
BOT_TOKEN = 9225 |
|
LINEBREAK_TOKEN = 13 |
|
|
|
ROLE_TOKENS = { |
|
"user": USER_TOKEN, |
|
"bot": BOT_TOKEN, |
|
"system": SYSTEM_TOKEN |
|
} |
|
|
|
_instance = None |
|
_model = None |
|
_model_params = None |
|
_lock = threading.Lock() |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super(LlmBackend, cls).__new__(cls) |
|
return cls._instance |
|
|
|
|
|
def is_model_loaded(self): |
|
return self._model is not None |
|
|
|
def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, chat_format='llama-2'): |
|
log.info('load_model - started') |
|
self._model_params = {} |
|
self._model_params['model_path'] = model_path |
|
self._model_params['context_size'] = context_size |
|
self._model_params['enable_gpu'] = enable_gpu |
|
self._model_params['gpu_layer_number'] = gpu_layer_number |
|
self._model_params['chat_format'] = chat_format |
|
|
|
if self._model is not None: |
|
self.unload_model() |
|
|
|
with self._lock: |
|
if enable_gpu: |
|
self._model = Llama( |
|
model_path=model_path, |
|
chat_format=chat_format, |
|
n_ctx=context_size, |
|
n_parts=1, |
|
|
|
logits_all=True, |
|
|
|
verbose=True, |
|
n_gpu_layers=gpu_layer_number |
|
) |
|
log.info('load_model - finished') |
|
return self._model |
|
else: |
|
self._model = Llama( |
|
model_path=model_path, |
|
chat_format=chat_format, |
|
n_ctx=context_size, |
|
n_parts=1, |
|
|
|
logits_all=True, |
|
|
|
verbose=True |
|
) |
|
log.info('load_model - finished') |
|
return self._model |
|
|
|
def set_system_prompt(self, prompt): |
|
with self._lock: |
|
self.SYSTEM_PROMPT = prompt |
|
|
|
def unload_model(self): |
|
log.info('unload_model - started') |
|
with self._lock: |
|
if self._model is not None: |
|
del self._model |
|
log.info('unload_model - finished') |
|
|
|
def ensure_model_is_loaded(self): |
|
log.info('ensure_model_is_loaded - started') |
|
if not self.is_model_loaded(): |
|
log.info('ensure_model_is_loaded - model reloading') |
|
if self._model_params is not None: |
|
self.load_model(**self._model_params) |
|
else: |
|
log.info('ensure_model_is_loaded - No model config found. Reloading can not be done.') |
|
log.info('ensure_model_is_loaded - finished') |
|
|
|
def generate_tokens(self, generator): |
|
log.info('generate_tokens - started') |
|
with self._lock: |
|
self.ensure_model_is_loaded() |
|
|
|
try: |
|
for token in generator: |
|
if token == self._model.token_eos(): |
|
log.info('generate_tokens - finished') |
|
yield b'' |
|
break |
|
|
|
token_str = self._model.detokenize([token]) |
|
yield token_str |
|
except Exception as e: |
|
log.error('generate_tokens - error') |
|
log.error(e) |
|
yield b'' |
|
|
|
def create_chat_completion(self, messages, stream=True): |
|
log.info('create_chat_completion called') |
|
with self._lock: |
|
log.info('create_chat_completion started') |
|
try: |
|
return self._model.create_chat_completion(messages=messages, stream=stream) |
|
except Exception as e: |
|
log.error('create_chat_completion - error') |
|
log.error(e) |
|
return None |
|
|
|
|
|
def get_message_tokens(self, role, content): |
|
log.info('get_message_tokens - started') |
|
self.ensure_model_is_loaded() |
|
message_tokens = self._model.tokenize(content.encode("utf-8")) |
|
message_tokens.insert(1, self.ROLE_TOKENS[role]) |
|
message_tokens.insert(2, self.LINEBREAK_TOKEN) |
|
message_tokens.append(self._model.token_eos()) |
|
log.info('get_message_tokens - finished') |
|
return message_tokens |
|
|
|
def get_system_tokens(self): |
|
return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT) |
|
|
|
def create_chat_generator_for_saiga(self, messages, parameters, use_system_prompt=True): |
|
log.info('create_chat_generator_for_saiga - started') |
|
with self._lock: |
|
self.ensure_model_is_loaded() |
|
tokens = self.get_system_tokens() if use_system_prompt else [] |
|
for message in messages: |
|
message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", "")) |
|
tokens.extend(message_tokens) |
|
|
|
tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN]) |
|
generator = self._model.generate( |
|
tokens, |
|
top_k=parameters['top_k'], |
|
top_p=parameters['top_p'], |
|
temp=parameters['temperature'], |
|
repeat_penalty=parameters['repetition_penalty'] |
|
) |
|
log.info('create_chat_generator_for_saiga - finished') |
|
return generator |
|
|
|
def generate_tokens(self, generator): |
|
log.info('generate_tokens - started') |
|
with self._lock: |
|
self.ensure_model_is_loaded() |
|
try: |
|
for token in generator: |
|
if token == self._model.token_eos(): |
|
yield b'' |
|
log.info('generate_tokens - finished') |
|
break |
|
|
|
token_str = self._model.detokenize([token]) |
|
yield token_str |
|
except Exception as e: |
|
log.error('generate_tokens - error') |
|
log.error(e) |
|
yield b'' |
|
|
|
|
|
|