Phoenix21 commited on
Commit
15969e9
·
verified ·
1 Parent(s): b752340

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +90 -19
pipeline.py CHANGED
@@ -5,10 +5,8 @@ import getpass
5
  import pandas as pd
6
  from typing import Optional, Dict, Any
7
 
8
- try:
9
- from langchain.runnables.base import Runnable
10
- except ImportError:
11
- from langchain_core.runnables.base import Runnable
12
 
13
  from langchain.docstore.document import Document
14
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -18,6 +16,7 @@ from langchain.chains import RetrievalQA
18
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
19
  import litellm
20
 
 
21
  from classification_chain import get_classification_chain
22
  from refusal_chain import get_refusal_chain
23
  from tailor_chain import get_tailor_chain
@@ -25,27 +24,83 @@ from cleaner_chain import get_cleaner_chain
25
 
26
  from langchain.llms.base import LLM
27
 
28
- # Environment keys
 
 
29
  if not os.environ.get("GEMINI_API_KEY"):
30
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
31
  if not os.environ.get("GROQ_API_KEY"):
32
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
33
 
 
 
 
34
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
35
- # ... [unchanged code for building/loading vectorstore] ...
36
- # Use your previously provided implementation here.
37
- # For brevity, not repeating this section.
38
- pass
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
41
- # ... [unchanged code for building a RAG chain] ...
42
- pass
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  classification_chain = get_classification_chain()
45
  refusal_chain = get_refusal_chain()
46
  tailor_chain = get_tailor_chain()
47
  cleaner_chain = get_cleaner_chain()
48
 
 
 
 
49
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
50
 
51
  wellness_csv = "AIChatbot.csv"
@@ -70,14 +125,21 @@ def do_web_search(query: str) -> str:
70
  response = manager_agent.run(search_query)
71
  return response
72
 
 
 
 
73
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
 
 
 
 
 
74
  user_query = inputs["input"]
75
  chat_history = inputs.get("chat_history", [])
76
 
77
- print("DEBUG: Starting run_with_chain_context...")
78
  class_result = classification_chain.invoke({"query": user_query})
79
  classification = class_result.get("text", "").strip()
80
- print("DEBUG: Classification =>", classification)
81
 
82
  if classification == "OutOfScope":
83
  refusal_text = refusal_chain.run({})
@@ -85,8 +147,7 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
85
  return {"answer": final_refusal.strip()}
86
 
87
  if classification == "Wellness":
88
- # Use the correct key "query" instead of "input"
89
- rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
90
  csv_answer = rag_result["result"].strip()
91
  if not csv_answer:
92
  web_answer = do_web_search(user_query)
@@ -96,24 +157,34 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
96
  web_answer = do_web_search(user_query)
97
  else:
98
  web_answer = ""
 
99
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
100
  final_answer = tailor_chain.run({"response": final_merged}).strip()
101
  return {"answer": final_answer}
102
 
103
  if classification == "Brand":
104
- rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
105
  csv_answer = rag_result["result"].strip()
106
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
107
  final_answer = tailor_chain.run({"response": final_merged}).strip()
108
  return {"answer": final_answer}
109
 
 
110
  refusal_text = refusal_chain.run({})
111
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
112
  return {"answer": final_refusal}
113
 
114
- # Runnable wrapper for my_memory_logic.py
 
 
 
115
  class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
 
 
 
 
116
  def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
117
  return run_with_chain_context(input)
118
 
 
119
  pipeline_runnable = PipelineRunnable()
 
5
  import pandas as pd
6
  from typing import Optional, Dict, Any
7
 
8
+ # Correct import for Runnable
9
+ from langchain.schema import Runnable
 
 
10
 
11
  from langchain.docstore.document import Document
12
  from langchain.embeddings import HuggingFaceEmbeddings
 
16
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
17
  import litellm
18
 
19
+ # Classification/Refusal/Tailor/Cleaner
20
  from classification_chain import get_classification_chain
21
  from refusal_chain import get_refusal_chain
22
  from tailor_chain import get_tailor_chain
 
24
 
25
  from langchain.llms.base import LLM
26
 
27
+ ###############################################################################
28
+ # 1) Environment keys
29
+ ###############################################################################
30
  if not os.environ.get("GEMINI_API_KEY"):
31
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
32
  if not os.environ.get("GROQ_API_KEY"):
33
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
34
 
