Muhammadidrees commited on
Commit
754f306
·
verified ·
1 Parent(s): 94d2f35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -89
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import os
2
  import gc
 
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
 
 
6
 
7
  # =============================
8
  # Configuration
@@ -12,22 +16,46 @@ MAX_NEW_TOKENS = 200
12
  TEMPERATURE = 0.5
13
  TOP_K = 50
14
  REPETITION_PENALTY = 1.1
 
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"🚀 Loading model from {MODEL_PATH} on {device}...")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # ==========================
20
  # Load Model & Tokenizer
21
  # =============================
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_PATH,
25
- device_map="auto",
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- low_cpu_mem_usage=True
28
- )
29
-
30
- print("✅ ChatDoctor model loaded successfully!\n")
 
 
 
31
 
32
  # =============================
33
  # Stop Criteria
@@ -57,43 +85,113 @@ MEDICAL_KEYWORDS = [
57
  "stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
58
  "doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
59
  "illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
60
- "weight", "vitamin", "fatigue", "anxiety", "depression"
 
61
  ]
62
 
63
- CASUAL_ONLY_PATTERNS = [
64
- "hey", "hi", "hello", "sup", "yo", "good morning", "good evening",
65
- "how are you", "wassup", "hiya"
 
 
66
  ]
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def is_medical_query(message):
 
70
  message_lower = message.lower()
 
 
71
  for keyword in MEDICAL_KEYWORDS:
72
  if keyword in message_lower:
73
  return True
74
- question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does"]
75
- has_question = any(q in message_lower.split()[:3] for q in question_words)
76
- if has_question and len(message.split()) > 5:
 
 
 
 
77
  return True
 
78
  return False
79
 
80
 
81
  def is_only_greeting(message):
82
- message_lower = message.lower().strip().replace("!", "").replace("?", "").replace(".", "")
83
- if len(message_lower.split()) <= 3:
84
- for pattern in CASUAL_ONLY_PATTERNS:
85
- if message_lower == pattern or message_lower.startswith(pattern):
86
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return False
88
 
89
 
90
  # =============================
91
  # Get Response
92
  # =============================
93
- def get_response(user_input, history_context):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if is_only_greeting(user_input):
95
  return "👋 Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
96
 
 
97
  if not is_medical_query(user_input):
98
  return (
99
  "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
@@ -104,17 +202,21 @@ def get_response(user_input, history_context):
104
  "Please describe your health concern in detail to get started."
105
  )
106
 
 
107
  human_prefix = "Patient:"
108
  doctor_prefix = "ChatDoctor:"
109
  system_instruction = (
110
  "You are ChatDoctor, a professional medical AI assistant. "
111
- "You provide accurate, concise, and empathetic responses to health-related questions only.\n\n"
112
- "If the question is non-medical, politely redirect back to medical topics.\n"
 
113
  )
114
 
115
- # Build history
 
 
116
  history_text = [system_instruction]
117
- for human, assistant in history_context:
118
  if human:
119
  history_text.append(f"{human_prefix} {human}")
120
  if assistant:
@@ -122,45 +224,68 @@ def get_response(user_input, history_context):
122
  history_text.append(f"{human_prefix} {user_input}")
123
 
124
  prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
125
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
126
-
127
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
128
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
129
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
130
-
131
- with torch.no_grad():
132
- output_ids = model.generate(
133
- input_ids,
134
- max_new_tokens=MAX_NEW_TOKENS,
135
- do_sample=True,
136
- temperature=TEMPERATURE,
137
- top_k=TOP_K,
138
- repetition_penalty=REPETITION_PENALTY,
139
- stopping_criteria=stopping_criteria,
140
- pad_token_id=tokenizer.eos_token_id,
141
- eos_token_id=tokenizer.eos_token_id
142
- )
143
-
144
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
145
-
146
- for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
147
- if stop_word in response:
148
- response = response.split(stop_word)[0].strip()
149
- break
150
-
151
- response = response.strip()
152
- if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud"]):
153
- response = (
154
- "I apologize for the confusion — I'm ChatDoctor, trained to assist with medical and health-related topics only. "
155
- "Please tell me about your symptoms or health concerns."
156
- )
157
-
158
- del input_ids, output_ids
159
- gc.collect()
160
- if torch.cuda.is_available():
161
- torch.cuda.empty_cache()
162
-
163
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  # =============================
@@ -171,46 +296,82 @@ custom_css = """
171
  text-align: center;
172
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
173
  color: white;
174
- padding: 20px;
175
- border-radius: 10px;
176
  margin-bottom: 20px;
 
177
  }
178
- #header h1 { margin: 0; font-size: 2.3em; }
179
- #header p { margin: 5px 0 0; font-size: 1em; opacity: 0.9; }
180
  .disclaimer {
181
  background-color: #fff3cd;
182
- border: 1px solid #ffc107;
183
  border-radius: 8px;
184
- padding: 15px;
185
  margin: 20px 0;
186
  color: #856404;
187
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  """
189
 
