File size: 6,051 Bytes
4b7cafe
2caab98
4b7cafe
 
 
 
 
 
 
 
 
 
 
 
 
 
821de99
2caab98
 
 
 
 
4b7cafe
 
 
 
 
 
 
 
 
5568b10
4b7cafe
 
 
 
 
 
 
 
 
 
5568b10
4b7cafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1baf1cd
4b7cafe
 
1baf1cd
4b7cafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821de99
4b7cafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from threading import Thread
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import torch
import gradio as gr
import re
import asyncio
import requests
import shutil
from langchain.llms import LlamaCpp
from langchain import PromptTemplate, LLMChain
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())

llm = LlamaCpp(
    model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
    temperature=0.5,
    top_p=0.9,
    max_tokens=80,
    verbose=True,
    n_ctx=2048,
    n_gpu_layers=-1,
    f16_kv=True
)

# μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")

# faiss_db 둜 λ‘œμ»¬μ— λ‘œλ“œν•˜κΈ°
docsearch = FAISS.load_local("", embeddings)

embeddings_filter = EmbeddingsFilter(
    embeddings=embeddings, 
    similarity_threshold=0.7,
    k = 2,
)
# μ••μΆ• 검색기 생성
compression_retriever = ContextualCompressionRetriever(
	# embeddings_filter μ„€μ •
    base_compressor=embeddings_filter, 
    # retriever λ₯Ό ν˜ΈμΆœν•˜μ—¬ 검색쿼리와 μœ μ‚¬ν•œ ν…μŠ€νŠΈλ₯Ό 찾음
    base_retriever=docsearch.as_retriever()
)


id_list = []
history = []
customer_data = ""
context = "{context}"
question = "{question}"

def gen(x, id, customer_data):

    index = 0
    matched = 0
    count = 0
    for s in id_list:
        if s == id:
            matched = 1
            break;
        index += 1

    if matched == 0:
        index = len(id_list)
        id_list.append(id)
        history.append('상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n')

        bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
        return bot_str
    else:
        if x == "μ΄ˆκΈ°ν™”":
            history[index] = '상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n'
            bot_str = f"λŒ€ν™”κΈ°λ‘μ΄ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n\nν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
            return bot_str
        elif x == "κ°€μž…μ •λ³΄":
            bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
            return bot_str
        else:
            context = "{context}"
            question = "{question}"
            customer_data_newline = customer_data.replace(",","\n")

            prompt_template = f"""당신은 λ³΄ν—˜ μƒλ‹΄μ›μž…λ‹ˆλ‹€. μ•„λž˜μ— 질문과 κ΄€λ ¨λœ μ•½κ΄€ 정보, 응닡 지침과 고객의 λ³΄ν—˜ κ°€μž… 정보, 고객과의 상담기둝이 μ£Όμ–΄μ§‘λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.

{context}

### λͺ…λ Ήμ–΄:
λ‹€μŒ 지침을 μ°Έκ³ ν•˜μ—¬ μƒλ‹΄μ›μœΌλ‘œμ„œ κ³ κ°μ—κ²Œ ν•„μš”ν•œ 응닡을 μ œκ³΅ν•˜μ„Έμš”.
[지침]
1.고객의 κ°€μž… 정보λ₯Ό κΌ­ ν™•μΈν•˜μ—¬ 고객이 κ°€μž…ν•œ λ³΄ν—˜μ— λŒ€ν•œ λ‚΄μš©λ§Œ μ œκ³΅ν•˜μ„Έμš”.
2.고객이 κ°€μž…ν•œ λ³΄ν—˜μ΄λΌλ©΄ 고객의 μ§ˆλ¬Έμ— λŒ€ν•΄ 적절히 λ‹΅λ³€ν•˜μ„Έμš”.
3.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ˜ 보상에 κ΄€ν•œ μ§ˆλ¬Έμ€ κ΄€λ ¨ λ³΄ν—˜μ„ μ†Œκ°œν•˜λ©° 보상이 λΆˆκ°€λŠ₯ν•˜λ‹€λŠ” 점을 μ•ˆλ‚΄ν•˜μ„Έμš”.
4.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ€ κ°€μž…μ΄ ν•„μš”ν•˜λ‹€κ³  λ³΄ν—˜λͺ…을 ν™•μ‹€ν•˜κ²Œ μ–ΈκΈ‰ν•˜μ„Έμš”.
λ‹€μŒ μž…λ ₯에 μ£Όμ–΄μ§€λŠ” 고객의 λ³΄ν—˜ κ°€μž… 정보와 상담 기둝을 보고 κ³ κ°μ—κ²Œ λ„μ›€λ˜λŠ” 정보λ₯Ό μ œκ³΅ν•˜μ„Έμš”. μ°¨κ·Όμ°¨κ·Ό μƒκ°ν•˜μ—¬ λ‹΅λ³€ν•˜μ„Έμš”. 당신은 잘 ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

### μž…λ ₯:
[고객의 κ°€μž… 정보]
{customer_data_newline}

[상담 기둝]
{history[index]}
고객:{question}

### 응닡:
"""

            # RetrievalQA 클래슀의 from_chain_typeμ΄λΌλŠ” 클래슀 λ©”μ„œλ“œλ₯Ό ν˜ΈμΆœν•˜μ—¬ μ§ˆμ˜μ‘λ‹΅ 객체λ₯Ό 생성
            qa = RetrievalQA.from_chain_type(
              llm=llm, 
              chain_type="stuff",
              retriever=compression_retriever, 
              return_source_documents=False,
              verbose=True, 
              chain_type_kwargs={"prompt": PromptTemplate(
                  input_variables=["context","question"],
                  template=prompt_template,
              )},
            )
            query=f"λ‚˜λŠ” ν˜„μž¬ {customer_data}만 κ°€μž…ν•œ 상황이야. {x}"
            response = qa({"query":query})
            output_str = response['result'].split("###")[0].split("\u200b")[0]
            history[index] += f"고객:{x}\n상담원:{output_str}\n"
            return output_str
def reset_textbox():
    return gr.update(value='')
with gr.Blocks() as demo:
    gr.Markdown(
       "duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
    )

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder='μž…λ ₯',
                label="User input"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")
        with gr.Column(scale=1):
            id_text = gr.Textbox(
                placeholder='772727',
                label="User id"
            )
            customer_data = gr.Textbox(
                placeholder='(무)1λ…„λΆ€ν„°μ €μΆ•λ³΄ν—˜, (무)μˆ˜μˆ λΉ„λ³΄ν—˜',
                label="customer_data"
            )
    button_submit.click(gen, [user_text, id_text, customer_data], model_output)
    demo.queue().launch(enable_queue=True)