PearlIsa commited on
Commit
ce580ca
·
verified ·
1 Parent(s): da25227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +452 -240
app.py CHANGED
@@ -1,251 +1,472 @@
1
  # app.py
2
  import os
3
- import logging
 
 
 
 
4
  import torch
5
- from typing import Dict, List, Any
6
- import gradio as gr
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
  from sentence_transformers import SentenceTransformer
9
- from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
10
  import faiss
11
  import numpy as np
12
- from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from datetime import datetime
14
- import json
15
- from huggingface_hub import login
16
- from dotenv import load_dotenv
17
-
18
- # Quick connectivity test to Hugging Face
 
 
 
 
 
 
 
 
 
 
 
19
  try:
20
- import huggingface_hub
21
- whoami = huggingface_hub.whoami()
22
- print("Hugging Face Login successful:", whoami)
 
23
  except Exception as e:
24
  print("Hugging Face Login failed:", e)
25
 
 
 
 
 
 
 
26
 
27
- # Load environment variables
28
- load_dotenv()
29
-
30
- # Suppress TensorFlow INFO and WARNING logs
31
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
32
- os.environ.pop("HF_HUB_OFFLINE", None)
33
-
34
- # Configure logging
35
- logging.basicConfig(
36
- level=logging.INFO,
37
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
38
- handlers=[
39
- logging.StreamHandler(),
40
- logging.FileHandler('app.log')
41
- ]
42
- )
43
-
44
- # Create a logger for your application
45
- logger = logging.getLogger('medical_bot')
46
-
47
- # Retrieve secrets securely from environment variables
48
- hf_token = os.getenv("HF_TOKEN")
49
- if hf_token:
50
- login(token=hf_token)
51
-
52
-
53
-
54
- class AdaptiveMedicalBot:
55
- def __init__(self):
56
- self.config = self.AdaptiveBotConfig()
57
- self.setup_models()
58
- self.load_datasets()
59
- self.setup_adaptive_learning()
60
- self.conversation_history = [] # Store conversation history
61
- self.symptom_tracker = {} # Track symptoms and severity across conversation
62
-
63
- class AdaptiveBotConfig:
64
- MODEL_NAME = "google/gemma-7b"
65
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
66
- LORA_R = 8
67
- LORA_ALPHA = 16
68
- LORA_DROPOUT = 0.1
69
- LORA_TARGET_MODULES = ["q_proj", "v_proj"]
70
- MAX_LENGTH = 512
71
- BATCH_SIZE = 1
72
- LEARNING_RATE = 1e-4
73
-
74
- def setup_adaptive_learning(self):
75
- """Initialize adaptive learning components"""
76
- self.feedback_history = []
77
-
78
- def setup_models(self):
79
- """Initialize models with LoRA and quantization"""
80
- try:
81
- # Add explicit device check
82
- device = "cuda" if torch.cuda.is_available() else "cpu"
83
- logger.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
84
 
85
- bnb_config = BitsAndBytesConfig(
86
- load_in_4bit=True,
87
- bnb_4bit_quant_type="nf4",
88
- bnb_4bit_compute_dtype=torch.float16,
89
- bnb_4bit_use_double_quant=True
90
- )
91
 
92
- # Add loading message
93
- logger.info("Loading tokenizer...")
94
- self.tokenizer = AutoTokenizer.from_pretrained(
95
- self.config.MODEL_NAME,
96
- use_auth_token=True,
97
- trust_remote_code=True
98
- )
99
 
100
- logger.info("Loading base model...")
101
- base_model = AutoModelForCausalLM.from_pretrained(
102
- self.config.MODEL_NAME,
103
- quantization_config=bnb_config,
104
- device_map="auto",
105
- trust_remote_code=True,
106
- use_auth_token=True
107
- )
108
 
