|
from flask import Flask, request, Response |
|
import logging |
|
from llama_cpp import Llama |
|
import threading |
|
from huggingface_hub import snapshot_download |
|
|
|
SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык." |
|
SYSTEM_TOKEN = 1788 |
|
USER_TOKEN = 1404 |
|
BOT_TOKEN = 9225 |
|
LINEBREAK_TOKEN = 13 |
|
|
|
ROLE_TOKENS = { |
|
"user": USER_TOKEN, |
|
"bot": BOT_TOKEN, |
|
"system": SYSTEM_TOKEN |
|
} |
|
|
|
|
|
lock = threading.Lock() |
|
|
|
app = Flask(__name__) |
|
|
|
app.logger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repo_name = "IlyaGusev/saiga2_70b_gguf" |
|
model_name = "ggml-model-q4_1.gguf" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_message_tokens(model, role, content): |
|
message_tokens = model.tokenize(content.encode("utf-8")) |
|
message_tokens.insert(1, ROLE_TOKENS[role]) |
|
message_tokens.insert(2, LINEBREAK_TOKEN) |
|
message_tokens.append(model.token_eos()) |
|
return message_tokens |
|
|
|
def get_system_tokens(model): |
|
system_message = { |
|
"role": "system", |
|
"content": SYSTEM_PROMPT |
|
} |
|
return get_message_tokens(model, **system_message) |
|
|
|
def get_system_tokens_for_preprompt(model, preprompt): |
|
system_message = { |
|
"role": "system", |
|
"content": preprompt |
|
} |
|
return get_message_tokens(model, **system_message) |
|
|
|
app.logger.info('Evaluating system tokens start') |
|
|
|
|
|
app.logger.info('Evaluating system tokens end') |
|
|
|
stop_generation = False |
|
|
|
def generate_tokens(model, generator): |
|
global stop_generation |
|
app.logger.info('generate_tokens started') |
|
|
|
for token in generator: |
|
if token == model.token_eos() or stop_generation: |
|
stop_generation = False |
|
app.logger.info('Abort generating') |
|
yield b'' |
|
break |
|
|
|
token_str = model.detokenize([token]) |
|
yield token_str |
|
|
|
@app.route('/stop_generation', methods=['GET']) |
|
def handler_stop_generation(): |
|
global stop_generation |
|
stop_generation = True |
|
return Response('Stopped', content_type='text/plain') |
|
|
|
@app.route('/', methods=['GET', 'PUT', 'DELETE', 'PATCH']) |
|
def generate_unknown_response(): |
|
app.logger.info('unknown method: '+request.method) |
|
try: |
|
request_payload = request.get_json() |
|
app.logger.info('payload: '+request.get_json()) |
|
except Exception as e: |
|
app.logger.info('payload empty') |
|
|
|
return Response('What do you want?', content_type='text/plain') |
|
|
|
@app.route('/search_request', methods=['POST']) |
|
def generate_search_request(): |
|
global stop_generation |
|
stop_generation = False |
|
|
|
|
|
|
|
data = request.get_json() |
|
app.logger.info(data) |
|
user_query = data.get("query", "") |
|
preprompt = data.get("preprompt", "Ты — русскоязычный автоматический ассистент для написании запросов для поисковых систем на русском языке. Отвечай на сообщения пользователя только текстом поискового запроса, релевантным запросу пользователя. Если запрос пользователя уже хорош, используй его в качестве результата.") |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
temperature = 0.01 |
|
truncate = parameters.get("truncate", 1000) |
|
max_new_tokens = parameters.get("max_new_tokens", 1024) |
|
top_p = 0.8 |
|
repetition_penalty = parameters.get("repetition_penalty", 1.2) |
|
top_k = 20 |
|
return_full_text = parameters.get("return_full_text", False) |
|
|
|
model = Llama( |
|
model_path=model_name, |
|
n_ctx=2000, |
|
n_parts=1, |
|
|
|
logits_all=True, |
|
|
|
verbose=True, |
|
n_gqa=8 |
|
) |
|
|
|
tokens = get_system_tokens_for_preprompt(model, preprompt) |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
tokens = get_message_tokens(model=model, role="user", content=user_query[:200]) + [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] |
|
|
|
generator = model.generate( |
|
tokens, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temp=temperature, |
|
repeat_penalty=repetition_penalty |
|
) |
|
|
|
|
|
return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True) |
|
|
|
@app.route('/', methods=['POST']) |
|
def generate_response(): |
|
global stop_generation |
|
stop_generation = False |
|
|
|
|
|
data = request.get_json() |
|
app.logger.info(data) |
|
messages = data.get("messages", []) |
|
preprompt = data.get("preprompt", "") |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
temperature = 0.02 |
|
truncate = parameters.get("truncate", 1000) |
|
max_new_tokens = parameters.get("max_new_tokens", 1024) |
|
top_p = 80 |
|
repetition_penalty = parameters.get("repetition_penalty", 1.2) |
|
top_k = 25 |
|
return_full_text = parameters.get("return_full_text", False) |
|
|
|
|
|
model = Llama( |
|
model_path=model_name, |
|
n_ctx=2000, |
|
n_parts=1, |
|
|
|
logits_all=True, |
|
|
|
verbose=True, |
|
n_gqa=8 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = get_system_tokens(model) |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
|
|
|
|
tokens = [] |
|
|
|
for message in messages: |
|
if message.get("from") == "assistant": |
|
message_tokens = get_message_tokens(model=model, role="bot", content=message.get("content", "")) |
|
else: |
|
message_tokens = get_message_tokens(model=model, role="user", content=message.get("content", "")) |
|
|
|
tokens.extend(message_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens.extend([model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]) |
|
|
|
app.logger.info('Prompt:') |
|
app.logger.info(model.detokenize(tokens).decode("utf-8", errors="ignore")) |
|
|
|
app.logger.info('Generate started') |
|
generator = model.generate( |
|
tokens, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temp=temperature, |
|
repeat_penalty=repetition_penalty |
|
) |
|
app.logger.info('Generator created') |
|
|
|
|
|
return Response(generate_tokens(model, generator), content_type='text/plain', status=200, direct_passthrough=True) |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, debug=False, threaded=False) |