Update pipeline.py
Browse files- pipeline.py +4 -3
pipeline.py
CHANGED
@@ -29,6 +29,7 @@ from cleaner_chain import get_cleaner_chain
|
|
29 |
|
30 |
from langchain.llms.base import LLM
|
31 |
|
|
|
32 |
###############################################################################
|
33 |
# 1) Environment keys
|
34 |
###############################################################################
|
@@ -137,7 +138,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
137 |
user_query = inputs["input"]
|
138 |
chat_history = inputs.get("chat_history", [])
|
139 |
|
140 |
-
# Classification
|
141 |
class_result = classification_chain.invoke({"query": user_query})
|
142 |
classification = class_result.get("text", "").strip()
|
143 |
|
@@ -147,7 +148,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
147 |
return {"answer": final_refusal.strip()}
|
148 |
|
149 |
if classification == "Wellness":
|
150 |
-
rag_result = wellness_rag_chain.invoke({"
|
151 |
csv_answer = rag_result["result"].strip()
|
152 |
if not csv_answer:
|
153 |
web_answer = do_web_search(user_query)
|
@@ -162,7 +163,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
162 |
return {"answer": final_answer}
|
163 |
|
164 |
if classification == "Brand":
|
165 |
-
rag_result = brand_rag_chain.invoke({"
|
166 |
csv_answer = rag_result["result"].strip()
|
167 |
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
|
168 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
|
|
29 |
|
30 |
from langchain.llms.base import LLM
|
31 |
|
32 |
+
|
33 |
###############################################################################
|
34 |
# 1) Environment keys
|
35 |
###############################################################################
|
|
|
138 |
user_query = inputs["input"]
|
139 |
chat_history = inputs.get("chat_history", [])
|
140 |
|
141 |
+
# 1) Classification
|
142 |
class_result = classification_chain.invoke({"query": user_query})
|
143 |
classification = class_result.get("text", "").strip()
|
144 |
|
|
|
148 |
return {"answer": final_refusal.strip()}
|
149 |
|
150 |
if classification == "Wellness":
|
151 |
+
rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
|
152 |
csv_answer = rag_result["result"].strip()
|
153 |
if not csv_answer:
|
154 |
web_answer = do_web_search(user_query)
|
|
|
163 |
return {"answer": final_answer}
|
164 |
|
165 |
if classification == "Brand":
|
166 |
+
rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
|
167 |
csv_answer = rag_result["result"].strip()
|
168 |
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
|
169 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|