Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
OncoLife Symptom & Triage Assistant | |
A medical chatbot that performs both symptom assessment and clinical triage for chemotherapy patients. | |
Updated: Using BioMistral-7B base model for medical conversations. | |
REBUILD: Simplified to use only base model, no adapters. | |
RAG: Added document retrieval capabilities for PDFs and other reference materials (optional). | |
""" | |
import gradio as gr | |
import os | |
import json | |
from pathlib import Path | |
from transformers import AutoTokenizer, MistralForCausalLM | |
import torch | |
from spaces import GPU | |
# RAG imports (optional) | |
try: | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
import PyPDF2 | |
import pdfplumber | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
import fitz # PyMuPDF for better PDF handling | |
RAG_AVAILABLE = True | |
except ImportError: | |
print("β οΈ RAG libraries not available, running in instruction-only mode") | |
RAG_AVAILABLE = False | |
# Force GPU detection for HF Spaces | |
def force_gpu_detection(): | |
"""Force GPU detection for Hugging Face Spaces""" | |
return torch.cuda.is_available() | |
class OncoLifeAssistant: | |
def __init__(self): | |
# BioMistral base model configuration | |
BASE = "BioMistral/BioMistral-7B" | |
print("π Initializing OncoLife Symptom & Triage Assistant") | |
print(f"π¦ Loading base model: {BASE}") | |
# Force GPU detection first | |
try: | |
gpu_available = force_gpu_detection() | |
print(f"π₯οΈ GPU Detection: {gpu_available}") | |
except Exception as e: | |
print(f"β οΈ GPU detection error: {e}") | |
gpu_available = torch.cuda.is_available() | |
self._load_model(BASE, gpu_available) | |
# Load the OncoLife instructions | |
self._load_instructions() | |
# Initialize RAG system (optional) | |
self.rag_enabled = False | |
if RAG_AVAILABLE: | |
try: | |
self._initialize_rag() | |
self.rag_enabled = True | |
print("β RAG system initialized successfully") | |
except Exception as e: | |
print(f"β οΈ RAG initialization failed: {e}") | |
print("π Continuing with instruction-only mode") | |
else: | |
print("π Running in instruction-only mode (no RAG)") | |
def _load_instructions(self): | |
"""Load the OncoLife instructions from the text file""" | |
try: | |
instructions_file = Path(__file__).parent / "oncolifebot_instructions.txt" | |
if instructions_file.exists(): | |
with open(instructions_file, 'r') as f: | |
self.instructions = f.read() | |
print("β Loaded oncolifebot_instructions.txt") | |
else: | |
print("β οΈ oncolifebot_instructions.txt not found") | |
self.instructions = "" | |
except Exception as e: | |
print(f"β Error loading instructions: {e}") | |
self.instructions = "" | |
def _initialize_rag(self): | |
"""Initialize the RAG system with document embeddings (lightweight version)""" | |
try: | |
print("π Initializing lightweight RAG system...") | |
# Use a smaller embedding model | |
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
print("β Loaded embedding model") | |
# Initialize ChromaDB with persistence disabled for memory efficiency | |
self.chroma_client = chromadb.Client() | |
self.collection = self.chroma_client.create_collection( | |
name="oncolife_documents", | |
metadata={"description": "OncoLife reference documents"} | |
) | |
print("β Initialized ChromaDB collection") | |
# Load and process documents (limited to essential files) | |
self._load_documents_lightweight() | |
except Exception as e: | |
print(f"β Error initializing RAG: {e}") | |
self.embedding_model = None | |
self.collection = None | |
raise e | |
def _load_documents_lightweight(self): | |
"""Load only essential documents to save memory""" | |
try: | |
docs_path = Path(__file__).parent / "guideline-docs" | |
print(f"π Loading essential documents from: {docs_path}") | |
if not docs_path.exists(): | |
print("β οΈ guideline-docs directory not found") | |
return | |
# Text splitter for chunking documents | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, # Smaller chunks to save memory | |
chunk_overlap=100, | |
separators=["\n\n", "\n", ". ", " ", ""] | |
) | |
documents_loaded = 0 | |
# Process PDF files (essential medical guidelines) | |
for pdf_file in docs_path.glob("*.pdf"): | |
try: | |
print(f"π Processing PDF: {pdf_file.name}") | |
text = self._extract_pdf_text(pdf_file) | |
if text: | |
chunks = text_splitter.split_text(text) | |
self._add_chunks_to_db(chunks, pdf_file.name) | |
documents_loaded += 1 | |
print(f"β Added {len(chunks)} chunks from {pdf_file.name}") | |
else: | |
print(f"β οΈ No text extracted from {pdf_file.name}") | |
except Exception as e: | |
print(f"β Error processing {pdf_file.name}: {e}") | |
# Process JSON files (lightweight) | |
for json_file in docs_path.glob("*.json"): | |
try: | |
print(f"π Processing JSON: {json_file.name}") | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
# Convert JSON to text representation | |
text = json.dumps(data, indent=2) | |
chunks = text_splitter.split_text(text) | |
self._add_chunks_to_db(chunks, json_file.name) | |
documents_loaded += 1 | |
print(f"β Added {len(chunks)} chunks from {json_file.name}") | |
except Exception as e: | |
print(f"β Error processing {json_file.name}: {e}") | |
# Process text files (lightweight) | |
for txt_file in docs_path.glob("*.txt"): | |
try: | |
print(f"π Processing TXT: {txt_file.name}") | |
with open(txt_file, 'r', encoding='utf-8') as f: | |
text = f.read() | |
chunks = text_splitter.split_text(text) | |
self._add_chunks_to_db(chunks, txt_file.name) | |
documents_loaded += 1 | |
print(f"β Added {len(chunks)} chunks from {txt_file.name}") | |
except Exception as e: | |
print(f"β Error processing {txt_file.name}: {e}") | |
print(f"β RAG system initialized with {documents_loaded} documents") | |
except Exception as e: | |
print(f"β Error loading documents: {e}") | |
def _extract_pdf_text(self, pdf_path): | |
"""Extract text from PDF using multiple methods""" | |
try: | |
# Try PyMuPDF first (better for complex PDFs) | |
try: | |
doc = fitz.open(pdf_path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
doc.close() | |
if text.strip(): | |
return text | |
except Exception as e: | |
print(f"PyMuPDF failed for {pdf_path.name}: {e}") | |
# Fallback to pdfplumber | |
try: | |
with pdfplumber.open(pdf_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
if page.extract_text(): | |
text += page.extract_text() + "\n" | |
return text | |
except Exception as e: | |
print(f"pdfplumber failed for {pdf_path.name}: {e}") | |
# Final fallback to PyPDF2 | |
try: | |
with open(pdf_path, 'rb') as file: | |
reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
except Exception as e: | |
print(f"PyPDF2 failed for {pdf_path.name}: {e}") | |
return None | |
except Exception as e: | |
print(f"β Error extracting text from {pdf_path.name}: {e}") | |
return None | |
def _add_chunks_to_db(self, chunks, source_name): | |
"""Add document chunks to the vector database""" | |
try: | |
if not chunks or not self.collection: | |
return | |
# Generate embeddings | |
embeddings = self.embedding_model.encode(chunks) | |
# Add to ChromaDB | |
self.collection.add( | |
embeddings=embeddings.tolist(), | |
documents=chunks, | |
metadatas=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))], | |
ids=[f"{source_name}_chunk_{i}" for i in range(len(chunks))] | |
) | |
except Exception as e: | |
print(f"β Error adding chunks to database: {e}") | |
def _retrieve_relevant_documents(self, query, top_k=3): | |
"""Retrieve relevant document chunks for a query""" | |
try: | |
if not self.collection or not self.embedding_model or not self.rag_enabled: | |
return [] | |
# Generate query embedding | |
query_embedding = self.embedding_model.encode([query]) | |
# Search for similar documents | |
results = self.collection.query( | |
query_embeddings=query_embedding.tolist(), | |
n_results=top_k | |
) | |
# Format results | |
relevant_docs = [] | |
if results['documents']: | |
for i, doc in enumerate(results['documents'][0]): | |
relevant_docs.append({ | |
'content': doc, | |
'source': results['metadatas'][0][i]['source'], | |
'similarity': results['distances'][0][i] if 'distances' in results else None | |
}) | |
return relevant_docs | |
except Exception as e: | |
print(f"β Error retrieving documents: {e}") | |
return [] | |
def _load_model(self, model_id, gpu_available): | |
"""Load the BioMistral base model with memory optimization""" | |
try: | |
print("π Loading BioMistral base model...") | |
# Determine device strategy | |
if gpu_available and torch.cuda.is_available(): | |
device = "cuda" | |
dtype = torch.float16 | |
print("π₯οΈ Loading BioMistral model on GPU...") | |
else: | |
device = "cpu" | |
dtype = torch.float32 | |
print("π» Loading BioMistral model on CPU...") | |
# Load tokenizer | |
print(f"π Loading tokenizer: {model_id}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
trust_remote_code=True | |
) | |
# Load the model with memory optimization | |
print(f"π¦ Loading model: {model_id}") | |
self.model = MistralForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True, | |
device_map="auto", | |
torch_dtype=dtype, | |
low_cpu_mem_usage=True, | |
# Add memory optimization | |
max_memory={0: "8GB", "cpu": "16GB"} if gpu_available else {"cpu": "8GB"} | |
) | |
# Add pad token if not present | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
print(f"β BioMistral base model loaded successfully on {device.upper()}!") | |
except Exception as e: | |
print(f"β Error loading BioMistral model: {e}") | |
self.model = None | |
self.tokenizer = None | |
def generate_oncolife_response(self, user_input, conversation_history): | |
"""Generate response using OncoLife instructions and optional RAG""" | |
try: | |
if self.model is None or self.tokenizer is None: | |
return """β **Model Loading Error** | |
The OncoLife assistant model failed to load. This could be due to: | |
1. Model not available | |
2. Memory constraints | |
3. Network issues | |
Please check the Space logs for details.""" | |
print(f"π Generating OncoLife response for: {user_input}") | |
# Retrieve relevant documents using RAG (if available) | |
context_text = "" | |
if self.rag_enabled: | |
try: | |
relevant_docs = self._retrieve_relevant_documents(user_input, top_k=2) | |
if relevant_docs: | |
context_text = "\n\n**Relevant Reference Information:**\n" | |
for i, doc in enumerate(relevant_docs): | |
context_text += f"\n--- Source: {doc['source']} ---\n{doc['content'][:300]}...\n" | |
except Exception as e: | |
print(f"β οΈ RAG retrieval failed: {e}") | |
# Create prompt using the loaded instructions and retrieved context | |
system_prompt = f"""You are the OncoLife Symptom & Triage Assistant. Follow these instructions exactly: | |
{self.instructions} | |
{context_text} | |
Current user input: {user_input}""" | |
# Format conversation history | |
history_text = "" | |
if conversation_history: | |
for entry in conversation_history: | |
history_text += f"User: {entry['user']}\nAssistant: {entry['assistant']}\n\n" | |
# Create full prompt | |
prompt = f"{system_prompt}\n\nConversation History:\n{history_text}\nUser: {user_input}\nAssistant:" | |
# Tokenize | |
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
# Get the device the model is actually on | |
model_device = next(self.model.parameters()).device | |
print(f"π§ Model device: {model_device}") | |
# Move inputs to the same device as the model | |
for key in inputs: | |
if isinstance(inputs[key], torch.Tensor): | |
inputs[key] = inputs[key].to(model_device) | |
print(f"π¦ Moved {key} to {model_device}") | |
# Ensure model is in eval mode | |
self.model.eval() | |
# Generate with proper device handling | |
with torch.no_grad(): | |
try: | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=512, # Longer responses for detailed medical assessment | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
except RuntimeError as e: | |
if "device" in str(e).lower(): | |
print("π Device error detected, trying CPU fallback...") | |
# Move everything to CPU and try again | |
self.model = self.model.to("cpu") | |
for key in inputs: | |
if isinstance(inputs[key], torch.Tensor): | |
inputs[key] = inputs[key].to("cpu") | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
else: | |
raise e | |
# Decode response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the assistant's response | |
if "Assistant:" in response: | |
answer = response.split("Assistant:")[-1].strip() | |
else: | |
answer = response.strip() | |
print("β OncoLife response generated successfully") | |
return answer | |
except Exception as e: | |
print(f"β Error generating OncoLife response: {e}") | |
return f"""β **Generation Error** | |
Error: {str(e)} | |
This could be due to: | |
1. Model compatibility issues | |
2. Memory constraints | |
3. Input format problems | |
Please try a simpler question or check the logs for more details.""" | |
def chat(self, message, history): | |
"""Main chat interface for OncoLife Assistant""" | |
if not message.strip(): | |
return "Please describe your symptoms or concerns." | |
# Convert history to the format expected by generate_oncolife_response | |
conversation_history = [] | |
for user_msg, assistant_msg in history: | |
conversation_history.append({ | |
"user": user_msg, | |
"assistant": assistant_msg | |
}) | |
# Generate response using OncoLife instructions and optional RAG | |
response = self.generate_oncolife_response(message, conversation_history) | |
return response | |
# Create interface | |
assistant = OncoLifeAssistant() | |
interface = gr.ChatInterface( | |
fn=assistant.chat, | |
title="π₯ OncoLife Symptom & Triage Assistant", | |
description="I'm here to help assess your symptoms and determine if you need to contact your care team. I can access your medical guidelines and reference documents to provide accurate information.", | |
examples=[ | |
["I'm feeling nauseous and tired"], | |
["I have a fever of 101"], | |
["My neuropathy is getting worse"], | |
["I'm having trouble eating"], | |
["I feel dizzy and lightheaded"] | |
], | |
theme=gr.themes.Soft() | |
) | |
if __name__ == "__main__": | |
print("=" * 60) | |
print("OncoLife Symptom & Triage Assistant") | |
print("=" * 60) | |
interface.launch(server_name="0.0.0.0", server_port=7860) | |