RAG / app.py
fastx's picture
Update app.py
d822b20 verified
import os
import json
import glob
from pathlib import Path
import torch
import streamlit as st
from dotenv import load_dotenv
from langchain_groq import ChatGroq
#from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
import numpy as np
from sentence_transformers import util
import time
# Set device for model (CUDA if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load environment variables - works for both local and Hugging Face Spaces
load_dotenv()
# Set up the clinical assistant LLM
# Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
try:
# For Hugging Face Spaces
from huggingface_hub.inference_api import InferenceApi
import os
groq_api_key = os.environ.get('GROQ_API_KEY')
# If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
groq_api_key = st.secrets['GROQ_API_KEY']
if not groq_api_key:
st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
# For UI demonstration without API key
class MockLLM:
def invoke(self, prompt):
return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
llm = MockLLM()
else:
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
except Exception as e:
st.error(f"Error setting up LLM: {str(e)}")
class MockLLM:
def invoke(self, prompt):
return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
llm = MockLLM()
# Set up embeddings for clinical context (Bio_ClinicalBERT)
embeddings = HuggingFaceEmbeddings(
model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
model_kwargs={"device": device}
)
def load_clinical_data():
"""Load both flowcharts and patient cases"""
docs = []
# Get the absolute path to the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Try to handle potential errors with file loading
try:
# Load diagnosis flowcharts
flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
if os.path.exists(flowchart_dir):
for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
try:
with open(fpath, 'r', encoding='utf-8') as f:
data = json.load(f)
content = f"""
DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
Diagnostic Path: {data.get('diagnostic', 'N/A')}
Key Criteria: {data.get('knowledge', 'N/A')}
"""
docs.append(Document(
page_content=content,
metadata={"source": fpath, "type": "flowchart"}
))
except Exception as e:
st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
else:
st.warning(f"Flowchart directory not found at {flowchart_dir}")
# Load patient cases
finished_dir = os.path.join(current_dir, "Finished")
if os.path.exists(finished_dir):
for category_dir in glob.glob(os.path.join(finished_dir, "*")):
if os.path.isdir(category_dir):
for case_file in glob.glob(os.path.join(category_dir, "*.json")):
try:
with open(case_file, 'r', encoding='utf-8') as f:
case_data = json.load(f)
notes = "\n".join(
f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
)
docs.append(Document(
page_content=f"""
PATIENT CASE: {Path(case_file).stem}
Category: {Path(category_dir).name}
Notes: {notes}
""",
metadata={"source": case_file, "type": "patient_case"}
))
except Exception as e:
st.warning(f"Error loading case file {case_file}: {str(e)}")
else:
st.warning(f"Finished directory not found at {finished_dir}")
# If no documents were loaded, add a sample document for testing
if not docs:
st.warning("No clinical data files found. Using sample data for demonstration.")
docs.append(Document(
page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
This application requires clinical data files to be present in the correct directories.
Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
metadata={"source": "sample", "type": "sample"}
))
except Exception as e:
st.error(f"Error loading clinical data: {str(e)}")
# Add a fallback document
docs.append(Document(
page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
metadata={"source": "error", "type": "error"}
))
return docs
def build_vectorstore():
"""Build and return the vectorstore using FAISS"""
documents = load_clinical_data()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = splitter.split_documents(documents)
vectorstore = FAISS.from_documents(splits, embeddings)
return vectorstore
# Path for saving/loading the vectorstore
def get_vectorstore_path():
"""Get the path for saving/loading the vectorstore"""
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, "vectorstore")
# Initialize vectorstore with disk persistence
@st.cache_resource(show_spinner="Loading clinical knowledge base...")
def get_vectorstore():
"""Get or create the vectorstore with disk persistence"""
vectorstore_path = get_vectorstore_path()
# Try to load from disk first
try:
if os.path.exists(vectorstore_path):
st.info("Loading vectorstore from disk...")
# Set allow_dangerous_deserialization to True since we trust our own vectorstore files
return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
except Exception as e:
st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.")
# If loading fails or doesn't exist, build a new one
st.info("Building new vectorstore...")
vectorstore = build_vectorstore()
# Save to disk for future use
try:
os.makedirs(vectorstore_path, exist_ok=True)
vectorstore.save_local(vectorstore_path)
st.success("Vectorstore saved to disk for future use")
except Exception as e:
st.warning(f"Could not save vectorstore to disk: {str(e)}")
return vectorstore
def run_rag_chat(query, vectorstore):
"""Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
try:
retriever = vectorstore.as_retriever()
prompt_template = ChatPromptTemplate.from_template("""
You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
<context>
{context}
</context>
Question: {input}
Answer:
""")
retrieved_docs = retriever.invoke(query, k=3)
retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
# Create document chain first
document_chain = create_stuff_documents_chain(llm, prompt_template)
# Then create retrieval chain
chain = create_retrieval_chain(retriever, document_chain)
# Invoke the chain
response = chain.invoke({"input": query})
# Add retrieved documents to response for transparency
response["context"] = retrieved_docs
return response
except Exception as e:
st.error(f"Error in RAG processing: {str(e)}")
# Return a fallback response
return {
"answer": f"I encountered an error processing your query: {str(e)}",
"context": [],
"input": query
}
def calculate_hit_rate(retriever, query, expected_docs, k=3):
"""Calculate the hit rate for top-k retrieved documents"""
retrieved_docs = retriever.get_relevant_documents(query, k=k)
retrieved_contents = [doc.page_content for doc in retrieved_docs]
hits = 0
for expected in expected_docs:
if any(expected in retrieved for retrieved in retrieved_contents):
hits += 1
return hits / len(expected_docs) if expected_docs else 0.0
def evaluate_rag_response(response, embeddings):
"""Evaluate the RAG response for faithfulness and hit rate"""
scores = {}
# Faithfulness: Answer-Context Similarity
answer_embed = embeddings.embed_query(response["answer"])
context_embeds = [embeddings.embed_query(doc.page_content) for doc in response["context"]]
similarities = [util.cos_sim(answer_embed, ctx_embed).item() for ctx_embed in context_embeds]
scores["faithfulness"] = float(np.mean(similarities)) if similarities else 0.0
# Custom Hit Rate Calculation
retriever = response["retriever"]
scores["hit_rate"] = calculate_hit_rate(
retriever,
query=response["input"],
expected_docs=[doc.page_content for doc in response["context"]],
k=3
)
return scores
def main():
"""Main function to run the Streamlit app"""
# Set page configuration
st.set_page_config(
page_title="DiReCT - Clinical Diagnostic Assistant",
page_icon="🩺",
layout="wide",
initial_sidebar_state="expanded"
)
# Load vectorstore only once using session state
if 'vectorstore' not in st.session_state:
with st.spinner("Loading clinical knowledge base... This may take a minute."):
try:
st.session_state.vectorstore = get_vectorstore()
# Use custom styled message without the success icon
st.markdown("<div style='padding:10px 15px;background-color:rgba(40,167,69,0.2);border-radius:5px;border-left:5px solid rgba(40,167,69,0.8);'>Clinical knowledge base loaded successfully!</div>", unsafe_allow_html=True)
except Exception as e:
st.error(f"Error loading knowledge base: {str(e)}")
st.session_state.vectorstore = None
# Custom CSS for modern look with dark theme compatibility
st.markdown("""
<style>
.stApp {max-width: 1200px; margin: 0 auto;}
.css-18e3th9 {padding-top: 2rem;}
.stButton>button {background-color: #3498db; color: white;}
.stButton>button:hover {background-color: #2980b9;}
.chat-message {border-radius: 10px; padding: 10px; margin-bottom: 10px;}
.chat-message-user {background-color: rgba(52, 152, 219, 0.2); color: inherit;}
.chat-message-assistant {background-color: rgba(240, 240, 240, 0.2); color: inherit;}
.source-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-bottom: 10px; border-left: 5px solid #3498db;}
.metrics-box {background-color: rgba(255, 255, 255, 0.1); color: inherit; border-radius: 5px; padding: 15px; margin-top: 20px;}
.features-container {display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; margin-top: 30px;}
.feature-item {flex: 1 1 calc(50% - 20px); min-width: 300px; display: flex; align-items: center; padding: 20px; border-radius: 10px; background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2)); transition: transform 0.3s, box-shadow 0.3s; border: 1px solid rgba(255, 255, 255, 0.1);}
.feature-item:hover {transform: translateY(-5px); box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);}
.feature-icon {width: 60px; height: 60px; border-radius: 50%; background: linear-gradient(135deg, #3498db, #2980b9); display: flex; align-items: center; justify-content: center; margin-right: 20px; box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3);}
.feature-icon i {font-size: 24px; color: white;}
.feature-content {flex: 1;}
.feature-content h3 {margin-top: 0; margin-bottom: 10px; color: inherit;}
.feature-content p {margin: 0; font-size: 0.9em; color: inherit; opacity: 0.8;}
.input-container {margin-bottom: 20px; padding: 15px; border-radius: 10px; background-color: rgba(255, 255, 255, 0.05); border: 1px solid rgba(255, 255, 255, 0.1);}
</style>
""", unsafe_allow_html=True)
# App states
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'page' not in st.session_state:
st.session_state.page = 'cover'
# Sidebar
with st.sidebar:
st.image("https://img.icons8.com/color/96/000000/caduceus.png", width=80)
st.title("DiReCT")
st.markdown("### Diagnostic Reasoning for Clinical Text")
st.markdown("---")
if st.button("Home", key="home_btn"):
st.session_state.page = 'cover'
if st.button("Diagnostic Assistant", key="assistant_btn"):
st.session_state.page = 'chat'
if st.button("About", key="about_btn"):
st.session_state.page = 'about'
st.markdown("---")
st.markdown("### Model Information")
st.markdown("**Embedding Model:** Bio_ClinicalBERT")
st.markdown("**LLM:** Llama-3.3-70B")
st.markdown("**Vector Store:** FAISS")
# Cover page
if st.session_state.page == 'cover':
# Hero section with animation
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("<h1 style='font-size:3.5em;'>DiReCT</h1>", unsafe_allow_html=True)
st.markdown("<h2 style='font-size:1.8em;color:#3498db;'>Diagnostic Reasoning for Clinical Text</h2>", unsafe_allow_html=True)
st.markdown("""<p style='font-size:1.2em;'>A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.</p>""", unsafe_allow_html=True)
st.markdown("""<br>""", unsafe_allow_html=True)
if st.button("Get Started", key="get_started"):
st.session_state.page = 'chat'
st.rerun()
with col2:
# Animated medical icon
st.markdown("""
<div style='display:flex;justify-content:center;align-items:center;height:100%;'>
<img src="https://img.icons8.com/color/240/000000/healthcare-and-medical.png" style='max-width:90%;'>
</div>
""", unsafe_allow_html=True)
# Modern Features section with Streamlit native components
st.markdown("<br><br>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align:center;'>Key Features</h2>", unsafe_allow_html=True)
# Create a 2x2 grid for features using Streamlit columns
col1, col2 = st.columns(2)
# Feature 1
with col1:
st.markdown("""
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
padding: 20px; border-radius: 10px; height: 100%;
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
<h3>πŸ” Intelligent Retrieval</h3>
<p>Finds the most relevant clinical information from the MIMIC-IV-Ext dataset</p>
</div>
""", unsafe_allow_html=True)
# Feature 2
with col2:
st.markdown("""
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
padding: 20px; border-radius: 10px; height: 100%;
border: 1px solid rgba(255, 255, 255, 0.1); margin-bottom: 20px;">
<h3>🧠 Advanced Reasoning</h3>
<p>Applies clinical knowledge to generate accurate diagnostic insights</p>
</div>
""", unsafe_allow_html=True)
# Feature 3
with col1:
st.markdown("""
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
padding: 20px; border-radius: 10px; height: 100%;
border: 1px solid rgba(255, 255, 255, 0.1);">
<h3>πŸ“„ Source Transparency</h3>
<p>Provides references to all clinical sources used in generating responses</p>
</div>
""", unsafe_allow_html=True)
# Feature 4
with col2:
st.markdown("""
<div style="background: linear-gradient(135deg, rgba(72, 126, 176, 0.1), rgba(72, 126, 176, 0.2));
padding: 20px; border-radius: 10px; height: 100%;
border: 1px solid rgba(255, 255, 255, 0.1);">
<h3>πŸŒ“ Dark/Light Theme Compatible</h3>
<p>Optimized interface that works seamlessly in both dark and light themes</p>
</div>
""", unsafe_allow_html=True)
# Chat interface
elif st.session_state.page == 'chat':
# Initialize session state for input if not exists
if 'user_input' not in st.session_state:
st.session_state.user_input = ""
# Header with clear button
col1, col2 = st.columns([3, 1])
with col1:
st.markdown("<h1>Clinical Diagnostic Assistant</h1>", unsafe_allow_html=True)
with col2:
# Add a clear button in the header
if st.button("πŸ—‘οΈ Clear Chat"):
st.session_state.chat_history = []
st.session_state.user_input = ""
st.rerun()
st.markdown("Ask any clinical diagnostic question and get insights based on medical knowledge and patient cases.")
# Fixed input area at the top
with st.container():
st.markdown("<div class='input-container'>", unsafe_allow_html=True)
user_input = st.text_area("Ask a clinical question:", st.session_state.user_input, height=100, key="question_input")
col1, col2 = st.columns([1, 5])
with col1:
submit_button = st.button("Submit")
st.markdown("</div>", unsafe_allow_html=True)
# Create a container for chat history
chat_container = st.container()
# Process query
if submit_button and user_input:
if st.session_state.vectorstore is None:
st.error("Knowledge base not loaded. Please refresh the page and try again.")
else:
with st.spinner("Analyzing clinical data..."):
try:
# Add a small delay for UX
time.sleep(0.5)
# Run RAG
response = run_rag_chat(user_input, st.session_state.vectorstore)
response["retriever"] = st.session_state.vectorstore.as_retriever()
# Clear previous chat history and only keep the current response
st.session_state.chat_history = [(user_input, response)]
# Clear the input field
st.session_state.user_input = ""
# Rerun to update UI
st.rerun()
except Exception as e:
st.error(f"Error processing query: {str(e)}")
# Display chat history in the container
with chat_container:
for i, (query, response) in enumerate(st.session_state.chat_history):
st.markdown(f"<div class='chat-message chat-message-user'><b>πŸ§‘β€βš•οΈ You:</b> {query}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='chat-message chat-message-assistant'><b>🩺 DiReCT:</b> {response['answer']}</div>", unsafe_allow_html=True)
with st.expander("View Sources"):
for doc in response["context"]:
st.markdown(f"<div class='source-box'>"
f"<b>Source:</b> {Path(doc.metadata['source']).stem}<br>"
f"<b>Type:</b> {doc.metadata['type']}<br>"
f"<b>Content:</b> {doc.page_content[:300]}...</div>",
unsafe_allow_html=True)
# Show evaluation metrics if available
try:
eval_scores = evaluate_rag_response(response, embeddings)
with st.expander("View Evaluation Metrics"):
col1, col2 = st.columns(2)
with col1:
st.metric("Hit Rate (Top-3)", f"{eval_scores['hit_rate']:.2f}")
with col2:
st.metric("Faithfulness", f"{eval_scores['faithfulness']:.2f}")
except Exception as e:
st.warning(f"Evaluation metrics unavailable: {str(e)}")
# About page
elif st.session_state.page == 'about':
st.markdown("<h1>About DiReCT</h1>", unsafe_allow_html=True)
st.markdown("""
### Project Overview
DiReCT (Diagnostic Reasoning for Clinical Text) is a Retrieval-Augmented Generation (RAG) system designed to assist medical professionals with diagnostic reasoning based on clinical notes and medical knowledge.
### Data Sources
This application uses the MIMIC-IV-Ext dataset, which contains de-identified clinical notes and medical records. The system processes:
- Diagnostic flowcharts
- Patient cases
- Clinical guidelines
### Technical Implementation
- **Embedding Model**: Bio_ClinicalBERT for domain-specific text understanding
- **Vector Database**: FAISS for efficient similarity search
- **LLM**: Llama-3.3-70B for generating medically accurate responses
- **Framework**: Built with LangChain and Streamlit
### Evaluation Metrics
The system evaluates responses using:
- **Hit Rate**: Measures how many relevant documents were retrieved
- **Faithfulness**: Measures how well the response aligns with the retrieved context
### Ethical Considerations
This system is designed as a clinical decision support tool and not as a replacement for professional medical judgment. All patient data used has been properly de-identified in compliance with healthcare privacy regulations.
""")
st.markdown("<br>", unsafe_allow_html=True)
st.markdown("### Developers")
st.markdown("This project was developed as part of an academic assignment on RAG systems for clinical applications.")
if __name__ == "__main__":
main()