Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,50 +7,55 @@ import re
|
|
7 |
import asyncio
|
8 |
import requests
|
9 |
import shutil
|
10 |
-
from langchain.llms import LlamaCpp
|
11 |
from langchain import PromptTemplate, LLMChain
|
12 |
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
13 |
from langchain.retrievers import ContextualCompressionRetriever
|
14 |
from langchain.chains import RetrievalQA
|
15 |
from langchain.vectorstores import FAISS
|
16 |
from langchain.embeddings import HuggingFaceEmbeddings
|
17 |
-
from langchain.prompts import PromptTemplate
|
18 |
from langchain.text_splitter import CharacterTextSplitter
|
19 |
from langchain.document_loaders import PyPDFLoader
|
20 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
21 |
-
|
22 |
-
from langchain.
|
|
|
|
|
23 |
|
24 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
print("Running on device:", torch_device)
|
26 |
print("CPU threads:", torch.get_num_threads())
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
f16_kv=True
|
37 |
-
)
|
38 |
|
39 |
# μλ² λ© λͺ¨λΈ λ‘λ
|
40 |
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ°
|
43 |
docsearch = FAISS.load_local("", embeddings)
|
44 |
|
45 |
embeddings_filter = EmbeddingsFilter(
|
46 |
-
embeddings=embeddings,
|
47 |
similarity_threshold=0.7,
|
48 |
k = 2,
|
49 |
)
|
50 |
# μμΆ κ²μκΈ° μμ±
|
51 |
compression_retriever = ContextualCompressionRetriever(
|
52 |
# embeddings_filter μ€μ
|
53 |
-
base_compressor=embeddings_filter,
|
54 |
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν
μ€νΈλ₯Ό μ°Ύμ
|
55 |
base_retriever=docsearch.as_retriever()
|
56 |
)
|
@@ -58,7 +63,7 @@ compression_retriever = ContextualCompressionRetriever(
|
|
58 |
|
59 |
id_list = []
|
60 |
history = []
|
61 |
-
|
62 |
context = "{context}"
|
63 |
question = "{question}"
|
64 |
|
@@ -76,29 +81,31 @@ def gen(x, id, customer_data):
|
|
76 |
if matched == 0:
|
77 |
index = len(id_list)
|
78 |
id_list.append(id)
|
|
|
79 |
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n')
|
80 |
|
81 |
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
82 |
return bot_str
|
83 |
else:
|
84 |
if x == "μ΄κΈ°ν":
|
|
|
85 |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n'
|
86 |
-
bot_str = f"λνκΈ°λ‘μ΄ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
87 |
return bot_str
|
88 |
elif x == "κ°μ
μ 보":
|
89 |
-
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {
|
90 |
return bot_str
|
91 |
else:
|
92 |
context = "{context}"
|
93 |
question = "{question}"
|
94 |
-
customer_data_newline =
|
95 |
|
96 |
prompt_template = f"""λΉμ μ 보ν μλ΄μμ
λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ
μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.
|
97 |
|
98 |
{context}
|
99 |
|
100 |
### λͺ
λ Ήμ΄:
|
101 |
-
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ μ 곡νμΈμ.
|
102 |
[μ§μΉ¨]
|
103 |
1.κ³ κ°μ κ°μ
μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ
ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ.
|
104 |
2.κ³ κ°μ΄ κ°μ
ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ.
|
@@ -119,19 +126,20 @@ def gen(x, id, customer_data):
|
|
119 |
|
120 |
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμοΏ½οΏ½ κ°μ²΄λ₯Ό μμ±
|
121 |
qa = RetrievalQA.from_chain_type(
|
122 |
-
llm=llm,
|
123 |
chain_type="stuff",
|
124 |
-
retriever=compression_retriever,
|
125 |
return_source_documents=False,
|
126 |
-
verbose=True,
|
127 |
chain_type_kwargs={"prompt": PromptTemplate(
|
128 |
input_variables=["context","question"],
|
129 |
template=prompt_template,
|
130 |
)},
|
131 |
)
|
132 |
-
query=f"λλ νμ¬ {
|
133 |
response = qa({"query":query})
|
134 |
-
output_str = response['result']
|
|
|
135 |
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n"
|
136 |
return output_str
|
137 |
def reset_textbox():
|
|
|
7 |
import asyncio
|
8 |
import requests
|
9 |
import shutil
|
|
|
10 |
from langchain import PromptTemplate, LLMChain
|
11 |
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
12 |
from langchain.retrievers import ContextualCompressionRetriever
|
13 |
from langchain.chains import RetrievalQA
|
14 |
from langchain.vectorstores import FAISS
|
15 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
16 |
from langchain.text_splitter import CharacterTextSplitter
|
17 |
from langchain.document_loaders import PyPDFLoader
|
18 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
19 |
+
import os
|
20 |
+
from langchain.llms import OpenAI
|
21 |
+
|
22 |
+
llm = OpenAI(model_name='text-davinci-003')
|
23 |
|
24 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
print("Running on device:", torch_device)
|
26 |
print("CPU threads:", torch.get_num_threads())
|
27 |
|
28 |
+
loader = PyPDFLoader("total.pdf")
|
29 |
+
pages = loader.load()
|
30 |
+
|
31 |
+
# λ°μ΄ν°λ₯Ό λΆλ¬μμ ν
μ€νΈλ₯Ό μΌμ ν μλ‘ λλκ³ κ΅¬λΆμλ‘ μ°κ²°νλ μμ
|
32 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
|
33 |
+
texts = text_splitter.split_documents(pages)
|
34 |
+
|
35 |
+
print(f"λ¬Έμμ {len(texts)}κ°μ λ¬Έμλ₯Ό κ°μ§κ³ μμ΅λλ€.")
|
|
|
|
|
36 |
|
37 |
# μλ² λ© λͺ¨λΈ λ‘λ
|
38 |
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
|
39 |
+
# λ¬Έμμ μλ ν
μ€νΈλ₯Ό μλ² λ©νκ³ FAISS μ μΈλ±μ€λ₯Ό ꡬμΆν¨
|
40 |
+
index = FAISS.from_documents(
|
41 |
+
documents=texts,
|
42 |
+
embedding=embeddings,
|
43 |
+
)
|
44 |
+
|
45 |
+
# faiss_db λ‘ λ‘컬μ μ μ₯νκΈ°
|
46 |
+
index.save_local("")
|
47 |
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ°
|
48 |
docsearch = FAISS.load_local("", embeddings)
|
49 |
|
50 |
embeddings_filter = EmbeddingsFilter(
|
51 |
+
embeddings=embeddings,
|
52 |
similarity_threshold=0.7,
|
53 |
k = 2,
|
54 |
)
|
55 |
# μμΆ κ²μκΈ° μμ±
|
56 |
compression_retriever = ContextualCompressionRetriever(
|
57 |
# embeddings_filter μ€μ
|
58 |
+
base_compressor=embeddings_filter,
|
59 |
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν
μ€νΈλ₯Ό μ°Ύμ
|
60 |
base_retriever=docsearch.as_retriever()
|
61 |
)
|
|
|
63 |
|
64 |
id_list = []
|
65 |
history = []
|
66 |
+
customer_data_list = []
|
67 |
context = "{context}"
|
68 |
question = "{question}"
|
69 |
|
|
|
81 |
if matched == 0:
|
82 |
index = len(id_list)
|
83 |
id_list.append(id)
|
84 |
+
customer_data_list.append(customer_data)
|
85 |
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n')
|
86 |
|
87 |
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
88 |
return bot_str
|
89 |
else:
|
90 |
if x == "μ΄κΈ°ν":
|
91 |
+
customer_data_list[index] = customer_data
|
92 |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n'
|
93 |
+
bot_str = f"λνκΈ°λ‘μ΄ λͺ¨λ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
94 |
return bot_str
|
95 |
elif x == "κ°μ
μ 보":
|
96 |
+
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data_list[index]}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
97 |
return bot_str
|
98 |
else:
|
99 |
context = "{context}"
|
100 |
question = "{question}"
|
101 |
+
customer_data_newline = customer_data_list[index].replace(",","\n")
|
102 |
|
103 |
prompt_template = f"""λΉμ μ 보ν μλ΄μμ
λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ
μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.
|
104 |
|
105 |
{context}
|
106 |
|
107 |
### λͺ
λ Ήμ΄:
|
108 |
+
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ κ°κ²°νκ² μ 곡νμΈμ.
|
109 |
[μ§μΉ¨]
|
110 |
1.κ³ κ°μ κ°μ
μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ
ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ.
|
111 |
2.κ³ κ°μ΄ κ°μ
ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ.
|
|
|
126 |
|
127 |
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμοΏ½οΏ½ κ°μ²΄λ₯Ό μμ±
|
128 |
qa = RetrievalQA.from_chain_type(
|
129 |
+
llm=llm,
|
130 |
chain_type="stuff",
|
131 |
+
retriever=compression_retriever,
|
132 |
return_source_documents=False,
|
133 |
+
verbose=True,
|
134 |
chain_type_kwargs={"prompt": PromptTemplate(
|
135 |
input_variables=["context","question"],
|
136 |
template=prompt_template,
|
137 |
)},
|
138 |
)
|
139 |
+
query=f"λλ νμ¬ {customer_data_list[index]}λ§ κ°μ
ν μν©μ΄μΌ. {x}"
|
140 |
response = qa({"query":query})
|
141 |
+
output_str = response['result']
|
142 |
+
print(prompt_template + output_str)
|
143 |
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n"
|
144 |
return output_str
|
145 |
def reset_textbox():
|