ldhldh commited on
Commit
42efc58
β€’
1 Parent(s): 59b1a71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -27
app.py CHANGED
@@ -7,50 +7,55 @@ import re
7
  import asyncio
8
  import requests
9
  import shutil
10
- from langchain.llms import LlamaCpp
11
  from langchain import PromptTemplate, LLMChain
12
  from langchain.retrievers.document_compressors import EmbeddingsFilter
13
  from langchain.retrievers import ContextualCompressionRetriever
14
  from langchain.chains import RetrievalQA
15
  from langchain.vectorstores import FAISS
16
  from langchain.embeddings import HuggingFaceEmbeddings
17
- from langchain.prompts import PromptTemplate
18
  from langchain.text_splitter import CharacterTextSplitter
19
  from langchain.document_loaders import PyPDFLoader
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
- from langchain.vectorstores import FAISS
22
- from langchain.embeddings import HuggingFaceEmbeddings
 
 
23
 
24
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
25
  print("Running on device:", torch_device)
26
  print("CPU threads:", torch.get_num_threads())
27
 
28
- llm = LlamaCpp(
29
- model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
30
- temperature=0.5,
31
- top_p=0.9,
32
- max_tokens=80,
33
- verbose=True,
34
- n_ctx=2048,
35
- n_gpu_layers=-1,
36
- f16_kv=True
37
- )
38
 
39
  # μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
40
  embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
41
-
 
 
 
 
 
 
 
42
  # faiss_db 둜 λ‘œμ»¬μ— λ‘œλ“œν•˜κΈ°
43
  docsearch = FAISS.load_local("", embeddings)
44
 
45
  embeddings_filter = EmbeddingsFilter(
46
- embeddings=embeddings,
47
  similarity_threshold=0.7,
48
  k = 2,
49
  )
50
  # μ••μΆ• 검색기 생성
51
  compression_retriever = ContextualCompressionRetriever(
52
  # embeddings_filter μ„€μ •
53
- base_compressor=embeddings_filter,
54
  # retriever λ₯Ό ν˜ΈμΆœν•˜μ—¬ 검색쿼리와 μœ μ‚¬ν•œ ν…μŠ€νŠΈλ₯Ό 찾음
55
  base_retriever=docsearch.as_retriever()
56
  )
