kkata commited on
Commit
a65fa77
·
verified ·
1 Parent(s): b66fd61

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +764 -0
app.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+
31
+ # LangChain OpenAI imports
32
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
33
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
34
+
35
+ # LlamaParse & LlamaIndex imports
36
+ from llama_parse import LlamaParse # Document parsing library
37
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
38
+
39
+ # LangGraph import
40
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
41
+
42
+ # Pydantic import
43
+ from pydantic import BaseModel # Pydantic for data validation
44
+
45
+ # Typing imports
46
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
47
+
48
+ # Other utilities
49
+ import numpy as np # Numpy for numerical operations
50
+ from groq import Groq
51
+ from mem0 import MemoryClient
52
+ import streamlit as st
53
+ from datetime import datetime
54
+
55
+ #====================================SETUP=====================================#
56
+ # Fetch secrets from Hugging Face Spaces
57
+ api_key = config.get("API_KEY")
58
+ endpoint = config.get("OPENAI_API_BASE")
59
+ llama_api_key = os.environ['GROQ_API_KEY']
60
+ MEM0_api_key = os.environ['mem0']
61
+
62
+ # Initialize the OpenAI embedding function for Chroma
63
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
64
+ api_base=endpoint, # Complete the code to define the API base endpoint
65
+ api_key=api_key, # Complete the code to define the API key
66
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
67
+ )
68
+
69
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
70
+
71
+ # Initialize the OpenAI Embeddings
72
+ embedding_model = OpenAIEmbeddings(
73
+ openai_api_base=endpoint,
74
+ openai_api_key=api_key,
75
+ model='text-embedding-ada-002'
76
+ )
77
+
78
+
79
+ # Initialize the Chat OpenAI model
80
+ llm = ChatOpenAI(
81
+ openai_api_base=endpoint,
82
+ openai_api_key=api_key,
83
+ model="gpt-4o-mini",
84
+ streaming=False
85
+ )
86
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
87
+
88
+ # set the LLM and embedding model in the LlamaIndex settings.
89
+ Settings.llm = llm # Complete the code to define the LLM model
90
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
91
+
92
+ #================================Creating Langgraph agent======================#
93
+
94
+ class AgentState(TypedDict):
95
+ query: str # The current user query
96
+ expanded_query: str # The expanded version of the user query
97
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
98
+ response: str # The generated response to the user query
99
+ precision_score: float # The precision score of the response
100
+ groundedness_score: float # The groundedness score of the response
101
+ groundedness_loop_count: int # Counter for groundedness refinement loops
102
+ precision_loop_count: int # Counter for precision refinement loops
103
+ feedback: str
104
+ query_feedback: str
105
+ groundedness_check: bool
106
+ loop_max_iter: int
107
+
108
+ def expand_query(state):
109
+ """
110
+ Expands the user query to improve retrieval of nutrition disorder-related information.
111
+
112
+ Args:
113
+ state (Dict): The current state of the workflow, containing the user query.
114
+
115
+ Returns:
116
+ Dict: The updated state with the expanded query.
117
+ """
118
+ print("---------Expanding Query---------")
119
+ system_message = '''You are a query-expansion assistant for a medical RAG system about nutritional disorders.
120
+ Rewrite the user's query to maximize recall while preserving intent.
121
+ - Add clinical synonyms, lay terms, and abbreviations (e.g., deficiency/insufficiency, toxicity, sx/symptoms, eval/evaluation, dx/diagnosis, mgmt/management, labs/biomarkers).
122
+ - Include related key phrases (etiology, risk factors, signs/symptoms, workup, treatment, complications) when relevant.
123
+ - Keep any explicit filters if present (e.g., Category:<...>, DisorderType:<...>, page:<N>).
124
+ - Remove filler words; prefer noun phrases and medically relevant keywords.
125
+ - Do not invent facts; do not answer the question.
126
+ Return ONLY the expanded query as a single line, no quotes or explanations.'''
127
+
128
+
129
+ expand_prompt = ChatPromptTemplate.from_messages([
130
+ ("system", system_message),
131
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
132
+
133
+ ])
134
+
135
+ chain = expand_prompt | llm | StrOutputParser()
136
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
137
+ print("expanded_query", expanded_query)
138
+ state["expanded_query"] = expanded_query
139
+ return state
140
+
141
+
142
+ # Initialize the Chroma vector store for retrieving documents
143
+ #vector_store = Chroma(
144
+ # collection_name="nutritional_hypotheticals",
145
+ # persist_directory="./nutritional_db",
146
+ # embedding_function=embedding_model
147
+
148
+ #)
149
+
150
+ # Create a retriever from the vector store
151
+ #retriever = vector_store.as_retriever(
152
+ # search_type='similarity',
153
+ # search_kwargs={'k': 3}
154
+ #)
155
+
156
+ # Initialize the Chroma vector store for retrieving documents
157
+ vector_store = Chroma(
158
+ collection_name='semantic_chunks',
159
+ persist_directory='/content/drive/MyDrive/research_db',
160
+ embedding_function=embedding_model
161
+ )
162
+
163
+ # Create a retriever from the vector store
164
+ retriever = vector_store.as_retriever(
165
+ search_type='similarity',
166
+ search_kwargs={'k': 5}
167
+ )
168
+
169
+
170
+ def retrieve_context(state):
171
+ """
172
+ Retrieves context from the vector store using the expanded or original query.
173
+
174
+ Args:
175
+ state (Dict): The current state of the workflow, containing the query and expanded query.
176
+
177
+ Returns:
178
+ Dict: The updated state with the retrieved context.
179
+ """
180
+ print("---------retrieve_context---------")
181
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
182
+ #print("Query used for retrieval:", query) # Debugging: Print the query
183
+
184
+ # Retrieve documents from the vector store
185
+ docs = retriever.invoke(query)
186
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
187
+
188
+ # Extract both page_content and metadata from each document
189
+ context= [
190
+ {
191
+ "content": doc.page_content, # The actual content of the document
192
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
193
+ }
194
+ for doc in docs
195
+ ]
196
+ state['context'] = context # Complete the code to define the key for storing the context
197
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
198
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
199
+ return state
200
+
201
+
202
+ def craft_response(state: Dict) -> Dict:
203
+ """
204
+ Generates a response using the retrieved context, focusing on nutrition disorders.
205
+
206
+ Args:
207
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
208
+
209
+ Returns:
210
+ Dict: The updated state with the generated response.
211
+ """
212
+ print("---------craft_response---------")
213
+ system_message = '''You are a concise clinical explainer focused on nutritional disorders.
214
+ Use ONLY the provided context to answer. If the answer is not in the context, say you don't know.
215
+ Be specific and organized (e.g., Key features, Evaluation, Management). Avoid speculation.'''
216
+
217
+ response_prompt = ChatPromptTemplate.from_messages([
218
+ ("system", system_message),
219
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
220
+ ])
221
+
222
+ chain = response_prompt | llm
223
+ response = chain.invoke({
224
+ "query": state['query'],
225
+ "context": "\n".join([doc["content"] for doc in state['context']]),
226
+ "feedback":state['feedback'] # add feedback to the prompt
227
+ })
228
+ state['response'] = response
229
+ print("intermediate response: ", response)
230
+
231
+ return state
232
+
233
+
234
+ def score_groundedness(state: Dict) -> Dict:
235
+ """
236
+ Checks whether the response is grounded in the retrieved context.
237
+
238
+ Args:
239
+ state (Dict): The current state of the workflow, containing the response and context.
240
+
241
+ Returns:
242
+ Dict: The updated state with the groundedness score.
243
+ """
244
+ print("---------check_groundedness---------")
245
+ system_message = '''You are a strict evaluator. Given the Context and the assistant Response,
246
+ output a single numeric score between 0.0 and 1.0 indicating how well the response is supported by the context.
247
+ 0.0 = not supported at all; 1.0 = fully supported. Output ONLY the number with no extra text'''
248
+
249
+ groundedness_prompt = ChatPromptTemplate.from_messages([
250
+ ("system", system_message),
251
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
252
+ ])
253
+
254
+ chain = groundedness_prompt | llm | StrOutputParser()
255
+ groundedness_score = float(chain.invoke({
256
+ "context": "\n".join([doc["content"] for doc in state['context']]),
257
+ "response":state['response'].content # Complete the code to define the response
258
+ }))
259
+ print("groundedness_score: ", groundedness_score)
260
+ state['groundedness_loop_count'] += 1
261
+ print("#########Groundedness Incremented###########")
262
+ state['groundedness_score'] = groundedness_score
263
+
264
+ return state
265
+
266
+
267
+ def check_precision(state: Dict) -> Dict:
268
+ """
269
+ Checks whether the response precisely addresses the user’s query.
270
+
271
+ Args:
272
+ state (Dict): The current state of the workflow, containing the query and response.
273
+
274
+ Returns:
275
+ Dict: The updated state with the precision score.
276
+ """
277
+ print("---------check_precision---------")
278
+ system_message = '''You are a strict evaluator. Given the user Query and the assistant Response,
279
+ output a single numeric score between 0.0 and 1.0 indicating how precisely the response answers the query.
280
+ Consider specificity, relevance, and avoidance of unrelated information. Output ONLY the number.'''
281
+
282
+ precision_prompt = ChatPromptTemplate.from_messages([
283
+ ("system", system_message),
284
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
285
+ ])
286
+
287
+ chain =precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
288
+ precision_score = float(chain.invoke({
289
+ "query": state['query'],
290
+ "response":state['response'].content # Complete the code to access the response from the state
291
+ }))
292
+ state['precision_score'] = precision_score
293
+ print("precision_score:", precision_score)
294
+ state['precision_loop_count'] +=1
295
+ print("#########Precision Incremented###########")
296
+ return state
297
+
298
+
299
+ def refine_response(state: Dict) -> Dict:
300
+ """
301
+ Suggests improvements for the generated response.
302
+
303
+ Args:
304
+ state (Dict): The current state of the workflow, containing the query and response.
305
+
306
+ Returns:
307
+ Dict: The updated state with response refinement suggestions.
308
+ """
309
+ print("---------refine_response---------")
310
+
311
+ system_message = '''You are a clinical editor for a nutritional-disorders RAG system.
312
+ Suggest concrete, actionable improvements to make the response more accurate, complete, and well-grounded.
313
+ - Point out missing key features, evaluation steps (labs/biomarkers), management, and red flags.
314
+ - Flag any statements that are not grounded in likely context and suggest adding inline citations [source p.X].
315
+ - Be concise; return a short bullet list of suggestions, not a rewritten answer.'''
316
+
317
+ refine_response_prompt = ChatPromptTemplate.from_messages([
318
+ ("system", system_message),
319
+ ("user", "Query: {query}\nResponse: {response}\n\n"
320
+ "What improvements can be made to enhance accuracy and completeness?")
321
+ ])
322
+
323
+ chain = refine_response_prompt | llm| StrOutputParser()
324
+
325
+ # Store response suggestions in a structured format
326
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
327
+ print("feedback: ", feedback)
328
+ print(f"State: {state}")
329
+ state['feedback'] = feedback
330
+ return state
331
+
332
+
333
+
334
+ def refine_query(state: Dict) -> Dict:
335
+ """
336
+ Suggests improvements for the expanded query.
337
+
338
+ Args:
339
+ state (Dict): The current state of the workflow, containing the query and expanded query.
340
+
341
+ Returns:
342
+ Dict: The updated state with query refinement suggestions.
343
+ """
344
+ print("---------refine_query---------")
345
+ system_message = '''You are a query-refinement assistant for a nutritional-disorders RAG system.
346
+ Improve the EXPANDED QUERY by proposing structured, actionable suggestions that enhance search precision.
347
+ Do NOT rewrite or replace the expanded query; only suggest additions/refinements.
348
+
349
+ Return a short, markdown-style list with these sections (omit any that are not applicable):
350
+ - Missing keywords (clinical terms, labs/biomarkers, complications, evaluation/management terms)
351
+ - Synonyms/variants (lay terms, abbreviations, alternate spellings)
352
+ - Scope refinement (population, severity, acuity, age/sex, pregnancy/lactation)
353
+ - Filters (Category, DisorderType, page, timeframe) — preserve any existing filters
354
+ - Exclusions (terms to avoid that cause drift)
355
+
356
+ Rules:
357
+ - Preserve the user’s intent and any explicit filters already present (e.g., Category:<...>, DisorderType:<...>, page:<N>).
358
+ - Be concise (≤ 6 bullets total). No prose explanations, no answers to the query, and no invented facts.'''
359
+
360
+
361
+ refine_query_prompt = ChatPromptTemplate.from_messages([
362
+ ("system", system_message),
363
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
364
+ "What improvements can be made for a better search?")
365
+ ])
366
+
367
+ chain = refine_query_prompt | llm | StrOutputParser()
368
+
369
+ # Store refinement suggestions without modifying the original expanded query
370
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
371
+ print("query_feedback: ", query_feedback)
372
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
373
+ state['query_feedback'] = query_feedback
374
+ return state
375
+
376
+
377
+ def should_continue_groundedness(state):
378
+ """Decides if groundedness is sufficient or needs improvement."""
379
+ print("---------should_continue_groundedness---------")
380
+ print("groundedness loop count: ", state['groundedness_loop_count'])
381
+ if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
382
+ print("Moving to precision")
383
+ return "check_precision"
384
+ else:
385
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
386
+ return "max_iterations_reached"
387
+ else:
388
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
389
+ return "refine_response"
390
+
391
+
392
+ def should_continue_precision(state: Dict) -> str:
393
+ """Decides if precision is sufficient or needs improvement."""
394
+ print("---------should_continue_precision---------")
395
+ print("precision loop count: ", state['precision_loop_count'])
396
+ if state['precision_score'] : # Threshold for precision
397
+ return "pass" # Complete the workflow
398
+ else:
399
+ if state["precision_loop_count"] > state['loop_max_iter']: # Maximum allowed loops
400
+ return "max_iterations_reached"
401
+ else:
402
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
403
+ return "refine_query" # Refine the query
404
+
405
+
406
+
407
+
408
+
409
+ def max_iterations_reached(state: Dict) -> Dict:
410
+ """Handles the case when the maximum number of iterations is reached."""
411
+ print("---------max_iterations_reached---------")
412
+ """Handles the case when the maximum number of iterations is reached."""
413
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
414
+ state['response'] = response
415
+ return state
416
+
417
+ from langgraph.graph import END, StateGraph, START
418
+
419
+ def create_workflow() -> StateGraph:
420
+ """Creates the updated workflow for the AI nutrition agent."""
421
+ workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
422
+
423
+ # Add processing nodes
424
+ workflow.add_node("expand_query",expand_query ) # Step 1: Expand user query. Complete with the function to expand the query
425
+ workflow.add_node("retrieve_context",retrieve_context ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
426
+ workflow.add_node("craft_response",craft_response ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
427
+ workflow.add_node("score_groundedness", score_groundedness ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
428
+ workflow.add_node("refine_response", refine_response ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
429
+ workflow.add_node("check_precision", check_precision ) # Step 6: Evaluate response precision. Complete with the function to check precision
430
+ workflow.add_node("refine_query",refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
431
+ workflow.add_node("max_iterations_reached",max_iterations_reached ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
432
+
433
+ # Main flow edges
434
+ workflow.add_edge(START, "expand_query")
435
+ workflow.add_edge("expand_query", "retrieve_context")
436
+ workflow.add_edge("retrieve_context", "craft_response")
437
+ workflow.add_edge("craft_response", "score_groundedness")
438
+
439
+ # Conditional edges based on groundedness check
440
+ workflow.add_conditional_edges(
441
+ "score_groundedness",
442
+ should_continue_groundedness, # Use the conditional function
443
+ {
444
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
445
+ "refine_response":"refine_response", # If not, refine the response.
446
+ "max_iterations_reached":"max_iterations_reached" # If max loops reached, exit.
447
+ }
448
+ )
449
+
450
+ workflow.add_edge("refine_response", "retrieve_context") # Refined responses are reprocessed.
451
+
452
+ # Conditional edges based on precision check
453
+ workflow.add_conditional_edges(
454
+ "check_precision",
455
+ should_continue_precision, # Use the conditional function
456
+ {
457
+ "pass": END, # If precise, complete the workflow.
458
+ "refine_query": "refine_query", # If imprecise, refine the query.
459
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
460
+ }
461
+ )
462
+
463
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
464
+
465
+ workflow.add_edge("max_iterations_reached", END)
466
+
467
+ return workflow
468
+
469
+
470
+
471
+
472
+
473
+
474
+ #=========================== Defining the agentic rag tool ====================#
475
+ WORKFLOW_APP = create_workflow().compile()
476
+ @tool
477
+ def agentic_rag(query: str):
478
+ """
479
+ Runs the RAG-based agent with conversation history for context-aware responses.
480
+
481
+ Args:
482
+ query (str): The current user query.
483
+
484
+ Returns:
485
+ Dict[str, Any]: The updated state with the generated response and conversation history.
486
+ """
487
+ # Initialize state with necessary parameters
488
+ inputs = {
489
+ "query": query, # Current user query
490
+ "expanded_query": "", # Expanded version of the query (filled by expand_query)
491
+ "context": [], # Retrieved documents (initially empty)
492
+ "response": "", # AI-generated response (filled by craft_response)
493
+ "precision_score": 0.0, # Precision score of the response
494
+ "groundedness_score": 0.0, # Groundedness score of the response
495
+ "groundedness_loop_count": 0, # Counter for groundedness loops
496
+ "precision_loop_count": 0, # Counter for precision loops
497
+ "feedback": "", # Feedback accumulator
498
+ "query_feedback": "", # Feedback specifically for query refinement
499
+ "loop_max_iter": 2 # Maximum number of iterations for loops
500
+ }
501
+
502
+ output = WORKFLOW_APP.invoke(inputs)
503
+
504
+ return output
505
+
506
+
507
+
508
+ #================================ Guardrails ===========================#
509
+ llama_guard_client = Groq(api_key=llama_api_key)
510
+ # Function to filter user input with Llama Guard
511
+ # Function to filter user input with Llama Guard
512
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
513
+ """
514
+ Filters user input using Llama Guard to ensure it is safe.
515
+
516
+ Parameters:
517
+ - user_input: The input provided by the user.
518
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
519
+
520
+ Returns:
521
+ - The filtered and safe input.
522
+ """
523
+ try:
524
+ # Create a request to Llama Guard to filter the user input
525
+ response = llama_guard_client.chat.completions.create(
526
+ messages=[{"role": "user", "content": user_input}],
527
+ model=model,
528
+ )
529
+ # Return the filtered input
530
+ return response.choices[0].message.content.strip()
531
+ except Exception as e:
532
+ print(f"Error with Llama Guard: {e}")
533
+ return None
534
+
535
+
536
+ #============================= Adding Memory to the agent using mem0 ===============================#
537
+
538
+ class NutritionBot:
539
+ def __init__(self):
540
+ """
541
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
542
+ """
543
+
544
+ # Initialize a memory client to store and retrieve customer interactions
545
+ self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
546
+
547
+ # Initialize the OpenAI client using the provided credentials
548
+ self.client = ChatOpenAI(
549
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
550
+ api_key=config.get("API_KEY"), # API key for authentication
551
+ endpoint = config.get("OPENAI_API_BASE"),
552
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
553
+ )
554
+
555
+ # Define tools available to the chatbot, such as web search
556
+ tools = [agentic_rag]
557
+
558
+ # Define the system prompt to set the behavior of the chatbot
559
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
560
+ Guidelines for Interaction:
561
+ Maintain a polite, professional, and reassuring tone.
562
+ Show genuine empathy for customer concerns and health challenges.
563
+ Reference past interactions to provide personalized and consistent advice.
564
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
565
+ Ensure consistent and accurate information across conversations.
566
+ If any detail is unclear or missing, proactively ask for clarification.
567
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
568
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
569
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
570
+
571
+ """
572
+
573
+ # Build the prompt template for the agent
574
+ prompt = ChatPromptTemplate.from_messages([
575
+ ("system", system_prompt), # System instructions
576
+ ("human", "{input}"), # Placeholder for human input
577
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
578
+ ])
579
+
580
+ # Create an agent capable of interacting with tools and executing tasks
581
+ agent = create_tool_calling_agent(self.client, tools, prompt)
582
+
583
+ # Wrap the agent in an executor to manage tool interactions and execution flow
584
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
585
+
586
+
587
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
588
+ """
589
+ Store customer interaction in memory for future reference.
590
+
591
+ Args:
592
+ user_id (str): Unique identifier for the customer.
593
+ message (str): Customer's query or message.
594
+ response (str): Chatbot's response.
595
+ metadata (Dict, optional): Additional metadata for the interaction.
596
+ """
597
+ if metadata is None:
598
+ metadata = {}
599
+
600
+ # Add a timestamp to the metadata for tracking purposes
601
+ metadata["timestamp"] = datetime.now().isoformat()
602
+
603
+ # Format the conversation for storage
604
+ conversation = [
605
+ {"role": "user", "content": message},
606
+ {"role": "assistant", "content": response}
607
+ ]
608
+
609
+ # Store the interaction in the memory client
610
+ self.memory.add(
611
+ conversation,
612
+ user_id=user_id,
613
+ output_format="v1.1",
614
+ metadata=metadata
615
+ )
616
+
617
+
618
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
619
+ """
620
+ Retrieve past interactions relevant to the current query.
621
+
622
+ Args:
623
+ user_id (str): Unique identifier for the customer.
624
+ query (str): The customer's current query.
625
+
626
+ Returns:
627
+ List[Dict]: A list of relevant past interactions.
628
+ """
629
+ return self.memory.search(
630
+ query=query, # Search for interactions related to the query
631
+ user_id=user_id, # Restrict search to the specific user
632
+ limit=5 # Complete the code to define the limit for retrieved interactions
633
+ )
634
+
635
+
636
+ def handle_customer_query(self, user_id: str, query: str) -> str:
637
+ """
638
+ Process a customer's query and provide a response, taking into account past interactions.
639
+
640
+ Args:
641
+ user_id (str): Unique identifier for the customer.
642
+ query (str): Customer's query.
643
+
644
+ Returns:
645
+ str: Chatbot's response.
646
+ """
647
+
648
+ # Retrieve relevant past interactions for context
649
+ relevant_history = self.get_relevant_history(user_id, query)
650
+
651
+ # Build a context string from the relevant history
652
+ context = "Previous relevant interactions:\n"
653
+ for memory in relevant_history:
654
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
655
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
656
+ context += "---\n"
657
+
658
+ # Print context for debugging purposes
659
+ print("Context: ", context)
660
+
661
+ # Prepare a prompt combining past context and the current query
662
+ prompt = f"""
663
+ Context:
664
+ {context}
665
+
666
+ Current customer query: {query}
667
+
668
+ Provide a helpful response that takes into account any relevant past interactions.
669
+ """
670
+
671
+ # Generate a response using the agent
672
+ response = self.agent_executor.invoke({"input": prompt})
673
+
674
+ # Store the current interaction for future reference
675
+ self.store_customer_interaction(
676
+ user_id=user_id,
677
+ message=query,
678
+ response=response["output"],
679
+ metadata={"type": "support_query"}
680
+ )
681
+
682
+ # Return the chatbot's response
683
+ return response['output']
684
+
685
+
686
+ #=====================User Interface using streamlit ===========================#
687
+ def nutrition_disorder_streamlit():
688
+ """
689
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
690
+ """
691
+ st.title("Nutrition Disorder Specialist")
692
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
693
+ st.write("Type 'exit' to end the conversation.")
694
+
695
+ # Initialize session state for chat history and user_id if they don't exist
696
+ if 'chat_history' not in st.session_state:
697
+ st.session_state.chat_history = []
698
+ if 'user_id' not in st.session_state:
699
+ st.session_state.user_id = None
700
+
701
+ # Login form: Only if user is not logged in
702
+ if st.session_state.user_id is None:
703
+ with st.form("login_form", clear_on_submit=True):
704
+ user_id = st.text_input("Please enter your name to begin:")
705
+ submit_button = st.form_submit_button("Login")
706
+ if submit_button and user_id:
707
+ st.session_state.user_id = user_id
708
+ st.session_state.chat_history.append({
709
+ "role": "assistant",
710
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
711
+ })
712
+ st.session_state.login_submitted = True # Set flag to trigger rerun
713
+ if st.session_state.get("login_submitted", False):
714
+ st.session_state.pop("login_submitted")
715
+ st.rerun()
716
+ else:
717
+ # Display chat history
718
+ for message in st.session_state.chat_history:
719
+ with st.chat_message(message["role"]):
720
+ st.write(message["content"])
721
+
722
+ # Chat input with custom placeholder text
723
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
724
+ if user_query:
725
+ if user_query.lower() == "exit":
726
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
727
+ with st.chat_message("user"):
728
+ st.write("exit")
729
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
730
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
731
+ with st.chat_message("assistant"):
732
+ st.write(goodbye_msg)
733
+ st.session_state.user_id = None
734
+ st.rerun()
735
+ return
736
+
737
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
738
+ with st.chat_message("user"):
739
+ st.write(user_query)
740
+
741
+ # Filter input using Llama Guard
742
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
743
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
744
+
745
+ # Check if input is safe based on allowed statuses
746
+ if filtered_result in ["SAFE", "S6", "S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
747
+ try:
748
+ if 'chatbot' not in st.session_state:
749
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
750
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
751
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
752
+ st.write(response)
753
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
754
+ except Exception as e:
755
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
756
+ st.write(error_msg)
757
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
758
+ else:
759
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
760
+ st.write(inappropriate_msg)
761
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
762
+
763
+ if __name__ == "__main__":
764
+ nutrition_disorder_streamlit()