|
import gradio as gr |
|
|
|
from huggingface_hub import snapshot_download |
|
from langchain.document_loaders import ( |
|
CSVLoader, |
|
EverNoteLoader, |
|
PDFMinerLoader, |
|
TextLoader, |
|
UnstructuredEmailLoader, |
|
UnstructuredEPubLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredMarkdownLoader, |
|
UnstructuredODTLoader, |
|
UnstructuredPowerPointLoader, |
|
UnstructuredWordDocumentLoader, |
|
) |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.docstore.document import Document |
|
from chromadb.config import Settings |
|
from llama_cpp import Llama |
|
|
|
|
|
SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им." |
|
SYSTEM_TOKEN = 1788 |
|
USER_TOKEN = 1404 |
|
BOT_TOKEN = 9225 |
|
LINEBREAK_TOKEN = 13 |
|
|
|
ROLE_TOKENS = { |
|
"user": USER_TOKEN, |
|
"bot": BOT_TOKEN, |
|
"system": SYSTEM_TOKEN |
|
} |
|
|
|
LOADER_MAPPING = { |
|
".csv": (CSVLoader, {}), |
|
".doc": (UnstructuredWordDocumentLoader, {}), |
|
".docx": (UnstructuredWordDocumentLoader, {}), |
|
".enex": (EverNoteLoader, {}), |
|
".epub": (UnstructuredEPubLoader, {}), |
|
".html": (UnstructuredHTMLLoader, {}), |
|
".md": (UnstructuredMarkdownLoader, {}), |
|
".odt": (UnstructuredODTLoader, {}), |
|
".pdf": (PDFMinerLoader, {}), |
|
".ppt": (UnstructuredPowerPointLoader, {}), |
|
".pptx": (UnstructuredPowerPointLoader, {}), |
|
".txt": (TextLoader, {"encoding": "utf8"}), |
|
} |
|
|
|
|
|
MODEL_NAME = "ggml-model-q4_1.bin" |
|
snapshot_download( |
|
repo_id="IlyaGusev/saiga_7b_lora_llamacpp", |
|
local_dir=".", |
|
allow_patterns=MODEL_NAME |
|
) |
|
|
|
|
|
model = Llama( |
|
model_path=MODEL_NAME, |
|
n_ctx=2000, |
|
n_parts=1, |
|
) |
|
|
|
max_new_tokens = 1500 |
|
top_k = 30 |
|
top_p = 0.9 |
|
temp = 0.1 |
|
repeat_penalty = 1.15 |
|
chunk_size = 300 |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
|
|
|
def load_single_document(file_path: str) -> Document: |
|
ext = "." + file_path.rsplit(".", 1)[-1] |
|
assert ext in LOADER_MAPPING |
|
loader_class, loader_args = LOADER_MAPPING[ext] |
|
loader = loader_class(file_path, **loader_args) |
|
return loader.load()[0] |
|
|
|
|
|
def get_message_tokens(model, role, content): |
|
message_tokens = model.tokenize(content.encode("utf-8")) |
|
message_tokens.insert(1, ROLE_TOKENS[role]) |
|
message_tokens.insert(2, LINEBREAK_TOKEN) |
|
message_tokens.append(model.token_eos()) |
|
return message_tokens |
|
|
|
|
|
def get_system_tokens(model): |
|
system_message = {"role": "system", "content": SYSTEM_PROMPT} |
|
return get_message_tokens(model, **system_message) |
|
|
|
|
|
def upload_files(files, file_paths): |
|
file_paths = [f.name for f in files] |
|
return file_paths |
|
|
|
|
|
def build_index(file_paths, db): |
|
documents = [load_single_document(path) for path in file_paths] |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=20) |
|
texts = text_splitter.split_documents(documents) |
|
def fix_lines(text): |
|
lines = text.split("\n") |
|
lines = [line for line in lines if len(line.strip()) > 2] |
|
return "\n".join(lines) |
|
fixed_texts = [] |
|
for text in texts: |
|
text.page_content = fix_lines(text.page_content) |
|
if len(text.page_content) < 10: |
|
continue |
|
fixed_texts.append(text) |
|
|
|
db = Chroma.from_documents( |
|
fixed_texts, |
|
embeddings, |
|
client_settings=Settings( |
|
anonymized_telemetry=False |
|
) |
|
) |
|
return db |
|
|
|
|
|
def user(message, history, system_prompt): |
|
new_history = history + [[message, None]] |
|
return "", new_history |
|
|
|
|
|
def bot(history, system_prompt, conversation_id, db): |
|
if not history: |
|
return |
|
|
|
tokens = get_system_tokens(model)[:] |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
for user_message, bot_message in history[:-1]: |
|
message_tokens = get_message_tokens(model=model, role="user", content=user_message) |
|
tokens.extend(message_tokens) |
|
if bot_message: |
|
message_tokens = get_message_tokens(model=model, role="bot", content=bot_message) |
|
tokens.extend(message_tokens) |
|
|
|
last_user_message = history[-1][0] |
|
if db: |
|
retriever = db.as_retriever(search_kwargs={"k": 2}) |
|
docs = retriever.get_relevant_documents(last_user_message) |
|
context = "\n\n".join([doc.page_content for doc in docs]) |
|
last_user_message = f"Контекст: {context}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}" |
|
message_tokens = get_message_tokens(model=model, role="user", content=last_user_message) |
|
tokens.extend(message_tokens) |
|
|
|
role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] |
|
tokens.extend(role_tokens) |
|
generator = model.generate( |
|
tokens, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temp=temp, |
|
repeat_penalty=repeat_penalty |
|
) |
|
|
|
completion_tokens = [] |
|
partial_text = "" |
|
for i, token in enumerate(generator): |
|
completion_tokens.append(token) |
|
if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens): |
|
break |
|
partial_text = model.detokenize(completion_tokens).decode("utf-8", "ignore") |
|
history[-1][1] = partial_text |
|
yield history |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
db = gr.State(None) |
|
conversation_id = gr.State(get_uuid) |
|
favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">' |
|
gr.Markdown( |
|
f"""<h1><center>{favicon}Saiga 7B Retrieval QA Llama.cpp</center></h1> |
|
""" |
|
) |
|
|
|
system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT) |
|
|
|
file_output = gr.File(file_count="multiple") |
|
file_paths = gr.State([]) |
|
|
|
chatbot = gr.Chatbot().style(height=400) |
|
with gr.Row(): |
|
with gr.Column(): |
|
msg = gr.Textbox( |
|
label="Отправить сообщение", |
|
placeholder="Отправить сообщение", |
|
show_label=False, |
|
).style(container=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
submit = gr.Button("Отправить") |
|
stop = gr.Button("Остановить") |
|
clear = gr.Button("Очистить") |
|
|
|
upload_event = file_output.change( |
|
fn=upload_files, |
|
inputs=[file_output, file_paths], |
|
outputs=[file_paths], |
|
queue=False, |
|
).then( |
|
fn=build_index, |
|
inputs=[file_paths, db], |
|
outputs=[db], |
|
queue=True |
|
) |
|
|
|
submit_event = msg.submit( |
|
fn=user, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[chatbot, system_prompt, conversation_id, db], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
submit_click_event = submit.click( |
|
fn=user, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[chatbot, system_prompt, conversation_id, db], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue(max_size=128, concurrency_count=1) |
|
demo.launch() |
|
|