ldhldh's picture
Update app.py
c3a9ea4
raw history blame
No virus
6.05 kB
from threading import Thread
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import torch
import gradio as gr
import re
import asyncio
import requests
import shutil
from langchain.llms import LlamaCpp
from langchain import PromptTemplate, LLMChain
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
llm = LlamaCpp(
model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
temperature=0.5,
top_p=0.9,
max_tokens=80,
verbose=True,
n_ctx=2048,
n_gpu_layers=-1,
f16_kv=True
)
# μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# faiss_db 둜 λ‘œμ»¬μ— λ‘œλ“œν•˜κΈ°
docsearch = FAISS.load_local("", embeddings)
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings,
similarity_threshold=0.7,
k = 2,
)
# μ••μΆ• 검색기 생성
compression_retriever = ContextualCompressionRetriever(
# embeddings_filter μ„€μ •
base_compressor=embeddings_filter,
# retriever λ₯Ό ν˜ΈμΆœν•˜μ—¬ 검색쿼리와 μœ μ‚¬ν•œ ν…μŠ€νŠΈλ₯Ό 찾음
base_retriever=docsearch.as_retriever()
)
id_list = []
history = []
customer_data = ""
context = "{context}"
question = "{question}"
def gen(x, id, customer_data):
index = 0
matched = 0
count = 0
for s in id_list:
if s == id:
matched = 1
break;
index += 1
if matched == 0:
index = len(id_list)
id_list.append(id)
history.append('상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n')
bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
return bot_str
else:
if x == "μ΄ˆκΈ°ν™”":
history[index] = '상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n'
bot_str = f"λŒ€ν™”κΈ°λ‘μ΄ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n\nν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
return bot_str
elif x == "κ°€μž…μ •λ³΄":
bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
return bot_str
else:
context = "{context}"
question = "{question}"
customer_data_newline = customer_data.replace(",","\n")
prompt_template = f"""당신은 λ³΄ν—˜ μƒλ‹΄μ›μž…λ‹ˆλ‹€. μ•„λž˜μ— 질문과 κ΄€λ ¨λœ μ•½κ΄€ 정보, 응닡 지침과 고객의 λ³΄ν—˜ κ°€μž… 정보, 고객과의 상담기둝이 μ£Όμ–΄μ§‘λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.
{context}
### λͺ…λ Ήμ–΄:
λ‹€μŒ 지침을 μ°Έκ³ ν•˜μ—¬ μƒλ‹΄μ›μœΌλ‘œμ„œ κ³ κ°μ—κ²Œ ν•„μš”ν•œ 응닡을 μ œκ³΅ν•˜μ„Έμš”.
[지침]
1.고객의 κ°€μž… 정보λ₯Ό κΌ­ ν™•μΈν•˜μ—¬ 고객이 κ°€μž…ν•œ λ³΄ν—˜μ— λŒ€ν•œ λ‚΄μš©λ§Œ μ œκ³΅ν•˜μ„Έμš”.
2.고객이 κ°€μž…ν•œ λ³΄ν—˜μ΄λΌλ©΄ 고객의 μ§ˆλ¬Έμ— λŒ€ν•΄ 적절히 λ‹΅λ³€ν•˜μ„Έμš”.
3.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ˜ 보상에 κ΄€ν•œ μ§ˆλ¬Έμ€ κ΄€λ ¨ λ³΄ν—˜μ„ μ†Œκ°œν•˜λ©° 보상이 λΆˆκ°€λŠ₯ν•˜λ‹€λŠ” 점을 μ•ˆλ‚΄ν•˜μ„Έμš”.
4.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ€ κ°€μž…μ΄ ν•„μš”ν•˜λ‹€κ³  λ³΄ν—˜λͺ…을 ν™•μ‹€ν•˜κ²Œ μ–ΈκΈ‰ν•˜μ„Έμš”.
λ‹€μŒ μž…λ ₯에 μ£Όμ–΄μ§€λŠ” 고객의 λ³΄ν—˜ κ°€μž… 정보와 상담 기둝을 보고 κ³ κ°μ—κ²Œ λ„μ›€λ˜λŠ” 정보λ₯Ό μ œκ³΅ν•˜μ„Έμš”. μ°¨κ·Όμ°¨κ·Ό μƒκ°ν•˜μ—¬ λ‹΅λ³€ν•˜μ„Έμš”. 당신은 잘 ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
### μž…λ ₯:
[고객의 κ°€μž… 정보]
{customer_data_newline}
[상담 기둝]
{history[index]}
고객:{question}
### 응닡:
"""
# RetrievalQA 클래슀의 from_chain_typeμ΄λΌλŠ” 클래슀 λ©”μ„œλ“œλ₯Ό ν˜ΈμΆœν•˜μ—¬ μ§ˆμ˜μ‘λ‹΅ 객체λ₯Ό 생성
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=compression_retriever,
return_source_documents=False,
verbose=True,
chain_type_kwargs={"prompt": PromptTemplate(
input_variables=["context","question"],
template=prompt_template,
)},
)
query=f"λ‚˜λŠ” ν˜„μž¬ {customer_data}만 κ°€μž…ν•œ 상황이야. {x}"
response = qa({"query":query})
output_str = response['result'].split("###")[0].split("\u200b")[0]
history[index] += f"고객:{x}\n상담원:{output_str}\n"
return output_str
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder='μž…λ ₯',
label="User input"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
with gr.Column(scale=1):
id_text = gr.Textbox(
placeholder='772727',
label="User id"
)
customer_data = gr.Textbox(
placeholder='(무)1λ…„λΆ€ν„°μ €μΆ•λ³΄ν—˜, (무)μˆ˜μˆ λΉ„λ³΄ν—˜',
label="customer_data"
)
button_submit.click(gen, [user_text, id_text, customer_data], model_output)
demo.queue().launch(enable_queue=True)