109
- logger.info("Preparing model for k-bit training...")
110
- base_model = prepare_model_for_kbit_training(base_model)
111
-
112
- logger.info("Configuring LoRA...")
113
- lora_config = LoraConfig(
114
- r=self.config.LORA_R,
115
- lora_alpha=self.config.LORA_ALPHA,
116
- target_modules=self.config.LORA_TARGET_MODULES,
117
- lora_dropout=self.config.LORA_DROPOUT,
118
- bias="none",
119
- task_type=TaskType.CAUSAL_LM
120
- )
121
 
122
- logger.info("Applying PEFT configuration...")
123
- self.model = get_peft_model(base_model, lora_config)
124
-
125
- logger.info("Loading embedding model...")
126
- self.embedding_model = SentenceTransformer(
127
- self.config.EMBEDDING_MODEL,
128
- device=device
129
- )
130
-
131
- logger.info("Models loaded successfully!")
132
-
133
- except Exception as e:
134
- logger.error(f"Error setting up models: {e}")
135
- raise
136
-
137
- def load_datasets(self):
138
- """Load and prepare datasets for RAG"""
139
- try:
140
- datasets = {
141
- "medqa": load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]"),
142
- "diagnosis": load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]"),
143
- "persona": load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
144
- }
145
-
146
- self.documents = []
147
- for dataset_name, dataset in datasets.items():
148
- for item in dataset:
149
- if dataset_name == "persona":
150
- if isinstance(item.get('personality'), list):
151
- self.documents.append({'text': " ".join(item['personality']), 'type': 'persona'})
152
- else:
153
- if 'input' in item and 'output' in item:
154
- self.documents.append({'text': f"{item['input']}\n{item['output']}", 'type': dataset_name})
155
-
156
- self._create_index()
157
- except Exception as e:
158
- logger.error(f"Error loading datasets: {e}")
159
- raise
160
-
161
- def _create_index(self):
162
- """Create FAISS index for RAG"""
163
- try:
164
- sample_embedding = self.embedding_model.encode("sample text")
165
- self.index = faiss.IndexFlatIP(sample_embedding.shape[0])
166
-
167
- embeddings = [self.embedding_model.encode(doc['text']) for doc in self.documents]
168
- self.index.add(np.array(embeddings))
169
- except Exception as e:
170
- logger.error(f"Error creating FAISS index: {e}")
171
- raise
172
-
173
- def generate_follow_up_questions(self, message: str, context: Dict[str, Any]) -> List[str]:
174
- """Generate follow-up questions based on context"""
175
- try:
176
- prompt = f"""Patient message: "{message}"
177
- Generate relevant follow-up questions focusing on timing, severity, associated symptoms, and impact on daily life.
178
- Questions:"""
179
-
180
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
181
- outputs = self.model.generate(inputs['input_ids'], max_new_tokens=50, temperature=0.7, do_sample=True)
182
- questions = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
183
- return questions.split("\n")
184
- except Exception as e:
185
- logger.error(f"Error generating follow-up questions: {e}")
186
- return ["Could you tell me more about when this started?"]
187
-
188
- def assess_symptom_severity(self, message: str) -> str:
189
- """Assess severity based on keywords in the message"""
190
- if "severe" in message.lower() or "emergency" in message.lower():
191
- return "emergency"
192
- elif "persistent" in message.lower() or "moderate" in message.lower():
193
- return "urgent"
194
- return "routine"
195
-
196
- def generate_response(self, message: str) -> Dict[str, Any]:
197
- """Generate a response based on the message"""
198
- try:
199
- severity = self.assess_symptom_severity(message)
200
- response = ""
201
-
202
- # Retrieve relevant documents from FAISS
203
- query_embedding = self.embedding_model.encode([message])
204
- _, indices = self.index.search(query_embedding, k=5)
205
- relevant_docs = [self.documents[idx]['text'] for idx in indices[0]]
206
-
207
- prompt = f"""As a compassionate medical assistant, analyze the patient message: "{message}".
208
- Consider relevant knowledge and the following documents:\n{relevant_docs}.
209
- Respond with empathy, follow-up questions, and care guidance."""
210
-
211
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.config.MAX_LENGTH).to(self.model.device)
212
- outputs = self.model.generate(inputs['input_ids'], max_new_tokens=100, temperature=0.7, do_sample=True)
213
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
214
-
215
- follow_ups = self.generate_follow_up_questions(message, {})
216
- response += f"\n{follow_ups[0]}"
217
 
