Spaces:
Sleeping
Sleeping
File size: 1,967 Bytes
adc37f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Jinja2ChatFormatter
from src.const import SYSTEM_PROMPT
class LlamaCPPChatEngine:
def __init__(self, model_path):
self._model = Llama(
model_path=model_path,
n_ctx=0,
verbose=False,
)
self.n_ctx = self._model.context_params.n_ctx
self._eos_token = self._model._model.token_get_text(
int(self._model.metadata['tokenizer.ggml.eos_token_id'])
)
self._formatter = Jinja2ChatFormatter(
template=self._model.metadata['tokenizer.chat_template'],
bos_token=self._model._model.token_get_text(
int(self._model.metadata['tokenizer.ggml.bos_token_id'])
),
eos_token=self._eos_token,
stop_token_ids=self._model.metadata['tokenizer.ggml.eos_token_id']
)
self._tokenizer = self._model.tokenizer()
def chat(self, messages, user_message, context):
if context:
user_message_extended = "\n".join(context + [f"Question: {user_message}"])
else:
user_message_extended = user_message
messages = (
[
{
"role": "system",
"content": SYSTEM_PROMPT
}
] + messages + [
{
"role": "user",
"content": user_message_extended,
}
]
)
prompt = self._formatter(messages=messages).prompt
tokens = self._tokenizer.encode(prompt, add_bos=False)
n_tokens = len(tokens)
response_generator = self._model.create_completion(
tokens,
stop=self._eos_token,
max_tokens=self.n_ctx - n_tokens,
stream=True,
temperature=0
)
return response_generator, n_tokens
|