Spaces:
Running
Running
import streamlit as st | |
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant | |
from medrag_multi_modal.retrieval.text_retrieval import ( | |
BM25sRetriever, | |
ContrieverRetriever, | |
MedCPTRetriever, | |
NVEmbed2Retriever, | |
) | |
# Define constants | |
ALL_AVAILABLE_MODELS = [ | |
"gemini-1.5-flash-latest", | |
"gemini-1.5-pro-latest", | |
"gpt-4o", | |
"gpt-4o-mini", | |
] | |
# Sidebar for configuration settings | |
st.sidebar.title("Configuration Settings") | |
project_name = st.sidebar.text_input( | |
label="Project Name", | |
value="ml-colabs/medrag-multi-modal", | |
placeholder="wandb project name", | |
help="format: wandb_username/wandb_project_name", | |
) | |
chunk_dataset_id = st.sidebar.selectbox( | |
label="Chunk Dataset ID", | |
options=["ashwiniai/medrag-text-corpus-chunks"], | |
) | |
llm_model = st.sidebar.selectbox( | |
label="LLM Model", | |
options=ALL_AVAILABLE_MODELS, | |
) | |
top_k_chunks_for_query = st.sidebar.slider( | |
label="Top K Chunks for Query", | |
min_value=1, | |
max_value=20, | |
value=5, | |
) | |
top_k_chunks_for_options = st.sidebar.slider( | |
label="Top K Chunks for Options", | |
min_value=1, | |
max_value=20, | |
value=3, | |
) | |
rely_only_on_context = st.sidebar.checkbox( | |
label="Rely Only on Context", | |
value=False, | |
) | |
retriever_type = st.sidebar.selectbox( | |
label="Retriever Type", | |
options=[ | |
"", | |
"BM25S", | |
"Contriever", | |
"MedCPT", | |
"NV-Embed-v2", | |
], | |
) | |
if retriever_type != "": | |
llm_model = LLMClient(model_name=llm_model) | |
retriever = None | |
if retriever_type == "BM25S": | |
retriever = BM25sRetriever.from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" | |
) | |
elif retriever_type == "Contriever": | |
retriever = ContrieverRetriever.from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", | |
chunk_dataset=chunk_dataset_id, | |
) | |
elif retriever_type == "MedCPT": | |
retriever = MedCPTRetriever.from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", | |
chunk_dataset=chunk_dataset_id, | |
) | |
elif retriever_type == "NV-Embed-v2": | |
retriever = NVEmbed2Retriever.from_index( | |
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", | |
chunk_dataset=chunk_dataset_id, | |
) | |
medqa_assistant = MedQAAssistant( | |
llm_client=llm_model, | |
retriever=retriever, | |
top_k_chunks_for_query=top_k_chunks_for_query, | |
top_k_chunks_for_options=top_k_chunks_for_options, | |
) | |
with st.chat_message("assistant"): | |
st.markdown( | |
""" | |
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences. | |
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge. | |
**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions. | |
Please consult a medical professional for any medical advice. | |
In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal). | |
""", | |
unsafe_allow_html=True, | |
) | |
query = st.chat_input("Enter your question here") | |
if query: | |
with st.chat_message("user"): | |
st.markdown(query) | |
response = medqa_assistant.predict(query=query) | |
with st.chat_message("assistant"): | |
st.markdown(response.response) | |