medrag / app.py
geekyrakshit's picture
update: app.py
716efd2 verified
import streamlit as st
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
from medrag_multi_modal.retrieval.text_retrieval import (
BM25sRetriever,
ContrieverRetriever,
MedCPTRetriever,
NVEmbed2Retriever,
)
# Define constants
ALL_AVAILABLE_MODELS = [
"gemini-1.5-flash-latest",
"gemini-1.5-pro-latest",
"gpt-4o",
"gpt-4o-mini",
]
# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
label="Project Name",
value="ml-colabs/medrag-multi-modal",
placeholder="wandb project name",
help="format: wandb_username/wandb_project_name",
)
chunk_dataset_id = st.sidebar.selectbox(
label="Chunk Dataset ID",
options=["ashwiniai/medrag-text-corpus-chunks"],
)
llm_model = st.sidebar.selectbox(
label="LLM Model",
options=ALL_AVAILABLE_MODELS,
)
top_k_chunks_for_query = st.sidebar.slider(
label="Top K Chunks for Query",
min_value=1,
max_value=20,
value=5,
)
top_k_chunks_for_options = st.sidebar.slider(
label="Top K Chunks for Options",
min_value=1,
max_value=20,
value=3,
)
rely_only_on_context = st.sidebar.checkbox(
label="Rely Only on Context",
value=False,
)
retriever_type = st.sidebar.selectbox(
label="Retriever Type",
options=[
"",
"BM25S",
"Contriever",
"MedCPT",
"NV-Embed-v2",
],
)
if retriever_type != "":
llm_model = LLMClient(model_name=llm_model)
retriever = None
if retriever_type == "BM25S":
retriever = BM25sRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
)
elif retriever_type == "Contriever":
retriever = ContrieverRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
chunk_dataset=chunk_dataset_id,
)
elif retriever_type == "MedCPT":
retriever = MedCPTRetriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
chunk_dataset=chunk_dataset_id,
)
elif retriever_type == "NV-Embed-v2":
retriever = NVEmbed2Retriever.from_index(
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
chunk_dataset=chunk_dataset_id,
)
medqa_assistant = MedQAAssistant(
llm_client=llm_model,
retriever=retriever,
top_k_chunks_for_query=top_k_chunks_for_query,
top_k_chunks_for_options=top_k_chunks_for_options,
)
with st.chat_message("assistant"):
st.markdown(
"""
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
Please consult a medical professional for any medical advice.
In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
""",
unsafe_allow_html=True,
)
query = st.chat_input("Enter your question here")
if query:
with st.chat_message("user"):
st.markdown(query)
response = medqa_assistant.predict(query=query)
with st.chat_message("assistant"):
st.markdown(response.response)