218
- # Append response to conversation history
219
- self.conversation_history.append((message, response))
220
-
221
- # Add care level guidance
222
- if severity == "emergency":
223
- response += "\nThis seems urgent. Please call 999 immediately."
224
- elif severity == "urgent":
225
- response += "\nConsider calling NHS 111 for urgent assistance."
226
-
227
- return {'response': response}
228
- except Exception as e:
229
- logger.error(f"Error generating response: {e}")
230
- return {
231
- 'response': "I'm experiencing technical issues. If this is an emergency, please call 999 immediately.",
232
- }
233
-
234
- def handle_feedback(self, message: str, response: str, feedback: int):
235
- """Update model based on feedback"""
236
- try:
237
- self.feedback_history.append({
238
- 'message': message,
239
- 'response': response,
240
- 'feedback': feedback,
241
- 'timestamp': datetime.now().isoformat()
242
- })
 
 
 
 
 
 
 
 
 
 
243
 
244
- if len(self.feedback_history) >= 10:
245
- # Implement learning updates from feedback
246
- self.feedback_history = [] # Reset history after learning update
247
- except Exception as e:
248
- logger.error(f"Error processing feedback: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  def create_demo():
251
  """Set up Gradio interface for the chatbot with enhanced styling and functionality."""
@@ -607,17 +828,8 @@ def create_demo():
607
  raise
608
 
609
  if __name__ == "__main__":
610
- # Initialize environment
611
- load_dotenv()
612
-
613
- # Set up Hugging Face login if token exists
614
- hf_token = os.getenv("HF_TOKEN")
615
- if hf_token:
616
- login(token=hf_token)
617
-
618
- # Launch demo
619
- os.environ.pop("HF_HUB_OFFLINE", None) # Ensure online mode
620
- demo = create_demo()
621
  demo.launch(share=True)
622
 
623
 
 
1
  # app.py
2
  import os
3
+ import json
4
+ import keras
5
+ from datasets import load_dataset
6
+ import tensorflow as tf
7
+ from huggingface_hub import login
8
  import torch
9
+ from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer)
 
 
10
  from sentence_transformers import SentenceTransformer
11
+ from typing import List, Dict, Union, Tuple
12
  import faiss
13
  import numpy as np
14
+ from datasets import Dataset
15
+ import torch.nn.functional as F
16
+ from torch.cuda.amp import autocast
17
+ import gc
18
+ from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel)
19
+ from tqdm.auto import tqdm
20
+ from torch.utils.data import DataLoader
21
+ import logging
22
+ import wandb
23
+ from pathlib import Path
24
+ from typing import List, Dict, Union, Optional, Any
25
+ import torch.nn as nn
26
+ from dataclasses import dataclass, field
27
+ import time
28
+ import asyncio
29
+ import pytest
30
+ from unittest.mock import Mock, patch
31
+ from sklearn.metrics import classification_report, confusion_matrix
32
+ import gradio as gr
33
+ import matplotlib.pyplot as plt
34
  from datetime import datetime
35
+ import requests
36
+ import pandas as pd
37
+ import seaborn as sns
38
+ import traceback
39
+ from matplotlib.gridspec import GridSpec
40
+ from datasets import load_dataset, concatenate_datasets
41
+ from langchain.vectorstores import FAISS
42
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
43
+ from langchain.embeddings import HuggingFaceEmbeddings
44
+ from langchain.document_loaders import TextLoader
45
+ from google.colab import output
46
+ import IPython.display as display
47
+ from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
48
+
49
+
50
+ # Ensure Hugging Face login
51
  try:
