saiga-api-cuda-v2-b13 / llm_backend.py
muryshev's picture
Update llm_backend.py
77d95ba
raw
history blame
6.92 kB
from llama_cpp import Llama
import gc
import threading
import logging
import sys
log = logging.getLogger('llm_api.backend')
class LlmBackend:
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
SYSTEM_TOKEN = 1788
USER_TOKEN = 1404
BOT_TOKEN = 9225
LINEBREAK_TOKEN = 13
ROLE_TOKENS = {
"user": USER_TOKEN,
"bot": BOT_TOKEN,
"system": SYSTEM_TOKEN
}
_instance = None
_model = None
_model_params = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super(LlmBackend, cls).__new__(cls)
return cls._instance
def is_model_loaded(self):
return self._model is not None
def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, chat_format='llama-2'):
log.info('load_model - started')
self._model_params = {}
self._model_params['model_path'] = model_path
self._model_params['context_size'] = context_size
self._model_params['enable_gpu'] = enable_gpu
self._model_params['gpu_layer_number'] = gpu_layer_number
self._model_params['chat_format'] = chat_format
if self._model is not None:
self.unload_model()
with self._lock:
if enable_gpu:
self._model = Llama(
model_path=model_path,
chat_format=chat_format,
n_ctx=context_size,
n_parts=1,
#n_batch=100,
logits_all=True,
#n_threads=12,
verbose=True,
n_gpu_layers=gpu_layer_number
)
log.info('load_model - finished')
return self._model
else:
self._model = Llama(
model_path=model_path,
chat_format=chat_format,
n_ctx=context_size,
n_parts=1,
#n_batch=100,
logits_all=True,
#n_threads=12,
verbose=True
)
log.info('load_model - finished')
return self._model
def set_system_prompt(self, prompt):
with self._lock:
self.SYSTEM_PROMPT = prompt
def unload_model(self):
log.info('unload_model - started')
with self._lock:
if self._model is not None:
del self._model
log.info('unload_model - finished')
def ensure_model_is_loaded(self):
log.info('ensure_model_is_loaded - started')
if not self.is_model_loaded():
log.info('ensure_model_is_loaded - model reloading')
if self._model_params is not None:
self.load_model(**self._model_params)
else:
log.info('ensure_model_is_loaded - No model config found. Reloading can not be done.')
log.info('ensure_model_is_loaded - finished')
def generate_tokens(self, generator):
log.info('generate_tokens - started')
with self._lock:
self.ensure_model_is_loaded()
try:
for token in generator:
if token == self._model.token_eos():
log.info('generate_tokens - finished')
yield b'' # End of chunk
break
token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
yield token_str
except Exception as e:
log.error('generate_tokens - error')
log.error(e)
yield b'' # End of chunk
def create_chat_completion(self, messages, stream=True):
log.info('create_chat_completion called')
with self._lock:
log.info('create_chat_completion started')
try:
return self._model.create_chat_completion(messages=messages, stream=stream)
except Exception as e:
log.error('create_chat_completion - error')
log.error(e)
return None
def get_message_tokens(self, role, content):
log.info('get_message_tokens - started')
self.ensure_model_is_loaded()
message_tokens = self._model.tokenize(content.encode("utf-8"))
message_tokens.insert(1, self.ROLE_TOKENS[role])
message_tokens.insert(2, self.LINEBREAK_TOKEN)
message_tokens.append(self._model.token_eos())
log.info('get_message_tokens - finished')
return message_tokens
def get_system_tokens(self):
return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT)
def create_chat_generator_for_saiga(self, messages, parameters, use_system_prompt=True):
log.info('create_chat_generator_for_saiga - started')
with self._lock:
self.ensure_model_is_loaded()
tokens = self.get_system_tokens() if use_system_prompt else []
for message in messages:
message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", ""))
tokens.extend(message_tokens)
tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN])
generator = self._model.generate(
tokens,
top_k=parameters['top_k'],
top_p=parameters['top_p'],
temp=parameters['temperature'],
repeat_penalty=parameters['repetition_penalty']
)
log.info('create_chat_generator_for_saiga - finished')
return generator
def generate_tokens(self, generator):
log.info('generate_tokens - started')
with self._lock:
self.ensure_model_is_loaded()
try:
for token in generator:
if token == self._model.token_eos():
yield b'' # End of chunk
log.info('generate_tokens - finished')
break
token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
yield token_str
except Exception as e:
log.error('generate_tokens - error')
log.error(e)
yield b'' # End of chunk