import streamlit as st import weave from medrag_multi_modal.assistant import ( FigureAnnotatorFromPageImage, LLMClient, MedQAAssistant, ) from medrag_multi_modal.assistant.llm_client import ( GOOGLE_MODELS, MISTRAL_MODELS, OPENAI_MODELS, ) from medrag_multi_modal.retrieval import MedCPTRetriever # Define constants ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS # Sidebar for configuration settings st.sidebar.title("Configuration Settings") project_name = st.sidebar.text_input( label="Project Name", value="ml-colabs/medrag-multi-modal", placeholder="wandb project name", help="format: wandb_username/wandb_project_name", ) chunk_dataset_name = st.sidebar.text_input( label="Text Chunk WandB Dataset Name", value="grays-anatomy-chunks:v0", placeholder="wandb dataset name", help="format: wandb_dataset_name:version", ) index_artifact_address = st.sidebar.text_input( label="WandB Index Artifact Address", value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", placeholder="wandb artifact address", help="format: wandb_username/wandb_project_name/wandb_artifact_name:version", ) image_artifact_address = st.sidebar.text_input( label="WandB Image Artifact Address", value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", placeholder="wandb artifact address", help="format: wandb_username/wandb_project_name/wandb_artifact_name:version", ) llm_client_model_name = st.sidebar.selectbox( label="LLM Client Model Name", options=ALL_AVAILABLE_MODELS, index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"), help="select a model from the list", ) figure_extraction_model_name = st.sidebar.selectbox( label="Figure Extraction Model Name", options=ALL_AVAILABLE_MODELS, index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"), help="select a model from the list", ) structured_output_model_name = st.sidebar.selectbox( label="Structured Output Model Name", options=ALL_AVAILABLE_MODELS, index=ALL_AVAILABLE_MODELS.index("gpt-4o"), help="select a model from the list", ) # Streamlit app layout st.title("MedQA Assistant App") # Initialize Weave weave.init(project_name=project_name) # Initialize clients and assistants llm_client = LLMClient(model_name=llm_client_model_name) retriever = MedCPTRetriever.from_wandb_artifact( chunk_dataset_name=chunk_dataset_name, index_artifact_address=index_artifact_address, ) figure_annotator = FigureAnnotatorFromPageImage( figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name), structured_output_llm_client=LLMClient(model_name=structured_output_model_name), image_artifact_address=image_artifact_address, ) medqa_assistant = MedQAAssistant( llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator ) query = st.chat_input("Enter your question here") if query: with st.chat_message("user"): st.markdown(query) response = medqa_assistant.predict(query=query) with st.chat_message("assistant"): st.markdown(response)