52
+ hf_token = os.getenv("HF_TOKEN")
53
+ if hf_token:
54
+ login(token=hf_token)
55
+ print("Login successful!")
56
  except Exception as e:
57
  print("Hugging Face Login failed:", e)
58
 
59
+ # CUDA and Memory Configurations
60
+ torch.backends.cuda.matmul.allow_tf32 = False
61
+ torch.backends.cudnn.allow_tf32 = False
62
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64,garbage_collection_threshold:0.8,expandable_segments:True'
63
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
64
+
65
 
66
+ def prepare_initial_datasets(batch_size=8):
67
+ print("Loading datasets with memory-optimized batch processing...")
68
+
69
+ def process_medqa_batch(examples):
70
+ results = []
71
+ inputs = examples['input']
72
+ instructions = examples['instruction']
73
+ outputs = examples['output']
74
+
75
+ for inp, inst, out in zip(inputs, instructions, outputs):
76
+ results.append({
77
+ "input": f"{inp} {inst}",
78
+ "output": out
79
+ })
80
+ return results
81
+
82
+ def process_meddia_batch(examples):
83
+ results = []
84
+ inputs = examples['input']
85
+ outputs = examples['output']
86
+
87
+ for inp, out in zip(inputs, outputs):
88
+ results.append({
89
+ "input": inp,
90
+ "output": out
91
+ })
92
+ return results
93
+
94
+ def process_persona_batch(examples):
95
+ results = []
96
+ personalities = examples['personality']
97
+ utterances = examples['utterances']
98
+
99
+ for pers, utts in zip(personalities, utterances):
100
+ try:
101
+ # Process personality list
102
+ personality = ' '.join([
103
+ p for p in pers
104
+ if isinstance(p, str)
105
+ ])
106
+
107
+ # Process utterances
108
+ if utts and len(utts) > 0:
109
+ utterance = utts[0]
110
+ history = []
111
+
112
+ # Process history
113
+ if 'history' in utterance and utterance['history']:
114
+ history = [
115
+ h for h in utterance['history']
116
+ if isinstance(h, str)
117
+ ]
118
+
119
+ history_text = ' '.join(history)
120
+
121
+ # Get candidate response
122
+ candidate = utterance.get('candidates', [''])[0] if utterance.get('candidates') else ''
123
+
124
+ if personality or history_text:
125
+ results.append({
126
+ "input": f"{personality} {history_text}".strip(),
127
+ "output": candidate
128
+ })
129
+ except Exception as e:
130
+ print(f"Error processing persona batch item: {e}")
131
+ continue
132
 
133
+ return results
134
+
135
+ # Load and process each dataset separately
136
+ print("Processing MedQA dataset...")
137
+ medqa = load_dataset("medalpaca/medical_meadow_medqa", split="train[:500]")
138
+ medqa_processed = []
139
 
140
+ for i in tqdm(range(0, len(medqa), batch_size), desc="Processing MedQA"):
141
+ batch = medqa[i:i + batch_size]
142
+ medqa_processed.extend(process_medqa_batch(batch))
143
+ if i % (batch_size * 5) == 0:
144
+ torch.cuda.empty_cache()
 
 
145
 
146
+ print("Processing MedDiagnosis dataset...")
147
+ meddia = load_dataset("wasiqnauman/medical-diagnosis-synthetic", split="train[:500]")
148
+ meddia_processed = []
 
 
 
 
 
149
 
150
+ for i in tqdm(range(0, len(meddia), batch_size), desc="Processing MedDiagnosis"):
151
+ batch = meddia[i:i + batch_size]
152
+ meddia_processed.extend(process_meddia_batch(batch))
153
+ if i % (batch_size * 5) == 0:
154
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
155
 
