JUNGU commited on
Commit
2763b46
Β·
verified Β·
1 Parent(s): 729d217

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +35 -97
rag_system.py CHANGED
@@ -8,11 +8,6 @@ from langchain.docstore.document import Document
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  import pdfplumber
10
  from concurrent.futures import ThreadPoolExecutor
11
- from langchain.retrievers import ContextualCompressionRetriever
12
- from langchain.retrievers.document_compressors import LLMChainExtractor
13
- from langgraph.graph import Graph
14
- from langchain_core.runnables import RunnablePassthrough, RunnableLambda
15
- from langchain.prompts import PromptTemplate
16
 
17
  # Load environment variables
18
  load_dotenv()
@@ -33,40 +28,11 @@ def load_retrieval_qa_chain():
33
  # Initialize ChatOpenAI model
34
  llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
35
 
36
- # Create a compressor for re-ranking
37
- compressor = LLMChainExtractor.from_llm(llm)
38
-
39
- # Create a ContextualCompressionRetriever
40
- compression_retriever = ContextualCompressionRetriever(
41
- base_compressor=compressor,
42
- base_retriever=vectorstore.as_retriever()
43
- )
44
-
45
- # Define your instruction/prompt
46
- instruction = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:
47
-
48
- 1. ��색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.
49
- 2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.
50
- 3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.
51
- 4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.
52
- 5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.
53
- 6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.
54
- 7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.
55
-
56
- 항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""
57
-
58
- # Create a prompt template
59
- prompt_template = PromptTemplate(
60
- input_variables=["context", "question"],
61
- template=instruction + "\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
62
- )
63
-
64
- # Create ConversationalRetrievalChain with the new retriever and prompt
65
  qa_chain = ConversationalRetrievalChain.from_llm(
66
  llm,
67
- retriever=compression_retriever,
68
- return_source_documents=True,
69
- combine_docs_chain_kwargs={"prompt": prompt_template}
70
  )
71
 
72
  return qa_chain
@@ -116,69 +82,41 @@ def update_embeddings():
116
  documents.extend(result)
117
  vectorstore.add_documents(documents)
118
 
119
- def create_rag_graph():
120
- qa_chain = load_retrieval_qa_chain()
121
-
122
- def retrieve_and_generate(inputs):
123
- question = inputs["question"]
124
- chat_history = inputs["chat_history"]
125
- result = qa_chain({"question": question, "chat_history": chat_history})
126
-
127
- # Ensure source documents have the correct metadata
128
- sources = []
129
- for doc in result.get("source_documents", []):
130
- if "source" in doc.metadata and "page" in doc.metadata:
131
- sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})")
132
- else:
133
- print(f"Warning: Document missing metadata: {doc.metadata}")
134
-
135
- return {
136
- "answer": result["answer"],
137
- "sources": sources
138
- }
139
-
140
- workflow = Graph()
141
- workflow.add_node("retrieve_and_generate", retrieve_and_generate)
142
- workflow.set_entry_point("retrieve_and_generate")
143
-
144
- chain = workflow.compile()
145
- return chain
146
-
147
- rag_chain = create_rag_graph()
148
-
149
- def get_answer(query, chat_history):
150
- try:
151
- response = rag_chain({"question": query, "chat_history": chat_history})
152
-
153
- if not response or "answer" not in response:
154
- return {
155
- "answer": "μ£„μ†‘ν•©λ‹ˆλ‹€. 닡변을 생성할 수 μ—†μ—ˆμŠ΅λ‹ˆλ‹€. μ§ˆλ¬Έμ„ λ‹€μ‹œ ν‘œν˜„ν•΄ μ£Όμ‹œκ² μŠ΅λ‹ˆκΉŒ?",
156
- "sources": []
157
- }
158
-
159
- sources = response.get("sources", [])
160
-
161
- return {
162
- "answer": response["answer"],
163
- "sources": sources
164
- }
165
- except Exception as e:
166
- print(f"Error in get_answer: {str(e)}")
167
- return {
168
- "answer": "λ‹΅λ³€ 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. λ‹€μ‹œ μ‹œλ„ν•΄ μ£Όμ„Έμš”.",
169
- "sources": []
170
- }
171
 
172
  # Example usage
173
  if __name__ == "__main__":
174
  update_embeddings() # Update embeddings with new documents
175
- question = "RAG μ‹œμŠ€ν…œμ— λŒ€ν•΄ μ„€λͺ…ν•΄μ£Όμ„Έμš”."
176
- response = get_answer(question, [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  print(f"Question: {question}")
178
  print(f"Answer: {response['answer']}")
179
- print(f"Sources: {response['sources']}")
180
-
181
- # Validate source format
182
- for source in response['sources']:
183
- if not (source.endswith(')') and ' (Page ' in source):
184
- print(f"Warning: Unexpected source format: {source}")
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  import pdfplumber
10
  from concurrent.futures import ThreadPoolExecutor
 
 
 
 
 
11
 
12
  # Load environment variables
13
  load_dotenv()
 
28
  # Initialize ChatOpenAI model
29
  llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
30
 
31
+ # Create ConversationalRetrievalChain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  qa_chain = ConversationalRetrievalChain.from_llm(
33
  llm,
34
+ vectorstore.as_retriever(),
35
+ return_source_documents=True
 
36
  )
37
 
38
  return qa_chain
 
82
  documents.extend(result)
83
  vectorstore.add_documents(documents)
84
 
85
+ # Generate answer for a query
86
+ def get_answer(qa_chain, query, chat_history):
87
+ formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
88
+
89
+ response = qa_chain.invoke({"question": query, "chat_history": formatted_history})
90
+
91
+ answer = response["answer"]
92
+
93
+ source_docs = response.get("source_documents", [])
94
+ source_texts = [f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})" for doc in source_docs]
95
+
96
+ return {"answer": answer, "sources": source_texts}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Example usage
99
  if __name__ == "__main__":
100
  update_embeddings() # Update embeddings with new documents
101
+ qa_chain = load_retrieval_qa_chain()
102
+ question = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:
103
+
104
+ 1. 검색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.
105
+
106
+ 2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.
107
+
108
+ 3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.
109
+
110
+ 4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.
111
+
112
+ 5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.
113
+
114
+ 6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.
115
+
116
+ 7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.
117
+ 항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""
118
+
119
+ response = get_answer(qa_chain, question, [])
120
  print(f"Question: {question}")
121
  print(f"Answer: {response['answer']}")
122
+ print(f"Sources: {response['sources']}")