anpigon commited on
Commit
7935c3a
1 Parent(s): 7b40096

Refactored the app.py file to improve structure and readability.

Browse files
Files changed (2) hide show
  1. app.py +95 -150
  2. prompt_template.py +28 -0
app.py CHANGED
@@ -1,159 +1,99 @@
1
  import os
2
  import gradio as gr
3
-
4
  from langchain_community.document_loaders import ObsidianLoader
5
  from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
6
-
7
  from langchain.embeddings import CacheBackedEmbeddings
8
  from langchain.storage import LocalFileStore
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
10
  from langchain_community.vectorstores import FAISS
11
-
12
  from langchain_community.retrievers import BM25Retriever
13
  from langchain.retrievers import EnsembleRetriever
14
-
15
  from langchain_cohere import CohereRerank
16
  from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
17
-
18
- from langchain_core.prompts import PromptTemplate
19
-
20
- from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
21
- from langchain_core.callbacks.manager import CallbackManager
22
- from langchain_core.runnables import ConfigurableField
23
- from langchain.callbacks.base import BaseCallbackHandler
24
  from langchain_core.output_parsers import StrOutputParser
25
- from langchain_core.runnables import RunnablePassthrough
26
  from langchain_groq import ChatGroq
27
- from langchain_community.llms import HuggingFaceHub
28
  from langchain_google_genai import GoogleGenerativeAI
29
- import platform
30
 
 
31
 
32
- directories = ["./docs/obsidian-help", "./docs/obsidian-developer"]
33
-
34
-
35
- # 1. 문서 로더를 사용하여 모든 .md 파일을 로드합니다.
36
- md_docs = []
37
- for directory in directories:
38
- try:
39
- loader = ObsidianLoader(directory, encoding="utf-8")
40
- md_docs.extend(loader.load())
41
- except Exception:
42
- pass
43
-
44
-
45
- # 2. 청크 분할기를 생성합니다.
46
- # 청크 크기는 2000, 청크간 겹치는 부분은 200 문자로 설정합니다.
47
- md_splitter = RecursiveCharacterTextSplitter.from_language(
48
- language=Language.MARKDOWN,
49
- chunk_size=2000,
50
- chunk_overlap=200,
51
- )
52
- splitted_docs = md_splitter.split_documents(md_docs)
53
-
54
-
55
- # 3. 임베딩 모델을 사용하여 문서의 임베딩을 계산합니다.
56
- # 허깅페이스 임베딩 모델 인스턴스를 생성합니다. 모델명으로 "BAAI/bge-m3 "을 사용합니다.
57
- if platform.system() == "Darwin":
58
- model_kwargs = {"device": "mps"}
59
- else:
60
- model_kwargs = {"device": "cpu"}
61
- model_name = "BAAI/bge-m3"
62
- encode_kwargs = {"normalize_embeddings": True}
63
- embeddings = HuggingFaceBgeEmbeddings(
64
- model_name=model_name,
65
- model_kwargs=model_kwargs,
66
- encode_kwargs=encode_kwargs,
67
- )
68
-
69
- # CacheBackedEmbeddings를 사용하여 임베딩 계산 결과를 캐시합니다.
70
- store = LocalFileStore("./.cache/")
71
- cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
72
- embeddings,
73
- store,
74
- namespace=embeddings.model_name,
75
- )
76
-
77
- # 4. FAISS 벡터 데이터베이스 인덱스를 생성하고 저장합니다.
78
  FAISS_DB_INDEX = "db_index"
79
 
80
- if os.path.exists(FAISS_DB_INDEX):
81
- # 저장된 데이터베이스 인덱스가 이미 존재하는 경우, 해당 인덱스를 로드합니다.
82
- db = FAISS.load_local(
83
- FAISS_DB_INDEX, # 로드할 FAISS 인덱스의 디렉토리 이름
84
- cached_embeddings, # 임베딩 정보를 제공
85
- allow_dangerous_deserialization=True, # 역직렬화를 허용하는 옵션
 
 
 
 
 
 
 
86
  )