156
+ print("Processing Persona-Chat dataset...")
157
+ persona = load_dataset("AlekseyKorshuk/persona-chat", split="train[:500]")
158
+ persona_processed = []
159
+
160
+ for i in tqdm(range(0, len(persona), batch_size), desc="Processing Persona-Chat"):
161
+ batch = persona[i:i + batch_size]
162
+ persona_processed.extend(process_persona_batch(batch))
163
+ if i % (batch_size * 5) == 0:
164
+ torch.cuda.empty_cache()
165
+
166
+ torch.cuda.empty_cache()
167
+
168
+ print("Creating final dataset...")
169
+ all_processed = persona_processed + medqa_processed + meddia_processed
170
+
171
+ valid_data = {
172
+ "input": [],
173
+ "output": []
174
+ }
175
+
176
+ for item in all_processed:
177
+ if item["input"].strip() and item["output"].strip():
178
+ valid_data["input"].append(item["input"])
179
+ valid_data["output"].append(item["output"])
180
+
181
+ final_dataset = Dataset.from_dict(valid_data)
182
+
183
+ print(f"Final dataset size: {len(final_dataset)}")
184
+ return final_dataset
185
+
186
+ def prepare_dataset(dataset, tokenizer, max_length=256, batch_size=4):
187
+ def tokenize_batch(examples):
188
+ formatted_texts = []
189
+
190
+ for i in range(0, len(examples['input']), batch_size):
191
+ sub_batch_inputs = examples['input'][i:i + batch_size]
192
+ sub_batch_outputs = examples['output'][i:i + batch_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ for input_text, output_text in zip(sub_batch_inputs, sub_batch_outputs):
195
+ try:
196
+ formatted_text = f"""<start_of_turn>user
197
+ {input_text}
198
+ <end_of_turn>
199
+ <start_of_turn>assistant
200
+ {output_text}
201
+ <end_of_turn>"""
202
+ formatted_texts.append(formatted_text)
203
+ except Exception as e:
204
+ print(f"Error formatting text: {e}")
205
+ continue
206
+
207
+ tokenized = tokenizer(
208
+ formatted_texts,
209
+ padding="max_length",
210
+ truncation=True,
211
+ max_length=max_length,
212
+ return_tensors=None
213
+ )
214
+
215
+ tokenized["labels"] = tokenized["input_ids"].copy()
216
+ return tokenized
217
+
218
+ print(f"Tokenizing dataset in small batches (size={batch_size})...")
219
+ tokenized_dataset = dataset.map(
220
+ tokenize_batch,
221
+ batched=True,
222
+ batch_size=batch_size,
223
+ remove_columns=dataset.column_names,
224
+ desc="Tokenizing dataset",
225
+ load_from_cache_file=False
226
+ )
227
+
228
+ return tokenized_dataset
229
 
230
+ def setup_model_and_tokenizer(model_name="google/gemma-2b"):
231
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
232
+ tokenizer.pad_token = tokenizer.eos_token
233
+
234
+ from transformers import BitsAndBytesConfig
235
+
236
+ bnb_config = BitsAndBytesConfig(
237
+ load_in_8bit=True,
238
+ bnb_8bit_compute_dtype=torch.float16,
239
+ llm_int8_enable_fp32_cpu_offload=True
240
+ )
241
+
242
+ model = AutoModelForCausalLM.from_pretrained(
243
+ model_name,
244
+ device_map="auto",
245
+ quantization_config=bnb_config,
246
+ torch_dtype=torch.float16,
247
+ low_cpu_mem_usage=True
248
+ )
249
+
250
+ model = prepare_model_for_kbit_training(model)
251
+
252
+ lora_config = LoraConfig(
253
+ r=4,
254
+ lora_alpha=16,
255
+ target_modules=["q_proj", "v_proj"],
256
+ lora_dropout=0.05,
257
+ bias="none",
258
+ task_type="CAUSAL_LM"
259
+ )
260
+
261
+ model = get_peft_model(model, lora_config)
262
+ model.print_trainable_parameters()
263
+
264
+ return model, tokenizer
265
+
266
+ def setup_training_arguments(output_dir="./pearly_fine_tuned"):
267
+ return TrainingArguments(
268
+ output_dir=output_dir,
269
+ num_train_epochs=1,
270
+ per_device_train_batch_size=1,
271
+ gradient_accumulation_steps=16,
272
+ warmup_steps=50,
273
+ logging_steps=10,
274
+ save_steps=200,
275
+ learning_rate=2e-4,
276
+ fp16=True,
277
+ gradient_checkpointing=True,
278
+ gradient_checkpointing_kwargs={"use_reentrant": False},
279
+ optim="adamw_8bit",
280
+ max_grad_norm=0.3,
281
+ weight_decay=0.001,
282
+ logging_dir="./logs",
283
+ save_total_limit=2,
284
+ remove_unused_columns=False,
285
+ dataloader_pin_memory=False,
286
+ max_steps=500,
287
+ report_to=["none"],
288
+ )
289
+
290
+ def main():
291
+ torch.backends.cuda.matmul.allow_tf32 = False
292
+ torch.backends.cudnn.allow_tf32 = False
293
+
294
+ torch.cuda.empty_cache()
295
+ if torch.cuda.is_available():
296
+ torch.cuda.reset_peak_memory_stats()
297
+
298
+ print("Preparing initial datasets...")
299
+ combined_dataset = prepare_initial_datasets(batch_size=4)
300
+
301
+ print(f"\nDataset size: {len(combined_dataset)}")
302
+ print(f"Column names: {combined_dataset.column_names}")
303
+
304
+ if len(combined_dataset) > 0:
305
+ print("\nSample input-output pair:")
306
+ print(f"Input: {combined_dataset[0]['input'][:100]}...")
307
+ print(f"Output: {combined_dataset[0]['output'][:100]}...")
308
+
309
+ print("\nSetting up model and tokenizer...")
310
+ model, tokenizer = setup_model_and_tokenizer()
311
+
312
+ print("\nPreparing dataset for training...")
313
+ processed_dataset = prepare_dataset(
314
+ combined_dataset,
315
+ tokenizer,
316
+ max_length=256,
317
+ batch_size=2
318
+ )
319
+
320
+ torch.cuda.empty_cache()
321
+
322
+ training_args = setup_training_arguments()
323
+
324
+ trainer = Trainer(
325
+ model=model,
326
+ args=training_args,
327
+ train_dataset=processed_dataset,
328
+ tokenizer=tokenizer,
329
+ )
330
+
331
+ print("\nStarting training...")
332
+ try:
333
+ trainer.train()
334
+ except Exception as e:
335
+ print(f"Training error: {e}")
336
+ torch.cuda.empty_cache()
337
+ raise e
338
+ finally:
339
+ torch.cuda.empty_cache()
340
+
341
+ print("\nSaving model...")
342
+ trainer.save_model()
343
+ print("Training completed!")
344
+
345
+ DISCLAIMER = """
346
+ IMPORTANT MEDICAL DISCLAIMER:
347
+ Pearly is an AI medical triage assistant designed to help direct you to appropriate medical services.
348
+ Pearly DOES NOT:
349
+ - Make medical diagnoses
350
+ - Prescribe medications
351
+ - Provide specific treatment recommendations
352
+ - Replace professional medical advice
353
+
354
+ Always consult qualified healthcare professionals for medical advice and treatment.
355
+ In case of emergency, call 999 immediately.
356
+ """
357
+
358
+ class PearlyBot:
359
+ def __init__(self, model_path="./pearly_fine_tuned", embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
360
+ print("Loading saved model...")
361
+ print(DISCLAIMER)
362
+
363
+ # Clean memory
364
+ if torch.cuda.is_available():
365
+ torch.cuda.empty_cache()
366
+
367
+ # Load tokenizer and model directly from saved path
368
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
369
+ self.model = AutoModelForCausalLM.from_pretrained(
370
+ model_path,
371
+ torch_dtype=torch.float16,
372
+ low_cpu_mem_usage=True,
373
+ device_map="auto"
374
+ )
375
+
376
+ self.model.eval() # Set to evaluation mode
377
+
378
+ # Initialize RAG components
379
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
380
+ self.vector_store = None
381
+ self.conversation_history = []
382
+
383
+ def initialize_rag(self, documents_path="./knowledge_base"):
384
+ """Initialize RAG system"""
385
+ print("Loading knowledge base...")
386
+
387
+ text_splitter = RecursiveCharacterTextSplitter(
388
+ chunk_size=300,
389
+ chunk_overlap=100,
390
+ separators=["\n\n", "\n", ".", "!", "?", ":"]
391
+ )
392
+
393
+ documents = []
394
+ for filename in os.listdir(documents_path):
395
+ if filename.endswith('.txt'):
396
+ loader = TextLoader(os.path.join(documents_path, filename))
397
+ documents.extend(loader.load())
398
+
399
+ texts = text_splitter.split_documents(documents)
400
+ self.vector_store = FAISS.from_documents(texts, self.embeddings)
401
+ self.retriever = self.vector_store.as_retriever(
402
+ search_type="similarity",
403
+ search_kwargs={"k": 5}
404
+ )
405
+ print("Knowledge base loaded successfully!")
406
+
407
+ def get_relevant_context(self, user_input):
408
+ if not self.retriever:
409
+ return ""
410
+ docs = self.retriever.get_relevant_documents(user_input)
411
+ return "\n\n".join([doc.page_content for doc in docs])
412
+
413
+ def generate_response(self, user_input):
414
+ context = self.get_relevant_context(user_input)
415
+ history = "\n".join([
416
+ f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
417
+ for turn in self.conversation_history[-3:]
418
+ ])
419
+
420
+ prompt = f"""<start_of_turn>system
421
+ As Pearly, I use the following medical guidelines to help triage patients:
422
+
423
+ {context}
424
+
425
+ Previous Conversation:
426
+ {history}
427
+
428
+ Based on these guidelines, I will:
429
+ 1. Assess symptoms and severity
430
+ 2. Ask relevant follow-up questions
431
+ 3. Direct to appropriate care (999, 111, or GP)
432
+ 4. Show empathy and cultural sensitivity
433
+ 5. Never diagnose or recommend treatments
434
+ <end_of_turn>
435
+ <start_of_turn>user
436
+ {user_input}
437
+ <end_of_turn>
438
+ <start_of_turn>assistant"""
439
+
440
+ inputs = self.tokenizer(
441
+ prompt,
442
+ return_tensors="pt",
443
+ truncation=True,
444
+ max_length=512
445
+ ).to(self.model.device)
446
+
447
+ with torch.no_grad():
448
+ outputs = self.model.generate(
449
+ **inputs,
450
+ max_new_tokens=256,
451
+ min_new_tokens=20,
452
+ do_sample=True,
453
+ temperature=0.7,
454
+ top_p=0.9,
455
+ repetition_penalty=1.2,
456
+ pad_token_id=self.tokenizer.pad_token_id
457
+ )
458
+
459
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
460
+ response = response.split("<start_of_turn>assistant")[-1].strip()
461
+ if "<end_of_turn>" in response:
462
+ response = response.split("<end_of_turn>")[0].strip()
463
+
464
+ self.conversation_history.append({
465
+ "user": user_input,
466
+ "assistant": response
467
+ })
468
+
469
+ return response
470
 
471
  def create_demo():
472
  """Set up Gradio interface for the chatbot with enhanced styling and functionality."""
 
828
  raise
829
 
830
  if __name__ == "__main__":
831
+ load_dotenv() # Load environment variables
832
+ demo = create_demo() # Launch the Gradio app
 
 
 
 
 
 
 
 
 
833
  demo.launch(share=True)
834
 
835