@@ -58,7 +63,7 @@ compression_retriever = ContextualCompressionRetriever(
58
 
59
  id_list = []
60
  history = []
61
- customer_data = ""
62
  context = "{context}"
63
  question = "{question}"
64
 
@@ -76,29 +81,31 @@ def gen(x, id, customer_data):
76
  if matched == 0:
77
  index = len(id_list)
78
  id_list.append(id)
 
79
  history.append('상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n')
80
 
81
  bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
82
  return bot_str
83
  else:
84
  if x == "μ΄ˆκΈ°ν™”":
 
85
  history[index] = '상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n'
86
- bot_str = f"λŒ€ν™”κΈ°λ‘μ΄ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n\nν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
87
  return bot_str
88
  elif x == "κ°€μž…μ •λ³΄":
89
- bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
90
  return bot_str
91
  else:
92
  context = "{context}"
93
  question = "{question}"
94
- customer_data_newline = customer_data.replace(",","\n")
95
 
96
  prompt_template = f"""당신은 λ³΄ν—˜ μƒλ‹΄μ›μž…λ‹ˆλ‹€. μ•„λž˜μ— 질문과 κ΄€λ ¨λœ μ•½κ΄€ 정보, 응닡 지침과 고객의 λ³΄ν—˜ κ°€μž… 정보, 고객과의 상담기둝이 μ£Όμ–΄μ§‘λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.
97
 
98
  {context}
99
 
100
  ### λͺ…λ Ήμ–΄:
101
- λ‹€μŒ 지침을 μ°Έκ³ ν•˜μ—¬ μƒλ‹΄μ›μœΌλ‘œμ„œ κ³ κ°μ—κ²Œ ν•„μš”ν•œ 응닡을 μ œκ³΅ν•˜μ„Έμš”.
102
  [지침]
103
  1.고객의 κ°€μž… 정보λ₯Ό κΌ­ ν™•μΈν•˜μ—¬ 고객이 κ°€μž…ν•œ λ³΄ν—˜μ— λŒ€ν•œ λ‚΄μš©λ§Œ μ œκ³΅ν•˜μ„Έμš”.
104
  2.고객이 κ°€μž…ν•œ λ³΄ν—˜μ΄λΌλ©΄ 고객의 μ§ˆλ¬Έμ— λŒ€ν•΄ 적절히 λ‹΅λ³€ν•˜μ„Έμš”.
@@ -119,19 +126,20 @@ def gen(x, id, customer_data):
119
 
120
  # RetrievalQA 클래슀의 from_chain_typeμ΄λΌλŠ” 클래슀 λ©”μ„œλ“œλ₯Ό ν˜ΈμΆœν•˜μ—¬ μ§ˆμ˜μ‘οΏ½οΏ½ 객체λ₯Ό 생성
121
  qa = RetrievalQA.from_chain_type(
122
- llm=llm,
123
  chain_type="stuff",
124
- retriever=compression_retriever,
125
  return_source_documents=False,
126
- verbose=True,
127
  chain_type_kwargs={"prompt": PromptTemplate(
128
  input_variables=["context","question"],
129
  template=prompt_template,
130
  )},
131
  )
132
- query=f"λ‚˜λŠ” ν˜„μž¬ {customer_data}만 κ°€μž…ν•œ 상황이야. {x}"
133
  response = qa({"query":query})
134
- output_str = response['result'].split("###")[0].split("\u200b")[0]
 
135
  history[index] += f"고객:{x}\n상담원:{output_str}\n"
136
  return output_str
137
  def reset_textbox():
 
7
  import asyncio
8
  import requests
9
  import shutil
 
10
  from langchain import PromptTemplate, LLMChain
11
  from langchain.retrievers.document_compressors import EmbeddingsFilter
12
  from langchain.retrievers import ContextualCompressionRetriever
13
  from langchain.chains import RetrievalQA
14
  from langchain.vectorstores import FAISS
15
  from langchain.embeddings import HuggingFaceEmbeddings
 
16
  from langchain.text_splitter import CharacterTextSplitter
17
  from langchain.document_loaders import PyPDFLoader
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ import os
20
+ from langchain.llms import OpenAI
21
+
22
+ llm = OpenAI(model_name='text-davinci-003')
23
 
24
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
25
  print("Running on device:", torch_device)
26
  print("CPU threads:", torch.get_num_threads())
27
 
28
+ loader = PyPDFLoader("total.pdf")
29
+ pages = loader.load()
30
+
31
+ # 데이터λ₯Ό λΆˆλŸ¬μ™€μ„œ ν…μŠ€νŠΈλ₯Ό μΌμ •ν•œ 수둜 λ‚˜λˆ„κ³  κ΅¬λΆ„μžλ‘œ μ—°κ²°ν•˜λŠ” μž‘μ—…
32
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
33
+ texts = text_splitter.split_documents(pages)
34
+
35
+ print(f"λ¬Έμ„œμ— {len(texts)}개의 λ¬Έμ„œλ₯Ό 가지고 μžˆμŠ΅λ‹ˆλ‹€.")
 
 
36
 
37
  # μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
38
  embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
39
+ # λ¬Έμ„œμ— μžˆλŠ” ν…μŠ€νŠΈλ₯Ό μž„λ² λ”©ν•˜κ³  FAISS 에 인덱슀λ₯Ό ꡬ좕함
40
+ index = FAISS.from_documents(
41
+ documents=texts,
42
+ embedding=embeddings,
43
+ )
44
+
45
+ # faiss_db 둜 λ‘œμ»¬μ— μ €μž₯ν•˜κΈ°
46
+ index.save_local("")
47
  # faiss_db 둜 λ‘œμ»¬μ— λ‘œλ“œν•˜κΈ°
48
  docsearch = FAISS.load_local("", embeddings)
49
 
50
  embeddings_filter = EmbeddingsFilter(
51
+ embeddings=embeddings,
52
  similarity_threshold=0.7,
53
  k = 2,
54
  )
55
  # μ••μΆ• 검색기 생성
56
  compression_retriever = ContextualCompressionRetriever(
57
  # embeddings_filter μ„€μ •
58
+ base_compressor=embeddings_filter,
59
  # retriever λ₯Ό ν˜ΈμΆœν•˜μ—¬ 검색쿼리와 μœ μ‚¬ν•œ ν…μŠ€νŠΈλ₯Ό 찾음
60
  base_retriever=docsearch.as_retriever()
61
  )
 
63
 
64
  id_list = []
65
  history = []
66
+ customer_data_list = []
67
  context = "{context}"
68
  question = "{question}"
69
 
 
81
  if matched == 0:
82
  index = len(id_list)
83
  id_list.append(id)
84
+ customer_data_list.append(customer_data)
85
  history.append('상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n')
86
 
87
  bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
88
  return bot_str
89
  else:
90
  if x == "μ΄ˆκΈ°ν™”":
91
+ customer_data_list[index] = customer_data
92
  history[index] = '상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n'
93
+ bot_str = f"λŒ€ν™”κΈ°λ‘μ΄ λͺ¨λ‘ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n\nν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
94
  return bot_str
95
  elif x == "κ°€μž…μ •λ³΄":
96
+ bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data_list[index]}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
97
  return bot_str
98
  else:
99
  context = "{context}"
100
  question = "{question}"
101
+ customer_data_newline = customer_data_list[index].replace(",","\n")
102
 
103
  prompt_template = f"""당신은 λ³΄ν—˜ μƒλ‹΄μ›μž…λ‹ˆλ‹€. μ•„λž˜μ— 질문과 κ΄€λ ¨λœ μ•½κ΄€ 정보, 응닡 지침과 고객의 λ³΄ν—˜ κ°€μž… 정보, 고객과의 상담기둝이 μ£Όμ–΄μ§‘λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.
104
 
105
  {context}
106
 
107
  ### λͺ…λ Ήμ–΄:
108
+ λ‹€μŒ 지침을 μ°Έκ³ ν•˜μ—¬ μƒλ‹΄μ›μœΌλ‘œμ„œ κ³ κ°μ—κ²Œ ν•„μš”ν•œ 응닡을 κ°„κ²°ν•˜κ²Œ μ œκ³΅ν•˜μ„Έμš”.
109
  [지침]
110
  1.고객의 κ°€μž… 정보λ₯Ό κΌ­ ν™•μΈν•˜μ—¬ 고객이 κ°€μž…ν•œ λ³΄ν—˜μ— λŒ€ν•œ λ‚΄μš©λ§Œ μ œκ³΅ν•˜μ„Έμš”.
111
  2.고객이 κ°€μž…ν•œ λ³΄ν—˜μ΄λΌλ©΄ 고객의 μ§ˆλ¬Έμ— λŒ€ν•΄ 적절히 λ‹΅λ³€ν•˜μ„Έμš”.
 
126
 
127
  # RetrievalQA 클래슀의 from_chain_typeμ΄λΌλŠ” 클래슀 λ©”μ„œλ“œλ₯Ό ν˜ΈμΆœν•˜μ—¬ μ§ˆμ˜μ‘οΏ½οΏ½ 객체λ₯Ό 생성
128
  qa = RetrievalQA.from_chain_type(
129
+ llm=llm,
130
  chain_type="stuff",
131
+ retriever=compression_retriever,
132
  return_source_documents=False,
133
+ verbose=True,
134
  chain_type_kwargs={"prompt": PromptTemplate(
135
  input_variables=["context","question"],
136
  template=prompt_template,
137
  )},
138
  )
139
+ query=f"λ‚˜λŠ” ν˜„μž¬ {customer_data_list[index]}만 κ°€μž…ν•œ 상황이야. {x}"
140
  response = qa({"query":query})
141
+ output_str = response['result']
142
+ print(prompt_template + output_str)
143
  history[index] += f"고객:{x}\n상담원:{output_str}\n"
144
  return output_str
145
  def reset_textbox():