Arjun Moorthy
Enable PDF processing in RAG system
2105147
#!/usr/bin/env python3
"""
OncoLife Symptom & Triage Assistant
A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients.
Updated: Using BioMistral-7B base model for medical conversations.
REBUILD: Simplified to use only base model, no adapters.
RAG: Added document retrieval capabilities for PDFs and other reference materials (optional).
"""
import gradio as gr
import os
import json
from pathlib import Path
from transformers import AutoTokenizer, MistralForCausalLM
import torch
from spaces import GPU
# RAG imports (optional)
try:
import chromadb
from sentence_transformers import SentenceTransformer
import PyPDF2
import pdfplumber
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
import fitz # PyMuPDF for better PDF handling
RAG_AVAILABLE = True
except ImportError:
print("⚠️ RAG libraries not available, running in instruction-only mode")
RAG_AVAILABLE = False
# Force GPU detection for HF Spaces
@GPU
def force_gpu_detection():
"""Force GPU detection for Hugging Face Spaces"""
return torch.cuda.is_available()
class OncoLifeAssistant:
def __init__(self):
# BioMistral base model configuration
BASE = "BioMistral/BioMistral-7B"
print("πŸ”„ Initializing OncoLife Symptom & Triage Assistant")
print(f"πŸ“¦ Loading base model: {BASE}")
# Force GPU detection first
try:
gpu_available = force_gpu_detection()
print(f"πŸ–₯️ GPU Detection: {gpu_available}")
except Exception as e:
print(f"⚠️ GPU detection error: {e}")
gpu_available = torch.cuda.is_available()
self._load_model(BASE, gpu_available)
# Load the OncoLife instructions
self._load_instructions()
# Initialize RAG system (optional)
self.rag_enabled = False
if RAG_AVAILABLE:
try:
self._initialize_rag()
self.rag_enabled = True
print("βœ… RAG system initialized successfully")
except Exception as e:
print(f"⚠️ RAG initialization failed: {e}")
print("πŸ”„ Continuing with instruction-only mode")
else:
print("πŸ”„ Running in instruction-only mode (no RAG)")
def _load_instructions(self):
"""Load the OncoLife instructions from the text file"""
try:
instructions_file = Path(__file__).parent / "oncolifebot_instructions.txt"
if instructions_file.exists():
with open(instructions_file, 'r') as f:
self.instructions = f.read()
print("βœ… Loaded oncolifebot_instructions.txt")
else:
print("⚠️ oncolifebot_instructions.txt not found")
self.instructions = ""
except Exception as e:
print(f"❌ Error loading instructions: {e}")
self.instructions = ""
def _initialize_rag(self):
"""Initialize the RAG system with document embeddings (lightweight version)"""
try:
print("πŸ” Initializing lightweight RAG system...")
# Use a smaller embedding model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("βœ… Loaded embedding model")
# Initialize ChromaDB with persistence disabled for memory efficiency
self.chroma_client = chromadb.Client()
self.collection = self.chroma_client.create_collection(
name="oncolife_documents",
metadata={"description": "OncoLife reference documents"}
)
print("βœ… Initialized ChromaDB collection")
# Load and process documents (limited to essential files)
self._load_documents_lightweight()
except Exception as e:
print(f"❌ Error initializing RAG: {e}")
self.embedding_model = None
self.collection = None
raise e
def _load_documents_lightweight(self):
"""Load only essential documents to save memory"""
try:
docs_path = Path(__file__).parent / "guideline-docs"
print(f"πŸ“š Loading essential documents from: {docs_path}")
if not docs_path.exists():
print("⚠️ guideline-docs directory not found")
return
# Text splitter for chunking documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # Smaller chunks to save memory
chunk_overlap=100,
separators=["\n\n", "\n", ". ", " ", ""]
)
documents_loaded = 0
# Process PDF files (essential medical guidelines)
for pdf_file in docs_path.glob("*.pdf"):
try:
print(f"πŸ“„ Processing PDF: {pdf_file.name}")
text = self._extract_pdf_text(pdf_file)
if text:
chunks = text_splitter.split_text(text)
self._add_chunks_to_db(chunks, pdf_file.name)
documents_loaded += 1
print(f"βœ… Added {len(chunks)} chunks from {pdf_file.name}")
else:
print(f"⚠️ No text extracted from {pdf_file.name}")
except Exception as e:
print(f"❌ Error processing {pdf_file.name}: {e}")
# Process JSON files (lightweight)
for json_file in docs_path.glob("*.json"):
try:
print(f"πŸ“„ Processing JSON: {json_file.name}")
with open(json_file, 'r') as f:
data = json.load(f)
# Convert JSON to text representation
text = json.dumps(data, indent=2)
chunks = text_splitter.split_text(text)
self._add_chunks_to_db(chunks, json_file.name)
documents_loaded += 1
print(f"βœ… Added {len(chunks)} chunks from {json_file.name}")
except Exception as e:
print(f"❌ Error processing {json_file.name}: {e}")
# Process text files (lightweight)
for txt_file in docs_path.glob("*.txt"):
try:
print(f"πŸ“„ Processing TXT: {txt_file.name}")
with open(txt_file, 'r', encoding='utf-8') as f:
text = f.read()
chunks = text_splitter.split_text(text)
self._add_chunks_to_db(chunks, txt_file.name)
documents_loaded += 1
print(f"βœ… Added {len(chunks)} chunks from {txt_file.name}")
except Exception as e:
print(f"❌ Error processing {txt_file.name}: {e}")
print(f"βœ… RAG system initialized with {documents_loaded} documents")
except Exception as e:
print(f"❌ Error loading documents: {e}")
def _extract_pdf_text(self, pdf_path):
"""Extract text from PDF using multiple methods"""
try:
# Try PyMuPDF first (better for complex PDFs)
try:
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
if text.strip():
return text
except Exception as e:
print(f"PyMuPDF failed for {pdf_path.name}: {e}")
# Fallback to pdfplumber
try:
with pdfplumber.open(pdf_path) as pdf:
text = ""
for page in pdf.pages:
if page.extract_text():
text += page.extract_text() + "\n"
return text
except Exception as e:
print(f"pdfplumber failed for {pdf_path.name}: {e}")
# Final fallback to PyPDF2
try:
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
except Exception as e:
print(f"PyPDF2 failed for {pdf_path.name}: {e}")
return None
except Exception as e:
print(f"❌ Error extracting text from {pdf_path.name}: {e}")
return None
def _add_chunks_to_db(self, chunks, source_name):
"""Add document chunks to the vector database"""
try:
if not chunks or not self.collection:
return
# Generate embeddings
embeddings = self.embedding_model.encode(chunks)
# Add to ChromaDB
self.collection.add(
embeddings=embeddings.tolist(),
documents=chunks,
metadatas=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))],
ids=[f"{source_name}_chunk_{i}" for i in range(len(chunks))]
)
except Exception as e:
print(f"❌ Error adding chunks to database: {e}")
def _retrieve_relevant_documents(self, query, top_k=3):
"""Retrieve relevant document chunks for a query"""
try:
if not self.collection or not self.embedding_model or not self.rag_enabled:
return []
# Generate query embedding
query_embedding = self.embedding_model.encode([query])
# Search for similar documents
results = self.collection.query(
query_embeddings=query_embedding.tolist(),
n_results=top_k
)
# Format results
relevant_docs = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
relevant_docs.append({
'content': doc,
'source': results['metadatas'][0][i]['source'],
'similarity': results['distances'][0][i] if 'distances' in results else None
})
return relevant_docs
except Exception as e:
print(f"❌ Error retrieving documents: {e}")
return []
def _load_model(self, model_id, gpu_available):
"""Load the BioMistral base model with memory optimization"""
try:
print("πŸ”„ Loading BioMistral base model...")
# Determine device strategy
if gpu_available and torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
print("πŸ–₯️ Loading BioMistral model on GPU...")
else:
device = "cpu"
dtype = torch.float32
print("πŸ’» Loading BioMistral model on CPU...")
# Load tokenizer
print(f"πŸ“ Loading tokenizer: {model_id}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# Load the model with memory optimization
print(f"πŸ“¦ Loading model: {model_id}")
self.model = MistralForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map="auto",
torch_dtype=dtype,
low_cpu_mem_usage=True,
# Add memory optimization
max_memory={0: "8GB", "cpu": "16GB"} if gpu_available else {"cpu": "8GB"}
)
# Add pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"βœ… BioMistral base model loaded successfully on {device.upper()}!")
except Exception as e:
print(f"❌ Error loading BioMistral model: {e}")
self.model = None
self.tokenizer = None
def generate_oncolife_response(self, user_input, conversation_history):
"""Generate response using OncoLife instructions and optional RAG"""
try:
if self.model is None or self.tokenizer is None:
return """❌ **Model Loading Error**
The OncoLife assistant model failed to load. This could be due to:
1. Model not available
2. Memory constraints
3. Network issues
Please check the Space logs for details."""
print(f"πŸ”„ Generating OncoLife response for: {user_input}")
# Retrieve relevant documents using RAG (if available)
context_text = ""
if self.rag_enabled:
try:
relevant_docs = self._retrieve_relevant_documents(user_input, top_k=2)
if relevant_docs:
context_text = "\n\n**Relevant Reference Information:**\n"
for i, doc in enumerate(relevant_docs):
context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:300]}...\n"
except Exception as e:
print(f"⚠️ RAG retrieval failed: {e}")
# Create prompt using the loaded instructions and retrieved context
system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly:
{self.instructions}
{context_text}
Current user input: {user_input}"""
# Format conversation history
history_text = ""
if conversation_history:
for entry in conversation_history:
history_text += f"User: {entry['user']}\nAssistant: {entry['assistant']}\n\n"
# Create full prompt
prompt = f"{system_prompt}\n\nConversation History:\n{history_text}\nUser: {user_input}\nAssistant:"
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
# Get the device the model is actually on
model_device = next(self.model.parameters()).device
print(f"πŸ”§ Model device: {model_device}")
# Move inputs to the same device as the model
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(model_device)
print(f"πŸ“¦ Moved {key} to {model_device}")
# Ensure model is in eval mode
self.model.eval()
# Generate with proper device handling
with torch.no_grad():
try:
outputs = self.model.generate(
**inputs,
max_new_tokens=512, # Longer responses for detailed medical assessment
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
except RuntimeError as e:
if "device" in str(e).lower():
print("πŸ”„ Device error detected, trying CPU fallback...")
# Move everything to CPU and try again
self.model = self.model.to("cpu")
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to("cpu")
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
raise e
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the assistant's response
if "Assistant:" in response:
answer = response.split("Assistant:")[-1].strip()
else:
answer = response.strip()
print("βœ… OncoLife response generated successfully")
return answer
except Exception as e:
print(f"❌ Error generating OncoLife response: {e}")
return f"""❌ **Generation Error**
Error: {str(e)}
This could be due to:
1. Model compatibility issues
2. Memory constraints
3. Input format problems
Please try a simpler question or check the logs for more details."""
def chat(self, message, history):
"""Main chat interface for OncoLife Assistant"""
if not message.strip():
return "Please describe your symptoms or concerns."
# Convert history to the format expected by generate_oncolife_response
conversation_history = []
for user_msg, assistant_msg in history:
conversation_history.append({
"user": user_msg,
"assistant": assistant_msg
})
# Generate response using OncoLife instructions and optional RAG
response = self.generate_oncolife_response(message, conversation_history)
return response
# Create interface
assistant = OncoLifeAssistant()
interface = gr.ChatInterface(
fn=assistant.chat,
title="πŸ₯ OncoLife Symptom & Triage Assistant",
description="I'm here to help assess your symptoms and determine if you need to contact your care team. I can access your medical guidelines and reference documents to provide accurate information.",
examples=[
["I'm feeling nauseous and tired"],
["I have a fever of 101"],
["My neuropathy is getting worse"],
["I'm having trouble eating"],
["I feel dizzy and lightheaded"]
],
theme=gr.themes.Soft()
)
if __name__ == "__main__":
print("=" * 60)
print("OncoLife Symptom & Triage Assistant")
print("=" * 60)
interface.launch(server_name="0.0.0.0", server_port=7860)