import os
import json
import glob
from pathlib import Path
import torch
import streamlit as st
from dotenv import load_dotenv
from langchain_groq import ChatGroq
#from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
import numpy as np
from sentence_transformers import util
import time
# Set device for model (CUDA if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load environment variables - works for both local and Hugging Face Spaces
load_dotenv()
# Set up the clinical assistant LLM
# Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
try:
# For Hugging Face Spaces
from huggingface_hub.inference_api import InferenceApi
import os
groq_api_key = os.environ.get('GROQ_API_KEY')
# If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
groq_api_key = st.secrets['GROQ_API_KEY']
if not groq_api_key:
st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
# For UI demonstration without API key
class MockLLM:
def invoke(self, prompt):
return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
llm = MockLLM()
else:
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
except Exception as e:
st.error(f"Error setting up LLM: {str(e)}")
class MockLLM:
def invoke(self, prompt):
return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
llm = MockLLM()
# Set up embeddings for clinical context (Bio_ClinicalBERT)
embeddings = HuggingFaceEmbeddings(
model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
model_kwargs={"device": device}
)
def load_clinical_data():
"""Load both flowcharts and patient cases"""
docs = []
# Get the absolute path to the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Try to handle potential errors with file loading
try:
# Load diagnosis flowcharts
flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
if os.path.exists(flowchart_dir):
for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
try:
with open(fpath, 'r', encoding='utf-8') as f:
data = json.load(f)
content = f"""
DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
Diagnostic Path: {data.get('diagnostic', 'N/A')}
Key Criteria: {data.get('knowledge', 'N/A')}
"""
docs.append(Document(
page_content=content,
metadata={"source": fpath, "type": "flowchart"}
))
except Exception as e:
st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
else:
st.warning(f"Flowchart directory not found at {flowchart_dir}")
# Load patient cases
finished_dir = os.path.join(current_dir, "Finished")
if os.path.exists(finished_dir):
for category_dir in glob.glob(os.path.join(finished_dir, "*")):
if os.path.isdir(category_dir):
for case_file in glob.glob(os.path.join(category_dir, "*.json")):
try:
with open(case_file, 'r', encoding='utf-8') as f:
case_data = json.load(f)
notes = "\n".join(
f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
)
docs.append(Document(
page_content=f"""
PATIENT CASE: {Path(case_file).stem}
Category: {Path(category_dir).name}
Notes: {notes}
""",
metadata={"source": case_file, "type": "patient_case"}
))
except Exception as e:
st.warning(f"Error loading case file {case_file}: {str(e)}")
else:
st.warning(f"Finished directory not found at {finished_dir}")
# If no documents were loaded, add a sample document for testing
if not docs:
st.warning("No clinical data files found. Using sample data for demonstration.")
docs.append(Document(
page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
This application requires clinical data files to be present in the correct directories.
Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
metadata={"source": "sample", "type": "sample"}
))
except Exception as e:
st.error(f"Error loading clinical data: {str(e)}")
# Add a fallback document
docs.append(Document(
page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
metadata={"source": "error", "type": "error"}
))
return docs
def build_vectorstore():
"""Build and return the vectorstore using FAISS"""
documents = load_clinical_data()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = splitter.split_documents(documents)
vectorstore = FAISS.from_documents(splits, embeddings)
return vectorstore
# Path for saving/loading the vectorstore
def get_vectorstore_path():
"""Get the path for saving/loading the vectorstore"""
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, "vectorstore")
# Initialize vectorstore with disk persistence
@st.cache_resource(show_spinner="Loading clinical knowledge base...")
def get_vectorstore():
"""Get or create the vectorstore with disk persistence"""
vectorstore_path = get_vectorstore_path()
# Try to load from disk first
try:
if os.path.exists(vectorstore_path):
st.info("Loading vectorstore from disk...")
# Set allow_dangerous_deserialization to True since we trust our own vectorstore files
return FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
except Exception as e:
st.warning(f"Could not load vectorstore from disk: {str(e)}. Building new vectorstore.")
# If loading fails or doesn't exist, build a new one
st.info("Building new vectorstore...")
vectorstore = build_vectorstore()
# Save to disk for future use
try:
os.makedirs(vectorstore_path, exist_ok=True)
vectorstore.save_local(vectorstore_path)
st.success("Vectorstore saved to disk for future use")
except Exception as e:
st.warning(f"Could not save vectorstore to disk: {str(e)}")
return vectorstore
def run_rag_chat(query, vectorstore):
"""Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
try:
retriever = vectorstore.as_retriever()
prompt_template = ChatPromptTemplate.from_template("""
You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
A powerful RAG-based clinical diagnostic assistant that leverages the MIMIC-IV-Ext dataset to provide accurate medical insights and diagnostic reasoning.
""", unsafe_allow_html=True) st.markdown("""Finds the most relevant clinical information from the MIMIC-IV-Ext dataset
Applies clinical knowledge to generate accurate diagnostic insights
Provides references to all clinical sources used in generating responses
Optimized interface that works seamlessly in both dark and light themes