medirag / app.py
alvinhenrick's picture
code cleanup
6ecf1c3 unverified
raw
history blame
3.89 kB
import dspy
import gradio as gr
from dotenv import load_dotenv
from medirag.cache.local import SemanticCaching
from medirag.index.kdbai import KDBAIDailyMedIndexer
from medirag.rag.qa import RAG, DailyMedRetrieve
from medirag.rag.wf import RAGWorkflow
from llama_index.llms.openai import OpenAI
from llama_index.core import Settings
load_dotenv()
# Initialize the components
indexer = KDBAIDailyMedIndexer()
indexer.load_index()
rm = DailyMedRetrieve(indexer=indexer)
turbo = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
dspy.settings.configure(lm=turbo, rm=rm)
# Set the LLM model
Settings.llm = OpenAI(model="gpt-3.5-turbo")
sm = SemanticCaching(
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
)
# Initialize RAGWorkflow with indexer
rag = RAG(k=5)
streaming_rag = RAGWorkflow(indexer=indexer, timeout=60, with_reranker=False, top_k=5, top_n=3)
def clear_cache():
sm.clear()
gr.Info("Cache is cleared", duration=1)
async def ask_med_question(query: str, enable_stream: bool):
# Check the cache first
response = sm.lookup(question=query, cosine_threshold=0.9)
if response:
# Return cached response if found
yield response
else:
if enable_stream:
# Stream response using RAGWorkflow
result = await streaming_rag.run(query=query)
# Handle streaming response
if hasattr(result, "async_response_gen"):
accumulated_response = ""
async for chunk in result.async_response_gen():
accumulated_response += chunk
yield accumulated_response # Accumulate and yield the updated response
# Save the accumulated response to the cache after streaming is complete
sm.save(query, accumulated_response)
elif isinstance(result, str):
# Handle non-streaming string response
yield result
sm.save(query, result)
else:
# Handle unexpected response types
print("Unexpected response type:", result)
yield "An unexpected error occurred."
else:
# Use RAG without streaming
response = rag(query).answer
yield response
# Save the response in the cache
if response:
sm.save(query, response)
css = """
h1 {
text-align: center;
display:block;
}
#md {margin-top: 70px}
"""
# Set up the Gradio interface with a checkbox for enabling streaming
with gr.Blocks(css=css) as app:
gr.Markdown("# DailyMed RAG")
with gr.Row():
with gr.Column(scale=1, min_width=100):
gr.Image(
"doc/images/MediRag.png",
width=100,
min_width=100,
show_label=False,
show_download_button=False,
show_share_button=False,
show_fullscreen_button=False,
)
with gr.Column(scale=10):
gr.Markdown(
"### Ask any question about medication usage and get answers based on DailyMed data.", elem_id="md"
)
with gr.Row():
enable_stream_chk = gr.Checkbox(label="Enable Streaming", value=False)
clear_cache_bt = gr.Button("Clear Cache")
input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...")
output_text = gr.Textbox(interactive=False, label="Response", lines=10)
submit_bt = gr.Button("Submit")
# Update the button click function to include the checkbox value
submit_bt.click(fn=ask_med_question, inputs=[input_text, enable_stream_chk], outputs=output_text)
# Update the button click function to include the checkbox value
clear_cache_bt.click(fn=clear_cache)
app.launch()