190
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
 
 
191
  gr.HTML("""
192
  <div id="header">
193
  <h1>🩺 ChatDoctor AI Assistant</h1>
194
  <p>Your AI-powered medical consultation partner</p>
195
  </div>
196
  """)
 
197
  gr.HTML("""
198
  <div class="disclaimer">
199
  <h3>⚠️ Medical Disclaimer</h3>
200
- <p>This AI assistant is for informational purposes only.
201
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
 
 
 
 
 
 
 
 
 
 
202
  </div>
203
  """)
204
 
205
  chatbot = gr.Chatbot(
206
- height=480,
207
- placeholder="<div style='text-align:center;padding:40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>Describe your symptoms or ask a health-related question to begin.</p></div>",
208
  show_label=False,
209
  avatar_images=(None, "🤖"),
210
  )
211
 
212
  with gr.Row():
213
- msg = gr.Textbox(placeholder="Type your medical concern here...", show_label=False, scale=9, container=False)
 
 
 
 
 
 
214
  send_btn = gr.Button("Send 📤", scale=1, variant="primary")
215
 
216
  with gr.Row():
@@ -218,44 +379,62 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
218
  retry_btn = gr.Button("🔄 Retry", scale=1)
219
 
220
  with gr.Accordion("⚙️ Advanced Settings", open=False):
221
- temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature")
222
  max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
223
- top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K")
224
 
225
  def user_message(user_msg, history):
 
 
226
  return "", history + [[user_msg, None]]
227
 
228
- def bot_response(history, temp, max_tok, topk):
 
 
 
229
  global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
230
  TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
 
231
  user_msg = history[-1][0]
232
- bot_msg = get_response(user_msg, history[:-1])
233
  history[-1][1] = bot_msg
234
  return history
235
 
236
- def retry_last(history, temp, max_tok, topk):
237
  if not history:
238
  return history
239
  user_msg = history[-1][0]
240
- bot_msg = get_response(user_msg, history[:-1])
241
  history[-1][1] = bot_msg
242
  return history
243
 
244
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
245
- bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
246
  )
247
  send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
248
- bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
249
  )
250
  clear_btn.click(lambda: None, None, chatbot, queue=False)
251
- retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot)
252
-
253
- gr.HTML(f"<footer><center><p>🧠 Powered by LLaMA-based ChatDoctor | Device: {device.upper()}</p></center></footer>")
 
 
 
 
 
 
 
 
254
 
255
  # =============================
256
  # Launch App
257
  # =============================
258
  if __name__ == "__main__":
259
- print("\n💡 Launching ChatDoctor Gradio Interface...")
 
 
 
 
260
  demo.queue()
261
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  import os
2
  import gc
3
+ import re
4
+ import time
5
  import torch
6
  import gradio as gr
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
8
+ from collections import defaultdict
9
+ from datetime import datetime, timedelta
10
 
11
  # =============================
12
  # Configuration
 
16
  TEMPERATURE = 0.5
17
  TOP_K = 50
18
  REPETITION_PENALTY = 1.1
19
+ MAX_HISTORY_TURNS = 5 # Limit conversation history
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"🚀 Loading model from {MODEL_PATH} on {device}...")
23
 
24
+ # =============================
25
+ # Rate Limiting (Simple IP-based)
26
+ # =============================
27
+ rate_limit_store = defaultdict(list)
28
+ MAX_REQUESTS_PER_MINUTE = 10
29
+
30
+ def check_rate_limit(session_id):
31
+ """Simple rate limiting to prevent abuse"""
32
+ now = datetime.now()
33
+ rate_limit_store[session_id] = [
34
+ timestamp for timestamp in rate_limit_store[session_id]
35
+ if now - timestamp < timedelta(minutes=1)
36
+ ]
37
+
38
+ if len(rate_limit_store[session_id]) >= MAX_REQUESTS_PER_MINUTE:
39
+ return False
40
+
41
+ rate_limit_store[session_id].append(now)
42
+ return True
43
+
44
  # ==========================
