Spaces:
Runtime error
Runtime error
from threading import Thread | |
from huggingface_hub import hf_hub_download | |
from llama_cpp import Llama | |
import torch | |
import gradio as gr | |
import re | |
import asyncio | |
import requests | |
import shutil | |
from langchain.llms import LlamaCpp | |
from langchain import PromptTemplate, LLMChain | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Running on device:", torch_device) | |
print("CPU threads:", torch.get_num_threads()) | |
llm = LlamaCpp( | |
model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin', | |
temperature=0.5, | |
top_p=0.9, | |
max_tokens=80, | |
verbose=True, | |
n_ctx=2048, | |
n_gpu_layers=-1, | |
f16_kv=True | |
) | |
# μλ² λ© λͺ¨λΈ λ‘λ | |
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large") | |
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ° | |
docsearch = FAISS.load_local("", embeddings) | |
embeddings_filter = EmbeddingsFilter( | |
embeddings=embeddings, | |
similarity_threshold=0.7, | |
k = 2, | |
) | |
# μμΆ κ²μκΈ° μμ± | |
compression_retriever = ContextualCompressionRetriever( | |
# embeddings_filter μ€μ | |
base_compressor=embeddings_filter, | |
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν μ€νΈλ₯Ό μ°Ύμ | |
base_retriever=docsearch.as_retriever() | |
) | |
id_list = [] | |
history = [] | |
customer_data = "" | |
context = "{context}" | |
question = "{question}" | |
def gen(x, id, customer_data): | |
index = 0 | |
matched = 0 | |
count = 0 | |
for s in id_list: | |
if s == id: | |
matched = 1 | |
break; | |
index += 1 | |
if matched == 0: | |
index = len(id_list) | |
id_list.append(id) | |
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n') | |
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
if x == "μ΄κΈ°ν": | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"λνκΈ°λ‘μ΄ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
elif x == "κ°μ μ 보": | |
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
context = "{context}" | |
question = "{question}" | |
customer_data_newline = customer_data.replace(",","\n") | |
prompt_template = f"""λΉμ μ 보ν μλ΄μμ λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ. | |
{context} | |
### λͺ λ Ήμ΄: | |
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ μ 곡νμΈμ. | |
[μ§μΉ¨] | |
1.κ³ κ°μ κ°μ μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ. | |
2.κ³ κ°μ΄ κ°μ ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ. | |
3.κ³ κ°μ΄ κ°μ νμ§ μμ 보νμ 보μμ κ΄ν μ§λ¬Έμ κ΄λ ¨ 보νμ μκ°νλ©° 보μμ΄ λΆκ°λ₯νλ€λ μ μ μλ΄νμΈμ. | |
4.κ³ κ°μ΄ κ°μ νμ§ μμ 보νμ κ°μ μ΄ νμνλ€κ³ 보νλͺ μ νμ€νκ² μΈκΈνμΈμ. | |
λ€μ μ λ ₯μ μ£Όμ΄μ§λ κ³ κ°μ 보ν κ°μ μ 보μ μλ΄ κΈ°λ‘μ λ³΄κ³ κ³ κ°μκ² λμλλ μ 보λ₯Ό μ 곡νμΈμ. μ°¨κ·Όμ°¨κ·Ό μκ°νμ¬ λ΅λ³νμΈμ. λΉμ μ μ ν μ μμ΅λλ€. | |
### μ λ ₯: | |
[κ³ κ°μ κ°μ μ 보] | |
{customer_data_newline} | |
[μλ΄ κΈ°λ‘] | |
{history[index]} | |
κ³ κ°:{question} | |
### μλ΅: | |
""" | |
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμλ΅ κ°μ²΄λ₯Ό μμ± | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=compression_retriever, | |
return_source_documents=False, | |
verbose=True, | |
chain_type_kwargs={"prompt": PromptTemplate( | |
input_variables=["context","question"], | |
template=prompt_template, | |
)}, | |
) | |
query=f"λλ νμ¬ {customer_data}λ§ κ°μ ν μν©μ΄μΌ. {x}" | |
response = qa({"query":query}) | |
output_str = response['result'].split("###")[0].split("\u200b")[0] | |
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n" | |
return output_str | |
def reset_textbox(): | |
return gr.update(value='') | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
user_text = gr.Textbox( | |
placeholder='μ λ ₯', | |
label="User input" | |
) | |
model_output = gr.Textbox(label="Model output", lines=10, interactive=False) | |
button_submit = gr.Button(value="Submit") | |
with gr.Column(scale=1): | |
id_text = gr.Textbox( | |
placeholder='772727', | |
label="User id" | |
) | |
customer_data = gr.Textbox( | |
placeholder='(무)1λ λΆν°μ μΆλ³΄ν, (무)μμ λΉλ³΄ν', | |
label="customer_data" | |
) | |
button_submit.click(gen, [user_text, id_text, customer_data], model_output) | |
demo.queue().launch(enable_queue=True) |