Phoenix21 commited on
Commit
bd00d5a
·
verified ·
1 Parent(s): 47b1df8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -11
pipeline.py CHANGED
@@ -1,5 +1,3 @@
1
- # pipeline.py
2
-
3
  import os
4
  import getpass
5
  import pandas as pd
@@ -139,12 +137,12 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
139
  chat_history = inputs.get("chat_history", [])
140
 
141
  # 1) Classification
142
- class_result = classification_chain.invoke({"query": user_query})
143
  classification = class_result.get("text", "").strip()
144
 
145
  if classification == "OutOfScope":
146
- refusal_text = refusal_chain.run({})
147
- final_refusal = tailor_chain.run({"response": refusal_text})
148
  return {"answer": final_refusal.strip()}
149
 
150
  if classification == "Wellness":
@@ -154,8 +152,8 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
154
  })
155
  csv_answer = rag_result["result"].strip()
156
  web_answer = do_web_search(user_query) if not csv_answer else ""
157
- final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
158
- final_answer = tailor_chain.run({"response": final_merged}).strip()
159
  return {"answer": final_answer}
160
 
161
  if classification == "Brand":
@@ -164,13 +162,14 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
164
  "chat_history": chat_history # Pass history here
165
  })
166
  csv_answer = rag_result["result"].strip()
167
- final_merged = cleaner_chain.merge(kb=csv_answer, web="")
168
- final_answer = tailor_chain.run({"response": final_merged}).strip()
169
  return {"answer": final_answer}
170
 
171
- refusal_text = refusal_chain.run({})
172
- final_refusal = tailor_chain.run({"response": refusal_text}).strip()
173
  return {"answer": final_refusal}
 
174
  ###############################################################################
175
  # 7) Build a "Runnable" wrapper so .with_listeners() works
176
  ###############################################################################
 
 
 
1
  import os
2
  import getpass
3
  import pandas as pd
 
137
  chat_history = inputs.get("chat_history", [])
138
 
139
  # 1) Classification
140
+ class_result = classification_chain.invoke({"query": user_query, "chat_history": chat_history})
141
  classification = class_result.get("text", "").strip()
142
 
143
  if classification == "OutOfScope":
144
+ refusal_text = refusal_chain.run({"chat_history": chat_history})
145
+ final_refusal = tailor_chain.run({"response": refusal_text, "chat_history": chat_history})
146
  return {"answer": final_refusal.strip()}
147
 
148
  if classification == "Wellness":
 
152
  })
153
  csv_answer = rag_result["result"].strip()
154
  web_answer = do_web_search(user_query) if not csv_answer else ""
155
+ final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer, chat_history=chat_history)
156
+ final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip()
157
  return {"answer": final_answer}
158
 
159
  if classification == "Brand":
 
162
  "chat_history": chat_history # Pass history here
163
  })
164
  csv_answer = rag_result["result"].strip()
165
+ final_merged = cleaner_chain.merge(kb=csv_answer, web="", chat_history=chat_history)
166
+ final_answer = tailor_chain.run({"response": final_merged, "chat_history": chat_history}).strip()
167
  return {"answer": final_answer}
168
 
169
+ refusal_text = refusal_chain.run({"chat_history": chat_history})
170
+ final_refusal = tailor_chain.run({"response": refusal_text, "chat_history": chat_history}).strip()
171
  return {"answer": final_refusal}
172
+
173
  ###############################################################################
174
  # 7) Build a "Runnable" wrapper so .with_listeners() works
175
  ###############################################################################