# app.py import os import json import keras from datasets import load_dataset import tensorflow as tf from huggingface_hub import login import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer) from sentence_transformers import SentenceTransformer from typing import List, Dict, Union, Tuple import faiss import numpy as np from datasets import Dataset import torch.nn.functional as F from torch.cuda.amp import autocast import gc from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel) from tqdm.auto import tqdm from torch.utils.data import DataLoader import logging import wandb from pathlib import Path from typing import List, Dict, Union, Optional, Any import torch.nn as nn from dataclasses import dataclass, field import time import asyncio import pytest from unittest.mock import Mock, patch from sklearn.metrics import classification_report, confusion_matrix import gradio as gr import matplotlib.pyplot as plt from datetime import datetime import requests import pandas as pd import seaborn as sns import traceback from matplotlib.gridspec import GridSpec from datasets import load_dataset, concatenate_datasets from langchain.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings from langchain.document_loaders import TextLoader from google.colab import output import IPython.display as display from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training # Ensure Hugging Face login try: hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) print("Login successful!") except Exception as e: print("Hugging Face Login failed:", e) # CUDA and Memory Configurations torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' def prepare_initial_datasets(batch_size=8): print("Loading datasets with memory-optimized batch processing...") def process_medqa_batch(examples): results = [] inputs = examples['input'] instructions = examples['instruction'] outputs = examples['output'] for inp, inst, out in zip(inputs, instructions, outputs): results.append({ "input": f"{inp} {inst}", "output": out }) return results def process_meddia_batch(examples): results = [] inputs = examples['input'] outputs = examples['output'] for inp, out in zip(inputs, outputs): results.append({ "input": inp, "output": out }) return results def process_persona_batch(examples): results = [] personalities = examples['personality'] utterances = examples['utterances'] for pers, utts in zip(personalities, utterances): try: # Process personality list personality = ' '.join([ p for p in pers if isinstance(p, str) ]) # Process utterances if utts and len(utts) > 0: utterance = utts[0] history = [] # Process history if 'history' in utterance and utterance['history']: history = [ h for h in utterance['history'] if isinstance(h, str) ] history_text = ' '.join(history) # Get candidate response candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else '' if personality or history_text: results.append({ "input": f"{personality} {history_text}".strip(), "output": candidate }) except Exception as e: print(f"Error processing persona batch item: {e}") continue return results # Load and process each dataset separately print("Processing MedQA dataset...") medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]") medqa_processed = [] for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"): batch = medqa[i:i + batch_size] medqa_processed.extend(process_medqa_batch(batch)) if i % (batch_size * 5) == 0: torch.cuda.empty_cache() print("Processing MedDiagnosis dataset...") meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]") meddia_processed = [] for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"): batch = meddia[i:i + batch_size] meddia_processed.extend(process_meddia_batch(batch)) if i % (batch_size * 5) == 0: torch.cuda.empty_cache() print("Processing Persona-Chat dataset...") persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]") persona_processed = [] for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"): batch = persona[i:i + batch_size] persona_processed.extend(process_persona_batch(batch)) if i % (batch_size * 5) == 0: torch.cuda.empty_cache() torch.cuda.empty_cache() print("Creating final dataset...") all_processed = persona_processed + medqa_processed + meddia_processed valid_data = { "input": [], "output": [] } for item in all_processed: if item["input"].strip() and item["output"].strip(): valid_data["input"].append(item["input"]) valid_data["output"].append(item["output"]) final_dataset = Dataset.from_dict(valid_data) print(f"Final dataset size: {len(final_dataset)}") return final_dataset def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4): def tokenize_batch(examples): formatted_texts = [] for i in range(0, len(examples['input']), batch_size): sub_batch_inputs = examples['input'][i:i + batch_size] sub_batch_outputs = examples['output'][i:i + batch_size] for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs): try: formatted_text = f"""user {input_text} assistant {output_text} """ formatted_texts.append(formatted_text) except Exception as e: print(f"Error formatting text: {e}") continue tokenized = tokenizer( formatted_texts, padding="max_length", truncation=True, max_length=max_length, return_tensors=None ) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized print(f"Tokenizing dataset in small batches (size={batch_size})...") tokenized_dataset = dataset.map( tokenize_batch, batched=True, batch_size=batch_size, remove_columns=dataset.column_names, desc="Tokenizing dataset", load_from_cache_file=False ) return tokenized_dataset def setup_model_and_tokenizer(model_name="google/gemma-2b"): tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, llm_int8_enable_fp32_cpu_offload=True ) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.float16, low_cpu_mem_usage=True ) model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=4, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model, tokenizer def setup_training_arguments(output_dir="./pearly_fine_tuned"): return TrainingArguments( output_dir=output_dir, num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=16, warmup_steps=50, logging_steps=10, save_steps=200, learning_rate=2e-4, fp16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_8bit", max_grad_norm=0.3, weight_decay=0.001, logging_dir="./logs", save_total_limit=2, remove_unused_columns=False, dataloader_pin_memory=False, max_steps=500, report_to=["none"], ) def main(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() print("Preparing initial datasets...") combined_dataset = prepare_initial_datasets(batch_size=4) print(f"\nDataset size: {len(combined_dataset)}") print(f"Column names: {combined_dataset.column_names}") if len(combined_dataset) > 0: print("\nSample input-output pair:") print(f"Input: {combined_dataset[0]['input'][:100]}...") print(f"Output: {combined_dataset[0]['output'][:100]}...") print("\nSetting up model and tokenizer...") model, tokenizer = setup_model_and_tokenizer() print("\nPreparing dataset for training...") processed_dataset = prepare_dataset( combined_dataset, tokenizer, max_length=256, batch_size=2 ) torch.cuda.empty_cache() training_args = setup_training_arguments() trainer = Trainer( model=model, args=training_args, train_dataset=processed_dataset, tokenizer=tokenizer, ) print("\nStarting training...") try: trainer.train() except Exception as e: print(f"Training error: {e}") torch.cuda.empty_cache() raise e finally: torch.cuda.empty_cache() print("\nSaving model...") trainer.save_model() print("Training completed!") DISCLAIMER = """ IMPORTANT MEDICAL DISCLAIMER: Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services. Pearly DOES NOT: - Make medical diagnoses - Prescribe medications - Provide specific treatment recommendations - Replace professional medical advice Always consult qualified healthcare professionals for medical advice and treatment. In case of emergency, call 999 immediately. """ class PearlyBot: def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"): print("Loading saved model...") print(DISCLAIMER) # Clean memory if torch.cuda.is_available(): torch.cuda.empty_cache() # Load tokenizer and model directly from saved path self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto" ) self.model.eval() # Set to evaluation mode # Initialize RAG components self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model) self.vector_store = None self.conversation_history = [] def initialize_rag(self, documents_path="./knowledge_base"): """Initialize RAG system""" print("Loading knowledge base...") text_splitter = RecursiveCharacterTextSplitter( chunk_size=300, chunk_overlap=100, separators=["\n\n", "\n", ".", "!", "?", ":"] ) documents = [] for filename in os.listdir(documents_path): if filename.endswith('.txt'): loader = TextLoader(os.path.join(documents_path, filename)) documents.extend(loader.load()) texts = text_splitter.split_documents(documents) self.vector_store = FAISS.from_documents(texts, self.embeddings) self.retriever = self.vector_store.as_retriever( search_type="similarity", search_kwargs={"k": 5} ) print("Knowledge base loaded successfully!") def get_relevant_context(self, user_input): if not self.retriever: return "" docs = self.retriever.get_relevant_documents(user_input) return "\n\n".join([doc.page_content for doc in docs]) def generate_response(self, user_input): context = self.get_relevant_context(user_input) history = "\n".join([ f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" for turn in self.conversation_history[-3:] ]) prompt = f"""system As Pearly, I use the following medical guidelines to help triage patients: {context} Previous Conversation: {history} Based on these guidelines, I will: 1. Assess symptoms and severity 2. Ask relevant follow-up questions 3. Direct to appropriate care (999, 111, or GP) 4. Show empathy and cultural sensitivity 5. Never diagnose or recommend treatments user {user_input} assistant""" inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ).to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, min_new_tokens=20, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2, pad_token_id=self.tokenizer.pad_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("assistant")[-1].strip() if "" in response: response = response.split("")[0].strip() self.conversation_history.append({ "user": user_input, "assistant": response }) return response def create_demo(): """Set up Gradio interface for the chatbot with enhanced styling and functionality.""" try: # Health check @gr.routes.get("/health") def health_check(): return {"status": "healthy"} bot = AdaptiveMedicalBot() def chat(message: str, history: List[Dict[str, str]]): try: if not message.strip(): return history bot_response = bot.generate_response(message) # Add user message history.append({ "role": "user", "content": message }) # Add bot response history.append({ "role": "assistant", "content": bot_response['response'] }) return history except Exception as e: logger.error(f"Chat error: {e}") history.append({ "role": "user", "content": message }) history.append({ "role": "assistant", "content": "I apologize, but I'm experiencing technical difficulties. For emergencies, please call 999." }) return history def process_feedback(is_positive: bool, comment: str, history: List[Dict[str, str]]): try: if not history: return last_interaction = history[-2:] # Get last user message and bot response if len(last_interaction) == 2: user_msg = last_interaction[0]["content"] bot_msg = last_interaction[1]["content"] feedback_data = { "user_message": user_msg, "bot_response": bot_msg, "feedback": 1 if is_positive else -1, "comment": comment, "timestamp": datetime.now().isoformat() } bot.handle_feedback( message=user_msg, response=bot_msg, feedback=1 if is_positive else -1 ) # Clear feedback inputs return gr.update(value="") except Exception as e: logger.error(f"Error processing feedback: {e}") # Create enhanced Gradio interface with gr.Blocks(theme=gr.themes.Soft( primary_hue="blue", secondary_hue="indigo", neutral_hue="slate", font=gr.themes.GoogleFont("Inter") )) as demo: # Custom CSS for enhanced styling gr.HTML(""" """) # Emergency Banner gr.HTML("""
🚨 For medical emergencies, always call 999 immediately 🚨
""") # Header Section with gr.Row(elem_classes="header"): gr.Markdown(""" # GP Medical Triage Assistant - Pearly Welcome to your personal medical triage assistant. I'm here to help assess your symptoms and guide you to appropriate care. """) # Main Features Grid gr.HTML("""
🏥
GP Appointments
🔍
Symptom Assessment
Urgent Care Guide
💊
Medical Advice
""") # Chat Interface with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( value=[{ "role": "assistant", "content": "Hello! I'm Pearly, your GP medical assistant. How can I help you today?" }], height=500, elem_id="chatbot", type="messages", show_label=False ) with gr.Row(): msg = gr.Textbox( label="Your message", placeholder="Type your message here...", lines=2, scale=4, autofocus=True, submit_on_enter=True ) submit = gr.Button("Send", variant="primary", scale=1) with gr.Column(scale=1): # Quick Actions Panel gr.Markdown("### Quick Actions") emergency_btn = gr.Button("🚨 Emergency Info", variant="secondary") nhs_111_btn = gr.Button("📞 NHS 111 Info", variant="secondary") booking_btn = gr.Button("📅 GP Booking", variant="secondary") # Conversation Controls gr.Markdown("### Controls") clear = gr.Button("🗑️ Clear Chat") # Feedback Section gr.Markdown("### Feedback") with gr.Row(): feedback_positive = gr.Button("👍", elem_id="thumb-up") feedback_negative = gr.Button("👎", elem_id="thumb-down") feedback_text = gr.Textbox( label="Additional comments (optional)", placeholder="Tell us more about your experience...", lines=2, visible=True ) feedback_submit = gr.Button("Submit Feedback", visible=True) # Examples Section with gr.Accordion("Example Messages", open=False): gr.Examples( examples=[ ["I've been having severe headaches for the past week"], ["I need to book a routine checkup"], ["I'm feeling very anxious lately and need help"], ["My child has had a fever for 2 days"], ["I need information about COVID-19 testing"] ], inputs=msg ) # Information Accordions with gr.Accordion("NHS Services Guide", open=False): gr.Markdown(""" ### Emergency Services (999) - Life-threatening emergencies - Severe injuries - Suspected heart attack or stroke ### NHS 111 - Urgent but non-emergency situations - Medical advice needed - Unsure where to go ### GP Services - Routine check-ups - Non-urgent medical issues - Prescription renewals """) # Event Handlers # Message submission handlers msg.submit( chat, inputs=[msg, chatbot], outputs=[chatbot] ).then( lambda: gr.update(value=""), None, [msg] ) submit.click( chat, inputs=[msg, chatbot], outputs=[chatbot] ).then( lambda: gr.update(value=""), None, [msg] ) # Feedback handlers feedback_positive.click( lambda history: process_feedback(True, feedback_text.value, history), inputs=[chatbot], outputs=[feedback_text] ) feedback_negative.click( lambda history: process_feedback(False, feedback_text.value, history), inputs=[chatbot], outputs=[feedback_text] ) feedback_submit.click( lambda: gr.update(value=""), outputs=[feedback_text] ) # Quick Action Button Handlers def show_emergency_info(): return """🚨 Emergency Services (999) - For life-threatening emergencies - Severe chest pain - Difficulty breathing - Severe bleeding - Loss of consciousness """ def show_nhs_111_info(): return """📞 NHS 111 Service - Available 24/7 - Medical advice - Local service information - Urgent care guidance """ def show_booking_info(): return """📅 GP Booking Options - Online booking - Phone booking - Routine appointments - Urgent appointments """ emergency_btn.click(lambda: show_emergency_info(), outputs=[msg]) nhs_111_btn.click(lambda: show_nhs_111_info(), outputs=[msg]) booking_btn.click(lambda: show_booking_info(), outputs=[msg]) return demo except Exception as e: logger.error(f"Error creating demo: {e}") raise if __name__ == "__main__": load_dotenv() # Load environment variables demo = create_demo() # Launch the Gradio app demo.launch(share=True)