45
  # Load Model & Tokenizer
46
  # =============================
47
+ try:
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_PATH,
51
+ device_map="auto",
52
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
53
+ low_cpu_mem_usage=True
54
+ )
55
+ print("✅ ChatDoctor model loaded successfully!\n")
56
+ except Exception as e:
57
+ print(f"❌ Error loading model: {e}")
58
+ raise
59
 
60
  # =============================
61
  # Stop Criteria
 
85
  "stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
86
  "doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
87
  "illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
88
+ "weight", "vitamin", "fatigue", "anxiety", "depression", "nausea", "dizziness",
89
+ "rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak"
90
  ]
91
 
92
+ # Emergency keywords that should trigger immediate medical attention warning
93
+ EMERGENCY_KEYWORDS = [
94
+ "suicide", "kill myself", "end my life", "chest pain", "can't breathe",
95
+ "severe bleeding", "overdose", "poisoning", "unconscious", "seizure",
96
+ "stroke", "heart attack", "choking"
97
  ]
98
 
99
+ CASUAL_PATTERNS = [
100
+ r"^(hey|hi|hello|sup|yo|wassup|hiya)\s*[\?\!\.]*$",
101
+ r"^good\s+(morning|evening|afternoon|night)\s*[\?\!\.]*$",
102
+ r"^how\s+are\s+you\s*[\?\!\.]*$",
103
+ r"^what'?s\s+up\s*[\?\!\.]*$",
104
+ ]
105
+
106
+
107
+ def is_emergency_query(message):
108
+ """Detect if query contains emergency keywords"""
109
+ message_lower = message.lower()
110
+ return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS)
111
+
112
 
113
  def is_medical_query(message):
114
+ """Enhanced medical query detection"""
115
  message_lower = message.lower()
116
+
117
+ # Check for medical keywords
118
  for keyword in MEDICAL_KEYWORDS:
119
  if keyword in message_lower:
120
  return True
121
+
122
+ # Check for question patterns with sufficient length
123
+ question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"]
124
+ words = message_lower.split()
125
+ has_question = any(q in words[:4] for q in question_words)
126
+
127
+ if has_question and len(words) > 5:
128
  return True
129
+
130
  return False
131
 
132
 
133
  def is_only_greeting(message):
134
+ """Improved greeting detection using regex"""
135
+ message_clean = message.lower().strip()
136
+
137
+ # Remove punctuation for matching
138
+ message_clean = re.sub(r'[!?.]+$', '', message_clean)
139
+
140
+ # Check if it matches any casual pattern
141
+ for pattern in CASUAL_PATTERNS:
142
+ if re.match(pattern, message_clean):
143
+ return True
144
+
145
+ return False
146
+
147
+
148
+ # =============================
149
+ # Safety Filter
150
+ # =============================
151
+ DANGEROUS_PATTERNS = [
152
+ r"take\s+\d+\s+(pills|tablets|capsules)",
153
+ r"inject\s+(yourself|myself)",
154
+ r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)",
155
+ r"ignore\s+(doctor|medical|professional)",
156
+ ]
157
+
158
+ def contains_dangerous_advice(response):
159
+ """Check if response contains potentially dangerous medical advice"""
160
+ response_lower = response.lower()
161
+
162
+ for pattern in DANGEROUS_PATTERNS:
163
+ if re.search(pattern, response_lower):
164
+ return True
165
+
166
  return False
167
 
168
 
169
  # =============================
170
  # Get Response
171
  # =============================
172
+ def get_response(user_input, history_context, session_id="default"):
173
+ """Generate response with enhanced safety and quality checks"""
174
+
175
+ # Rate limiting check
176
+ if not check_rate_limit(session_id):
177
+ return "⏰ You've made too many requests. Please wait a minute before trying again."
178
+
179
+ # Emergency detection
180
+ if is_emergency_query(user_input):
181
+ return (
182
+ "🚨 **EMERGENCY DETECTED** 🚨\n\n"
183
+ "If you are experiencing a medical emergency, please:\n"
184
+ "• Call emergency services immediately (911 in US, 999 in UK, 112 in EU)\n"
185
+ "• Go to the nearest emergency room\n"
186
+ "• Contact your local emergency hotline\n\n"
187
+ "This AI cannot provide emergency medical care. Please seek immediate professional help."
188
+ )
189
+
190
+ # Greeting detection
191
  if is_only_greeting(user_input):
