import dspy import gradio as gr from dotenv import load_dotenv from medirag.cache.local import LocalSemanticCache from medirag.index.kdbai import KDBAIDailyMedIndexer from medirag.rag.dspy import DspyRAG, DailyMedRetrieve from medirag.rag.llama_index import WorkflowRAG from llama_index.llms.openai import OpenAI from llama_index.core import Settings from medirag.rag.qa_rag import QuestionAnswerRunner # Load Env load_dotenv() # Initialize the Retriever indexer = KDBAIDailyMedIndexer() indexer.load_index() rm = DailyMedRetrieve(indexer=indexer) # Set the LLM model for DSPy llm = dspy.OpenAI(model="gpt-4o-mini", max_tokens=4000) dspy.settings.configure(lm=llm, rm=rm) # Set the LLM model for LlamaIndex Settings.llm = OpenAI(model="gpt-4o-mini") sm = LocalSemanticCache(model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_cache.json") def clear_cache(): sm.clear() gr.Info("Cache is cleared", duration=1) async def ask_med_question(query: str, enable_stream: bool, enable_reranking: bool, top_k: int): if enable_stream: llama_index_rag = WorkflowRAG(indexer=indexer, timeout=60, top_k=top_k, with_reranker=enable_reranking) qa = QuestionAnswerRunner(sm=sm, rag=llama_index_rag) else: dspy_rag = DspyRAG(k=top_k, with_reranker=enable_reranking) qa = QuestionAnswerRunner(sm=sm, rag=dspy_rag) accumulated_response = "" response = qa.ask(query, enable_stream=enable_stream) async for chunk in response: accumulated_response += chunk yield accumulated_response css = """ h1 { text-align: center; display:block; } #md {margin-top: 70px} """ # Set up the Gradio interface with a checkbox for enabling streaming with gr.Blocks(css=css) as app: gr.Markdown("# DailyMed RAG") with gr.Row(): with gr.Column(scale=1, min_width=100): gr.Image( "doc/images/MediRag.png", width=100, min_width=100, show_label=False, show_download_button=False, show_share_button=False, show_fullscreen_button=False, ) with gr.Column(scale=10): gr.Markdown( "### Ask any question about medication usage and get answers based on DailyMed data.", elem_id="md" ) with gr.Row(): enable_stream_chk = gr.Checkbox(label="Enable Streaming", value=False) enable_reranking_chk = gr.Checkbox(label="Enable ReRanking", value=False) top_k_dropdown = gr.Dropdown( [3, 5, 7], label="Top K", info="Documents to Retrieve!", min_width=100, value=3, ) clear_cache_bt = gr.Button("Clear Cache") input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...") output_text = gr.Textbox(interactive=False, label="Response", lines=10) submit_bt = gr.Button("Submit") # Update the button click function to include the checkbox value submit_bt.click( fn=ask_med_question, inputs=[input_text, enable_stream_chk, enable_reranking_chk, top_k_dropdown], outputs=output_text, ) # Update the button click function to include the checkbox value clear_cache_bt.click(fn=clear_cache) app.launch()