87
- else:
88
- # combined_documents 문서들과 cached_embeddings 임베딩을 사용하여
89
- # FAISS 데이터베이스 인스턴스를 생성합니다.
90
- db = FAISS.from_documents(splitted_docs, cached_embeddings)
91
- # 생성된 데이터베이스 인스턴스를 지정한 폴더에 로컬로 저장합니다.
92
- db.save_local(folder_path=FAISS_DB_INDEX)
93
-
94
-
95
- # 5. Retrieval를 생성합니다.
96
- faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
97
-
98
- # 문서 컬렉션을 사용하여 BM25 검색 모델 인스턴스를 생성합니다.
99
- bm25_retriever = BM25Retriever.from_documents(splitted_docs) # 초기화에 사용할 문서 컬렉션
100
- bm25_retriever.k = 10 # 검색 시 최대 10개의 결과를 반환하도록 합니다.
101
-
102
- # EnsembleRetriever 인스턴스를 생성합니다.
103
- ensemble_retriever = EnsembleRetriever(
104
- retrievers=[bm25_retriever, faiss_retriever], # 사용할 검색 모델의 리스트
105
- weights=[0.5, 0.5], # 각 검색 모델의 결과에 적용할 가중치
106
- search_type="mmr", # 검색 결과의 다양성을 증진시키는 MMR 방식을 사용
107
- )
108
-
109
- # 6. CohereRerank 모델을 사용하여 재정렬을 수행합니다.
110
- compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
111
- compression_retriever = ContextualCompressionRetriever(
112
- base_compressor=compressor,
113
- base_retriever=ensemble_retriever,
114
- )
115
-
116
- # 7. Prompt를 생성합니다.
117
- prompt = PromptTemplate.from_template(
118
- """당신은 20년 경력의 옵시디언 노트앱 및 플러그인 개발 전문가로, 옵시디언 노트앱 사용법, 플러그인 및 테마 개발에 대한 깊은 지식을 가지고 있습니다. 당신의 주된 임무는 제공된 문서를 바탕으로 질문에 최대한 정확하고 상세하게 답변하는 것입니다.
119
- 문서에는 옵시디언 노트앱의 기본 사용법, 고급 기능, 플러그인 개발 방법, 테마 개발 가이드 등 옵시디언 노트앱을 깊이 있게 사용하고 확장하는 데 필요한 정보가 포함되어 있습니다.
120
- 귀하의 답변은 다음 지침에 따라야 합니다:
121
- 1. 모든 답변은 명확하고 이해하기 쉬운 한국어로 제공되어야 합니다.
122
- 2. 답변은 문서의 내용을 기반으로 해야 하며, 가능한 한 구체적인 정보를 포함해야 합니다.
123
- 3. 문서 내에서 직접적인 답변을 찾을 수 없는 경우, "문서에는 해당 질문에 대한 구체적인 답변이 없습니다."라고 명시해 주세요.
124
- 4. 가능한 경우, 답변과 관련된 문서의 구체적인 부분(예: 섹션 이름, 페이지 번호 등)을 출처로서 명시해 주세요.
125
- 5. 질문에 대한 답변이 문서에 부분적으로만 포함되어 있는 경우, 가능한 한 많은 정보를 종합하여 답변해 주세요. 또한, 추가적인 연구나 참고자료가 필요할 수 있음을 언급해 주세요.
126
-
127
- #참고문서:
128
- \"\"\"
129
- {context}
130
- \"\"\"
131
-
132
- #질문:
133
- {question}
134
-
135
- #답변:
136
-
137
- 출처:
138
- - source1
139
- - source2
140
- - ...
141
- """
142
- )
143
-
144
-
145
- # 7. chain를 생성합니다.
146
- llm = ChatGroq(
147
- model_name="llama3-70b-8192",
148
- temperature=0,
149
- ).configurable_alternatives(
150
- ConfigurableField(id="llm"),
151
- default_key="llama3",
152
- gemini=GoogleGenerativeAI(
153
- model="gemini-pro",
154
  temperature=0,
155
- ),
156
- )
 
 
 
 
 
 
157
 
158
  def format_docs(docs):
159
  formatted_docs = []
@@ -164,21 +104,26 @@ def format_docs(docs):
164
  formatted_docs.append(formatted_doc)
165
  return "\n---\n".join(formatted_docs)