192
  return "👋 Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
193
 
194
+ # Non-medical query handling
195
  if not is_medical_query(user_input):
196
  return (
197
  "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
 
202
  "Please describe your health concern in detail to get started."
203
  )
204
 
205
+ # Build prompt with limited history
206
  human_prefix = "Patient:"
207
  doctor_prefix = "ChatDoctor:"
208
  system_instruction = (
209
  "You are ChatDoctor, a professional medical AI assistant. "
210
+ "You provide accurate, concise, and empathetic responses to health-related questions only.\n"
211
+ "Always recommend consulting a healthcare professional for serious conditions.\n"
212
+ "Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n"
213
  )
214
 
215
+ # Limit history to prevent token overflow
216
+ limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context
217
+
218
  history_text = [system_instruction]
219
+ for human, assistant in limited_history:
220
  if human:
221
  history_text.append(f"{human_prefix} {human}")
222
  if assistant:
 
224
  history_text.append(f"{human_prefix} {user_input}")
225
 
226
  prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
227
+
228
+ try:
229
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
230
+
231
+ # Stop words for cleaner output
232
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
233
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
234
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
235
+
236
+ with torch.no_grad():
237
+ output_ids = model.generate(
238
+ input_ids,
239
+ max_new_tokens=MAX_NEW_TOKENS,
240
+ do_sample=True,
241
+ temperature=TEMPERATURE,
242
+ top_k=TOP_K,
243
+ repetition_penalty=REPETITION_PENALTY,
244
+ stopping_criteria=stopping_criteria,
245
+ pad_token_id=tokenizer.eos_token_id,
246
+ eos_token_id=tokenizer.eos_token_id
247
+ )
248
+
249
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
250
+
251
+ # Clean up response
252
+ for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
253
+ if stop_word in response:
254
+ response = response.split(stop_word)[0].strip()
255
+ break
256
+
257
+ response = response.strip()
258
+
259
+ # Safety filter
260
+ if contains_dangerous_advice(response):
261
+ response = (
262
+ "I apologize, but I cannot provide that specific medical advice. "
263
+ "Please consult with a qualified healthcare professional who can properly evaluate your situation."
264
+ )
265
+
266
+ # Filter out inappropriate content
267
+ if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]):
268
+ response = (
269
+ "I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. "
270
+ "Please tell me more about your symptoms or health concerns so I can help you better."
271
+ )
272
+
273
+ # Add disclaimer for serious conditions
274
+ serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"]
275
+ if any(condition in response.lower() for condition in serious_conditions):
276
+ response += "\n\n⚠️ **Important:** Please consult a healthcare professional for proper diagnosis and treatment."
277
+
278
+ # Clean up memory
279
+ del input_ids, output_ids
280
+ gc.collect()
281
+ if torch.cuda.is_available():
282
+ torch.cuda.empty_cache()
283
+
284
+ return response
285
+
286
+ except Exception as e:
287
+ print(f"Error generating response: {e}")
288
+ return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later."
289
 
290
 
291
  # =============================
 
296
  text-align: center;
297
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
298
  color: white;
299
+ padding: 25px;
300
+ border-radius: 12px;
301
  margin-bottom: 20px;
302
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
303
  }
304
+ #header h1 { margin: 0; font-size: 2.5em; font-weight: 700; }
305
+ #header p { margin: 5px 0 0; font-size: 1.1em; opacity: 0.95; }
306
  .disclaimer {
307
  background-color: #fff3cd;
308
+ border-left: 4px solid #ffc107;
309
  border-radius: 8px;
310
+ padding: 18px;
311
  margin: 20px 0;
312
  color: #856404;
313
  }
