Phoenix21 commited on
Commit
41770fd
·
verified ·
1 Parent(s): 3c2ff85

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -5
pipeline.py CHANGED
@@ -24,6 +24,7 @@ from classification_chain import get_classification_chain
24
  from refusal_chain import get_refusal_chain
25
  from tailor_chain import get_tailor_chain
26
  from cleaner_chain import get_cleaner_chain
 
27
 
28
  from langchain.llms.base import LLM
29
 
@@ -101,6 +102,7 @@ classification_chain = get_classification_chain()
101
  refusal_chain = get_refusal_chain()
102
  tailor_chain = get_tailor_chain()
103
  cleaner_chain = get_cleaner_chain()
 
104
 
105
  ###############################################################################
106
  # 5) Build vectorstores & RAG
@@ -136,8 +138,11 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
136
  user_query = inputs["input"]
137
  chat_history = inputs.get("chat_history", [])
138
 
139
- # 1) Classification
140
- class_result = classification_chain.invoke({"query": user_query, "chat_history": chat_history})
 
 
 
141
  classification = class_result.get("text", "").strip()
142
 
143
  if classification == "OutOfScope":
@@ -147,18 +152,18 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
147
 
148
  if classification == "Wellness":
149
  rag_result = wellness_rag_chain.invoke({
150
- "query": user_query,
151
  "chat_history": chat_history # Pass history here
152
  })
153
  csv_answer = rag_result["result"].strip()
154
- web_answer = do_web_search(user_query) if not csv_answer else ""
155
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer, chat_history=chat_history)
156
  final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip()
157
  return {"answer": final_answer}
158
 
159
  if classification == "Brand":
160
  rag_result = brand_rag_chain.invoke({
161
- "query": user_query,
162
  "chat_history": chat_history # Pass history here
163
  })
164
  csv_answer = rag_result["result"].strip()
 
24
  from refusal_chain import get_refusal_chain
25
  from tailor_chain import get_tailor_chain
26
  from cleaner_chain import get_cleaner_chain
27
+ from contextualize_chain import get_contextualize_chain # New Import for ContextualizeChain
28
 
29
  from langchain.llms.base import LLM
30
 
 
102
  refusal_chain = get_refusal_chain()
103
  tailor_chain = get_tailor_chain()
104
  cleaner_chain = get_cleaner_chain()
105
+ contextualize_chain = get_contextualize_chain() # New Chain for Contextualizing User Queries
106
 
107
  ###############################################################################
108
  # 5) Build vectorstores & RAG
 
138
  user_query = inputs["input"]
139
  chat_history = inputs.get("chat_history", [])
140
 
141
+ # 1) Contextualize the query (ensure it's relevant to the conversation history)
142
+ contextualized_query = contextualize_chain.invoke({"user_query": user_query, "chat_history": chat_history})
143
+
144
+ # 2) Classification (using the contextualized query)
145
+ class_result = classification_chain.invoke({"query": contextualized_query, "chat_history": chat_history})
146
  classification = class_result.get("text", "").strip()
147
 
148
  if classification == "OutOfScope":
 
152
 
153
  if classification == "Wellness":
154
  rag_result = wellness_rag_chain.invoke({
155
+ "query": contextualized_query,
156
  "chat_history": chat_history # Pass history here
157
  })
158
  csv_answer = rag_result["result"].strip()
159
+ web_answer = do_web_search(contextualized_query) if not csv_answer else ""
160
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer, chat_history=chat_history)
161
  final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip()
162
  return {"answer": final_answer}
163
 
164
  if classification == "Brand":
165
  rag_result = brand_rag_chain.invoke({
166
+ "query": contextualized_query,
167
  "chat_history": chat_history # Pass history here
168
  })
169
  csv_answer = rag_result["result"].strip()