Spaces:
Sleeping
Sleeping
import dspy | |
import gradio as gr | |
from dotenv import load_dotenv | |
from medirag.cache.local import SemanticCaching | |
from medirag.index.kdbai import KDBAIDailyMedIndexer | |
from medirag.rag.qa import RAG, DailyMedRetrieve | |
from medirag.rag.wf import RAGWorkflow | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core import Settings | |
load_dotenv() | |
# Initialize the components | |
indexer = KDBAIDailyMedIndexer() | |
indexer.load_index() | |
rm = DailyMedRetrieve(indexer=indexer) | |
turbo = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000) | |
dspy.settings.configure(lm=turbo, rm=rm) | |
# Set the LLM model | |
Settings.llm = OpenAI(model="gpt-3.5-turbo") | |
sm = SemanticCaching( | |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json" | |
) | |
# Initialize RAGWorkflow with indexer | |
rag = RAG(k=5) | |
streaming_rag = RAGWorkflow(indexer=indexer, timeout=60, with_reranker=False, top_k=5, top_n=3) | |
def clear_cache(): | |
sm.clear() | |
gr.Info("Cache is cleared", duration=1) | |
async def ask_med_question(query: str, enable_stream: bool): | |
# Check the cache first | |
response = sm.lookup(question=query, cosine_threshold=0.9) | |
if response: | |
# Return cached response if found | |
yield response | |
else: | |
if enable_stream: | |
# Stream response using RAGWorkflow | |
result = await streaming_rag.run(query=query) | |
# Handle streaming response | |
if hasattr(result, "async_response_gen"): | |
accumulated_response = "" | |
async for chunk in result.async_response_gen(): | |
accumulated_response += chunk | |
yield accumulated_response # Accumulate and yield the updated response | |
# Save the accumulated response to the cache after streaming is complete | |
sm.save(query, accumulated_response) | |
elif isinstance(result, str): | |
# Handle non-streaming string response | |
yield result | |
sm.save(query, result) | |
else: | |
# Handle unexpected response types | |
print("Unexpected response type:", result) | |
yield "An unexpected error occurred." | |
else: | |
# Use RAG without streaming | |
response = rag(query).answer | |
yield response | |
# Save the response in the cache | |
if response: | |
sm.save(query, response) | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
} | |
#md {margin-top: 70px} | |
""" | |
# Set up the Gradio interface with a checkbox for enabling streaming | |
with gr.Blocks(css=css) as app: | |
gr.Markdown("# DailyMed RAG") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=100): | |
gr.Image( | |
"doc/images/MediRag.png", | |
width=100, | |
min_width=100, | |
show_label=False, | |
show_download_button=False, | |
show_share_button=False, | |
show_fullscreen_button=False, | |
) | |
with gr.Column(scale=10): | |
gr.Markdown( | |
"### Ask any question about medication usage and get answers based on DailyMed data.", elem_id="md" | |
) | |
with gr.Row(): | |
enable_stream_chk = gr.Checkbox(label="Enable Streaming", value=False) | |
clear_cache_bt = gr.Button("Clear Cache") | |
input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...") | |
output_text = gr.Textbox(interactive=False, label="Response", lines=10) | |
submit_bt = gr.Button("Submit") | |
# Update the button click function to include the checkbox value | |
submit_bt.click(fn=ask_med_question, inputs=[input_text, enable_stream_chk], outputs=output_text) | |
# Update the button click function to include the checkbox value | |
clear_cache_bt.click(fn=clear_cache) | |
app.launch() | |