Phoenix21 commited on
Commit
2065cb4
·
verified ·
1 Parent(s): 5067009

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +4 -3
pipeline.py CHANGED
@@ -29,6 +29,7 @@ from cleaner_chain import get_cleaner_chain
29
 
30
  from langchain.llms.base import LLM
31
 
 
32
  ###############################################################################
33
  # 1) Environment keys
34
  ###############################################################################
@@ -137,7 +138,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
137
  user_query = inputs["input"]
138
  chat_history = inputs.get("chat_history", [])
139
 
140
- # Classification step
141
  class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
143
 
@@ -147,7 +148,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
147
  return {"answer": final_refusal.strip()}
148
 
149
  if classification == "Wellness":
150
- rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
151
  csv_answer = rag_result["result"].strip()
152
  if not csv_answer:
153
  web_answer = do_web_search(user_query)
@@ -162,7 +163,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
162
  return {"answer": final_answer}
163
 
164
  if classification == "Brand":
165
- rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
166
  csv_answer = rag_result["result"].strip()
167
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
168
  final_answer = tailor_chain.run({"response": final_merged}).strip()
 
29
 
30
  from langchain.llms.base import LLM
31
 
32
+
33
  ###############################################################################
34
  # 1) Environment keys
35
  ###############################################################################
 
138
  user_query = inputs["input"]
139
  chat_history = inputs.get("chat_history", [])
140
 
141
+ # 1) Classification
142
  class_result = classification_chain.invoke({"query": user_query})
143
  classification = class_result.get("text", "").strip()
144
 
 
148
  return {"answer": final_refusal.strip()}
149
 
150
  if classification == "Wellness":
151
+ rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
152
  csv_answer = rag_result["result"].strip()
153
  if not csv_answer:
154
  web_answer = do_web_search(user_query)
 
163
  return {"answer": final_answer}
164
 
165
  if classification == "Brand":
166
+ rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
167
  csv_answer = rag_result["result"].strip()
168
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
169
  final_answer = tailor_chain.run({"response": final_merged}).strip()