35
+ ###############################################################################
36
+ # 2) Build or load VectorStore
37
+ ###############################################################################
38
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
39
+ if os.path.exists(store_dir):
40
+ print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading from disk.")
41
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
42
+ vectorstore = FAISS.load_local(store_dir, embeddings)
43
+ return vectorstore
44
+ else:
45
+ print(f"DEBUG: Building new store from CSV: {csv_path}")
46
+ df = pd.read_csv(csv_path)
47
+ df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
48
+ df.columns = df.columns.str.strip()
49
+
50
+ if "Answer" in df.columns:
51
+ df.rename(columns={"Answer": "Answers"}, inplace=True)
52
+ if "Question" not in df.columns and "Question " in df.columns:
53
+ df.rename(columns={"Question ": "Question"}, inplace=True)
54
+
55
+ if "Question" not in df.columns or "Answers" not in df.columns:
56
+ raise ValueError("CSV must have 'Question' and 'Answers' columns.")
57
+
58
+ docs = []
59
+ for _, row in df.iterrows():
60
+ q = str(row["Question"])
61
+ ans = str(row["Answers"])
62
+ doc = Document(page_content=ans, metadata={"question": q})
63
+ docs.append(doc)
64
+
65
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
66
+ vectorstore = FAISS.from_documents(docs, embedding=embeddings)
67
+ vectorstore.save_local(store_dir)
68
+ return vectorstore
69
+
70
+ ###############################################################################
71
+ # 3) Build RAG chain
72
+ ###############################################################################
73
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
74
+ class GeminiLangChainLLM(LLM):
75
+ def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
76
+ messages = [{"role": "user", "content": prompt}]
77
+ return llm_model(messages, stop_sequences=stop)
78
+
79
+ @property
80
+ def _llm_type(self) -> str:
81
+ return "custom_gemini"
82
+
83
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
84
+ gemini_as_llm = GeminiLangChainLLM()
85
+ rag_chain = RetrievalQA.from_chain_type(
86
+ llm=gemini_as_llm,
87
+ chain_type="stuff",
88
+ retriever=retriever,
89
+ return_source_documents=True
90
+ )
91
+ return rag_chain
92
+
93
+ ###############################################################################
94
+ # 4) Initialize sub-chains
95
+ ###############################################################################
96
  classification_chain = get_classification_chain()
97
  refusal_chain = get_refusal_chain()
98
  tailor_chain = get_tailor_chain()
99
  cleaner_chain = get_cleaner_chain()
100
 
101
+ ###############################################################################
102
+ # 5) Build vectorstores & RAG
103
+ ###############################################################################
104
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
105
 
106
  wellness_csv = "AIChatbot.csv"
 
125
  response = manager_agent.run(search_query)
126
  return response
127
 
128
+ ###############################################################################
129
+ # 6) Orchestrator function: returns a dict => {"answer": "..."}
130
+ ###############################################################################
131
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
132
+ """
133
+ Called by the Runnable.
134
+ inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
135
+ Output: { "answer": <final string> }
136
+ """
137
  user_query = inputs["input"]
138
  chat_history = inputs.get("chat_history", [])
139
 
140
+ # 1) Classification
141
  class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
 
143
 
144
  if classification == "OutOfScope":
145
  refusal_text = refusal_chain.run({})
 
147
  return {"answer": final_refusal.strip()}
148
 
149
  if classification == "Wellness":
150
+ rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
 
151
  csv_answer = rag_result["result"].strip()
152
  if not csv_answer:
153
  web_answer = do_web_search(user_query)
 
157
  web_answer = do_web_search(user_query)
158
  else:
159
  web_answer = ""
160
+
161
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
162
  final_answer = tailor_chain.run({"response": final_merged}).strip()
163
  return {"answer": final_answer}
164
 
165
  if classification == "Brand":
166
+ rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
167
  csv_answer = rag_result["result"].strip()
168
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
169
  final_answer = tailor_chain.run({"response": final_merged}).strip()
170
  return {"answer": final_answer}
171
 
172
+ # fallback
173
  refusal_text = refusal_chain.run({})
174
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
175
  return {"answer": final_refusal}
176
 
177
+ ###############################################################################
178
+ # 7) Build a "Runnable" wrapper so .with_listeners() works
179
+ ###############################################################################
180
+
181
  class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
182
+ """
183
+ Wraps run_with_chain_context(...) in a Runnable
184
+ so that RunnableWithMessageHistory can attach listeners.
185
+ """
186
  def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
187
  return run_with_chain_context(input)
188
 
189
+ # Export an instance of PipelineRunnable for use in my_memory_logic.py
190
  pipeline_runnable = PipelineRunnable()