166
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- rag_chain = (
169
- {"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
170
- | prompt
171
- | llm
172
- | StrOutputParser()
173
- )
174
 
175
- # # 8. chain를 실행합니다.
176
- def predict(message, history=None):
177
- answer = rag_chain.invoke(message)
178
- return answer
 
179
 
180
- gr.ChatInterface(
181
- predict,
182
- title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!",
183
- description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!",
184
- ).launch()
 
1
  import os
2
  import gradio as gr
3
+ import platform
4
  from langchain_community.document_loaders import ObsidianLoader
5
  from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
 
6
  from langchain.embeddings import CacheBackedEmbeddings
7
  from langchain.storage import LocalFileStore
8
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain_community.vectorstores import FAISS
 
10
  from langchain_community.retrievers import BM25Retriever
11
  from langchain.retrievers import EnsembleRetriever
 
12
  from langchain_cohere import CohereRerank
13
  from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
14
+ from langchain_core.runnables import ConfigurableField, RunnablePassthrough
 
 
 
 
 
 
15
  from langchain_core.output_parsers import StrOutputParser
 
16
  from langchain_groq import ChatGroq
 
17
  from langchain_google_genai import GoogleGenerativeAI
 
18
 
19
+ from prompt_template import PROMPT_TEMPLATE
20
 
21
+ DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  FAISS_DB_INDEX = "db_index"
23
 
24
+ def load_and_process_documents(directories):
25
+ md_docs = []
26
+ for directory in directories:
27
+ try:
28
+ loader = ObsidianLoader(directory, encoding="utf-8")
29
+ md_docs.extend(loader.load())
30
+ except Exception:
31
+ pass
32
+
33
+ md_splitter = RecursiveCharacterTextSplitter.from_language(
34
+ language=Language.MARKDOWN,
35
+ chunk_size=2000,
36
+ chunk_overlap=200,
37
  )
38
+ return md_splitter.split_documents(md_docs)
39
+
40
+ def setup_retrieval_system(splitted_docs):
41
+ if platform.system() == "Darwin":
42
+ model_kwargs = {"device": "mps"}
43
+ else:
44
+ model_kwargs = {"device": "cpu"}
45
+
46
+ embeddings = HuggingFaceBgeEmbeddings(
47
+ model_name="BAAI/bge-m3",
48
+ model_kwargs=model_kwargs,
49
+ encode_kwargs={"normalize_embeddings": True},
50
+ )
51
+
52
+ store = LocalFileStore("./.cache/")
53
+ cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
54
+ embeddings,
55
+ store,
56
+ namespace=embeddings.model_name,
57
+ )
58
+
59
+ if os.path.exists(FAISS_DB_INDEX):
60
+ db = FAISS.load_local(
61
+ FAISS_DB_INDEX,
62
+ cached_embeddings,
63
+ allow_dangerous_deserialization=True,
64
+ )
65
+ else:
66
+ db = FAISS.from_documents(splitted_docs, cached_embeddings)
67
+ db.save_local(folder_path=FAISS_DB_INDEX)
68
+
69
+ faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
70
+ bm25_retriever = BM25Retriever.from_documents(splitted_docs)
71
+ bm25_retriever.k = 10
72
+
73
+ ensemble_retriever = EnsembleRetriever(
74
+ retrievers=[bm25_retriever, faiss_retriever],
75
+ weights=[0.5, 0.5],
76
+ search_type="mmr",
77
+ )
78
+
79
+ compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
80
+ return ContextualCompressionRetriever(
81
+ base_compressor=compressor,
82
+ base_retriever=ensemble_retriever,
83
+ )
84
+
85
+ def setup_language_model():
86
+ return ChatGroq(
87
+ model_name="llama3-70b-8192",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  temperature=0,
89
+ ).configurable_alternatives(
90
+ ConfigurableField(id="llm"),
91
+ default_key="llama3",
92
+ gemini=GoogleGenerativeAI(
93
+ model="gemini-pro",
94
+ temperature=0,
95
+ ),
96
+ )
97
 
98
  def format_docs(docs):
99
  formatted_docs = []
 
104
  formatted_docs.append(formatted_doc)
105
  return "\n---\n".join(formatted_docs)
106
 
107
+ def main():
108
+ splitted_docs = load_and_process_documents(DIRECTORIES)
109
+ compression_retriever = setup_retrieval_system(splitted_docs)
110
+ llm = setup_language_model()
111
+
112
+ rag_chain = (
113
+ {"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
114
+ | PROMPT_TEMPLATE
115
+ | llm
116
+ | StrOutputParser()
117
+ )
118
 
119
+ def predict(message, history=None):
120
+ return rag_chain.invoke(message)
 
 
 
 
121
 
122
+ gr.ChatInterface(
123
+ predict,
124
+ title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!",
125
+ description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇���니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!",
126
+ ).launch()
127
 
128
+ if __name__ == "__main__":
129
+ main()
 
 
 
prompt_template.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+
3
+ PROMPT_TEMPLATE = PromptTemplate.from_template(
4
+ """당신은 20년 경력의 옵시디언 노트앱 및 플러그인 개발 전문가로, 옵시디언 노트앱 사용법, 플러그인 및 테마 개발에 대한 깊은 지식을 가지고 있습니다. 당신의 주된 임무는 제공된 문서를 바탕으로 질문에 최대한 정확하고 상세하게 답변하는 것입니다.
5
+ 문서에는 옵시디언 노트앱의 기본 사용법, 고급 기능, 플러그인 개발 방법, 테마 개발 가이드 등 옵시디언 노트앱을 깊이 있게 사용하고 확장하는 데 필요한 정보가 포함되어 있습니다.
6
+ 귀하의 답변은 다음 지침에 따라야 합니다:
7
+ 1. 모든 답변은 명확하고 이해하기 쉬운 한국어로 제공되어야 합니다.
8
+ 2. 답변은 문서의 내용을 기반으로 해야 하며, 가능한 한 구체적인 정보를 포함해야 합니다.
9
+ 3. 문서 내에서 직접적인 답변을 찾을 수 없는 경우, "문서에는 해당 질문에 대한 구체적인 답변이 없습니다."라고 명시해 주세요.
10
+ 4. 가능한 경우, 답변과 관련된 문서의 구체적인 부분(예: 섹션 이름, 페이지 번호 등)을 출처로서 명시해 주세요.
11
+ 5. 질문에 대한 답변이 문서에 부분적으로만 포함되어 있는 경우, 가능한 한 많은 정보를 종합하여 답변해 주세요. 또한, 추가적인 연구나 참고자료가 필요할 수 있음을 언급해 주세요.
12
+
13
+ #참고문서:
14
+ \"\"\"
15
+ {context}
16
+ \"\"\"
17
+
18
+ #질문:
19
+ {question}
20
+
21
+ #답변:
22
+
23
+ 출처:
24
+ - source1
25
+ - source2
26
+ - ...
27
+ """
28
+ )