314
+ .disclaimer h3 { margin-top: 0; color: #d39e00; }
315
+ .emergency-warning {
316
+ background-color: #f8d7da;
317
+ border-left: 4px solid #dc3545;
318
+ border-radius: 8px;
319
+ padding: 15px;
320
+ margin: 15px 0;
321
+ color: #721c24;
322
+ }
323
+ footer {
324
+ margin-top: 30px;
325
+ padding: 15px;
326
+ text-align: center;
327
+ color: #6c757d;
328
+ font-size: 0.9em;
329
+ }
330
  """
331
 
332
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
333
+ session_state = gr.State(value=str(time.time())) # Unique session ID
334
+
335
  gr.HTML("""
336
  <div id="header">
337
  <h1>🩺 ChatDoctor AI Assistant</h1>
338
  <p>Your AI-powered medical consultation partner</p>
339
  </div>
340
  """)
341
+
342
  gr.HTML("""
343
  <div class="disclaimer">
344
  <h3>⚠️ Medical Disclaimer</h3>
345
+ <p><strong>This AI assistant is for informational purposes only.</strong>
346
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
347
+ Always seek the advice of your physician or qualified health provider with any questions
348
+ you may have regarding a medical condition.</p>
349
+ </div>
350
+ """)
351
+
352
+ gr.HTML("""
353
+ <div class="emergency-warning">
354
+ <h4>🚨 In Case of Emergency</h4>
355
+ <p>If you are experiencing a medical emergency, call emergency services immediately
356
+ (911 in US, 999 in UK, 112 in EU) or go to the nearest emergency room.</p>
357
  </div>
358
  """)
359
 
360
  chatbot = gr.Chatbot(
361
+ height=500,
362
+ placeholder="<div style='text-align:center;padding:50px;'><h3>👋 Welcome to ChatDoctor!</h3><p style='color:#6c757d;'>Describe your symptoms or ask a health-related question to begin.</p><p style='color:#dc3545;margin-top:15px;'><strong>Remember:</strong> This is not a replacement for professional medical care.</p></div>",
363
  show_label=False,
364
  avatar_images=(None, "🤖"),
365
  )
366
 
367
  with gr.Row():
368
+ msg = gr.Textbox(
369
+ placeholder="Type your medical concern here... (e.g., 'I have a headache for 3 days')",
370
+ show_label=False,
371
+ scale=9,
372
+ container=False,
373
+ lines=1
374
+ )
375
  send_btn = gr.Button("Send 📤", scale=1, variant="primary")
376
 
377
  with gr.Row():
 
379
  retry_btn = gr.Button("🔄 Retry", scale=1)
380
 
381
  with gr.Accordion("⚙️ Advanced Settings", open=False):
382
+ temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)")
383
  max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
384
+ top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling")
385
 
386
  def user_message(user_msg, history):
387
+ if not user_msg.strip():
388
+ return "", history
389
  return "", history + [[user_msg, None]]
390
 
391
+ def bot_response(history, temp, max_tok, topk, session_id):
392
+ if not history or history[-1][1] is not None:
393
+ return history
394
+
395
  global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
396
  TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
397
+
398
  user_msg = history[-1][0]
399
+ bot_msg = get_response(user_msg, history[:-1], session_id)
400
  history[-1][1] = bot_msg
401
  return history
402
 
403
+ def retry_last(history, temp, max_tok, topk, session_id):
404
  if not history:
405
  return history
406
  user_msg = history[-1][0]
407
+ bot_msg = get_response(user_msg, history[:-1], session_id)
408
  history[-1][1] = bot_msg
409
  return history
410
 
411
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
412
+ bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
413
  )
414
  send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
415
+ bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
416
  )
417
  clear_btn.click(lambda: None, None, chatbot, queue=False)
418
+ retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot)
419
+
420
+ gr.HTML(f"""
421
+ <footer>
422
+ <p><strong>🧠 Powered by LLaMA-based ChatDoctor</strong></p>
423
+ <p>Device: {device.upper()} | Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute</p>
424
+ <p style='font-size:0.85em;margin-top:10px;'>
425
+ This AI provides general health information only. Always consult healthcare professionals for medical advice.
426
+ </p>
427
+ </footer>
428
+ """)
429
 
430
  # =============================
431
  # Launch App
432
  # =============================
433
  if __name__ == "__main__":
434
+ print("\n💡 Launching Enhanced ChatDoctor Gradio Interface...")
435
+ print(f"📊 Configuration:")
436
+ print(f" - Max History Turns: {MAX_HISTORY_TURNS}")
437
+ print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute")
438
+ print(f" - Device: {device.upper()}")
439
  demo.queue()
440
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)