lukiod commited on
Commit
a548a89
Β·
verified Β·
1 Parent(s): b88326f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -222
app.py CHANGED
@@ -1,238 +1,211 @@
1
  import gradio as gr
2
- import pandas as pd
3
- from datetime import datetime
4
  import torch
5
- from transformers import T5Tokenizer, T5ForConditionalGeneration
6
- import gc
7
- from typing import List, Dict
8
- import os
9
- import time
10
  import logging
 
 
11
 
12
  # Setup logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # Disable gradient computation and set memory efficient settings
17
- torch.set_grad_enabled(False)
18
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
19
-
20
- # Create cache directory
21
- os.makedirs("model_cache", exist_ok=True)
22
-
23
- class ModelHandler:
24
  def __init__(self):
25
- self.model_name = "google/flan-t5-small" # Small model for Spaces
26
- self.device = "cpu"
27
- self.initialized = False
28
- self.load_attempts = 0
29
- self.max_attempts = 3
 
30
  self.initialize_model()
31
 
32
  def initialize_model(self):
33
- while not self.initialized and self.load_attempts < self.max_attempts:
34
- try:
35
- logger.info(f"Loading model attempt {self.load_attempts + 1}")
36
- self.tokenizer = T5Tokenizer.from_pretrained(
37
- self.model_name,
38
- model_max_length=512,
39
- cache_dir="model_cache"
40
- )
41
- self.model = T5ForConditionalGeneration.from_pretrained(
42
- self.model_name,
43
- low_cpu_mem_usage=True,
44
- cache_dir="model_cache"
45
- )
46
- self.initialized = True
47
- logger.info("Model loaded successfully")
48
- return True
49
- except Exception as e:
50
- logger.error(f"Loading attempt failed: {str(e)}")
51
- self.load_attempts += 1
52
- time.sleep(1)
53
- return False
54
-
55
- def generate_response(self, prompt: str, max_length: int = 256) -> str:
56
- if not self.initialized:
57
- return "Model initialization failed. Using basic responses."
58
-
59
  try:
60
- clean_prompt = prompt.strip()
61
- if len(clean_prompt) == 0:
62
- return "Please provide a valid question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- inputs = self.tokenizer(
65
- clean_prompt,
66
- max_length=512,
67
- truncation=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  padding=True,
69
  return_tensors="pt"
70
  )
 
 
 
 
 
 
 
 
 
 
71
 
72
- with torch.no_grad():
73
- outputs = self.model.generate(
74
- input_ids=inputs["input_ids"],
75
- max_length=max_length,
76
- min_length=10,
77
- num_beams=1,
78
- do_sample=True,
79
- temperature=0.7,
80
- top_k=50,
81
- top_p=0.95,
82
- )
83
-
84
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
-
86
- del outputs, inputs
87
  gc.collect()
88
-
89
- return response if response else "Could not generate a response."
90
-
 
91
  except Exception as e:
92
- logger.error(f"Generation error: {str(e)}")
93
- return self.get_fallback_response(prompt)
94
 
95
- def get_fallback_response(self, query: str) -> str:
96
- responses = {
97
- "hello": "Hello! I'm your health assistant.",
98
- "help": "I can help with health information and tracking.",
99
- "health": "I provide general health information.",
100
- "sleep": "Aim for 7-9 hours of sleep daily.",
101
- "exercise": "Regular exercise is important for health.",
102
- "diet": "Eat a balanced diet with plenty of vegetables.",
103
- "medication": "Always follow prescribed medication schedules.",
104
- "water": "Stay hydrated by drinking plenty of water daily.",
105
- "stress": "Managing stress is important for overall health."
106
- }
107
 
108
- query = query.lower()
109
- for key, response in responses.items():
110
- if key in query:
111
- return response
112
- return "I understand you have a health question. Please try rephrasing it simply."
 
 
113
 
114
- class HealthData:
115
- def __init__(self):
116
- self.metrics = []
117
- self.medications = []
 
 
 
118
 
119
- def add_metrics(self, metrics: Dict) -> bool:
120
- try:
121
- self.metrics.append({
122
- 'Date': datetime.now().strftime('%Y-%m-%d'),
123
- **metrics
124
- })
125
- return True
126
- except Exception as e:
127
- logger.error(f"Error adding metrics: {str(e)}")
128
- return False
129
 
130
- def add_medication(self, medication: Dict) -> bool:
131
- try:
132
- self.medications.append(medication)
133
- return True
134
- except Exception as e:
135
- logger.error(f"Error adding medication: {str(e)}")
136
- return False
137
 
138
- def get_health_context(self) -> str:
 
139
  context_parts = []
140
 
141
  if self.metrics:
142
  latest = self.metrics[-1]
143
  context_parts.extend([
144
- f"Recent Health Metrics (Date: {latest['Date']}):",
145
- f"- Weight: {latest['Weight']} kg",
146
- f"- Steps: {latest['Steps']}",
147
- f"- Sleep: {latest['Sleep']} hours"
148
  ])
149
 
150
  if self.medications:
151
  context_parts.append("\nCurrent Medications:")
152
  for med in self.medications:
153
  med_info = f"- {med['Medication']} ({med['Dosage']}) at {med['Time']}"
154
- if med['Notes']:
155
  med_info += f" | Note: {med['Notes']}"
156
  context_parts.append(med_info)
157
 
158
- return "\n".join(context_parts) if context_parts else "No health data available."
159
-
160
- class HealthAssistant:
161
- def __init__(self):
162
- self.model = ModelHandler()
163
- self.data = HealthData()
164
- self.request_count = 0
165
 
166
- def get_response(self, message: str, history: List = None) -> str:
167
  try:
168
- self.request_count += 1
169
-
170
- # Prepare context
171
- context = self.data.get_health_context()
172
-
173
- # Format prompt
174
- prompt = (
175
- f"Context: {context}\n\n"
176
- f"Question: {message}\n\n"
177
- "Provide a helpful and accurate health-related response."
178
- )
179
 
180
- # Get response
181
- response = self.model.generate_response(prompt)
182
-
183
- # Periodic cleanup
184
- if self.request_count % 5 == 0:
185
- gc.collect()
186
-
187
- return response
188
-
189
  except Exception as e:
190
- logger.error(f"Error in get_response: {str(e)}")
191
- return self.model.get_fallback_response(message)
192
 
193
- class HealthAssistantUI:
194
  def __init__(self):
195
  self.assistant = HealthAssistant()
196
 
197
- def user_chat(self, message: str, history: List) -> tuple:
198
- if message.strip() == "":
199
  return "", history
200
 
201
- bot_message = self.assistant.get_response(message)
202
- history.append([message, bot_message])
203
  return "", history
204
 
205
- def save_metrics(self, weight: float, steps: int, sleep: float) -> tuple:
206
- if not all([weight is not None, steps is not None, sleep is not None]):
207
- return "⚠️ Please fill in all metrics.", None
208
 
209
- metrics = {'Weight': weight, 'Steps': steps, 'Sleep': sleep}
210
- if self.assistant.data.add_metrics(metrics):
211
- df = pd.DataFrame(self.assistant.data.metrics)
212
- return "βœ… Metrics saved successfully!", df
213
- return "❌ Error saving metrics", None
214
 
215
- def save_medication(self, name: str, dosage: str, time: str, notes: str) -> tuple:
216
  if not all([name, dosage, time]):
217
- return "⚠️ Please fill in all required fields.", None
218
 
219
- medication = {
220
- 'Medication': name,
221
- 'Dosage': dosage,
222
- 'Time': time,
223
- 'Notes': notes or ''
224
- }
225
- if self.assistant.data.add_medication(medication):
226
- df = pd.DataFrame(self.assistant.data.medications)
227
- return "βœ… Medication added successfully!", df
228
- return "❌ Error adding medication", None
229
 
230
  def create_interface(self):
231
  with gr.Blocks(title="Health Assistant", theme=gr.themes.Soft()) as demo:
232
  gr.Markdown(
233
  """
234
- # πŸ₯ Health Assistant
235
- Your AI-powered health companion. Track metrics, manage medications, and get health information.
236
  """
237
  )
238
 
@@ -241,11 +214,11 @@ class HealthAssistantUI:
241
  with gr.Tab("πŸ’¬ Health Chat"):
242
  chatbot = gr.Chatbot(
243
  height=450,
244
- show_label=False,
245
  )
246
  with gr.Row():
247
  msg = gr.Textbox(
248
- placeholder="Type your health question... (Press Enter)",
249
  lines=2,
250
  show_label=False,
251
  scale=9
@@ -253,86 +226,63 @@ class HealthAssistantUI:
253
  send_btn = gr.Button("Send", scale=1)
254
  clear_btn = gr.Button("Clear Chat")
255
 
256
- # Health Metrics Tab
257
  with gr.Tab("πŸ“Š Health Metrics"):
258
  with gr.Row():
259
- with gr.Column():
260
- weight_input = gr.Number(label="Weight (kg)")
261
- steps_input = gr.Number(label="Steps")
262
- sleep_input = gr.Number(label="Hours Slept")
263
- metrics_btn = gr.Button("Save Metrics")
264
- metrics_status = gr.Markdown()
265
-
266
- with gr.Column():
267
- metrics_display = gr.Dataframe(
268
- headers=["Date", "Weight", "Steps", "Sleep"]
269
- )
270
 
271
- # Medication Manager Tab
272
  with gr.Tab("πŸ’Š Medication Manager"):
273
  with gr.Row():
274
- with gr.Column():
275
- med_name = gr.Textbox(label="Medication Name")
276
- med_dosage = gr.Textbox(label="Dosage")
277
- med_time = gr.Textbox(label="Time")
278
- med_notes = gr.Textbox(label="Notes (optional)")
279
- med_btn = gr.Button("Add Medication")
280
- med_status = gr.Markdown()
281
-
282
- with gr.Column():
283
- meds_display = gr.Dataframe(
284
- headers=["Medication", "Dosage", "Time", "Notes"]
285
- )
286
 
287
  # Event handlers
288
- msg.submit(self.user_chat, [msg, chatbot], [msg, chatbot])
289
- send_btn.click(self.user_chat, [msg, chatbot], [msg, chatbot])
290
  clear_btn.click(lambda: [], None, chatbot)
291
 
292
  metrics_btn.click(
293
- self.save_metrics,
294
  inputs=[weight_input, steps_input, sleep_input],
295
- outputs=[metrics_status, metrics_display]
296
  )
297
 
298
  med_btn.click(
299
- self.save_medication,
300
  inputs=[med_name, med_dosage, med_time, med_notes],
301
- outputs=[med_status, meds_display]
302
  )
303
 
304
  gr.Markdown(
305
  """
306
  ### ⚠️ Important Note
307
- This is an AI assistant for general health information only.
308
  Always consult healthcare professionals for medical advice.
309
  """
310
  )
311
 
312
  return demo
313
 
314
- def cleanup():
315
- """Cleanup function for memory management"""
316
- gc.collect()
317
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
318
-
319
  def main():
320
  try:
321
- logger.info("Starting Health Assistant")
322
- ui = HealthAssistantUI()
323
- demo = ui.create_interface()
324
-
325
- # Register cleanup
326
- demo.load(cleanup)
327
-
328
- # Launch app
329
  demo.launch(
330
  share=False,
331
  enable_queue=True,
332
  max_threads=4
333
  )
334
  except Exception as e:
335
- logger.error(f"Error starting app: {str(e)}")
336
 
337
  if __name__ == "__main__":
338
  main()
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
 
 
 
5
  import logging
6
+ from typing import List, Dict
7
+ import gc
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ class HealthAssistant:
 
 
 
 
 
 
 
14
  def __init__(self):
15
+ self.model_name = "Qwen/Qwen2-VL-7B-Instruct"
16
+ self.model = None
17
+ self.tokenizer = None
18
+ self.processor = None
19
+ self.metrics = []
20
+ self.medications = []
21
  self.initialize_model()
22
 
23
  def initialize_model(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
+ logger.info("Loading Qwen2-VL model...")
26
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
27
+ self.model_name,
28
+ torch_dtype=torch.bfloat16,
29
+ attn_implementation="flash_attention_2",
30
+ device_map="auto"
31
+ )
32
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
33
+ self.processor = AutoProcessor.from_pretrained(
34
+ self.model_name,
35
+ min_pixels=256*28*28,
36
+ max_pixels=1280*28*28
37
+ )
38
+ logger.info("Model loaded successfully")
39
+ except Exception as e:
40
+ logger.error(f"Error loading model: {e}")
41
+ raise
42
 
43
+ def generate_response(self, message: str, history: List = None) -> str:
44
+ try:
45
+ # Format conversation with health context
46
+ messages = self._format_messages(message, history)
47
+
48
+ # Prepare for inference
49
+ text = self.processor.apply_chat_template(
50
+ messages,
51
+ tokenize=False,
52
+ add_generation_prompt=True
53
+ )
54
+
55
+ # Since we're not using images in this case
56
+ image_inputs, video_inputs = [], []
57
+
58
+ # Process inputs
59
+ inputs = self.processor(
60
+ text=[text],
61
+ images=image_inputs,
62
+ videos=video_inputs,
63
  padding=True,
64
  return_tensors="pt"
65
  )
66
+ inputs = inputs.to(self.model.device)
67
+
68
+ # Generate response
69
+ generated_ids = self.model.generate(
70
+ **inputs,
71
+ max_new_tokens=256,
72
+ do_sample=True,
73
+ temperature=0.7,
74
+ top_p=0.9
75
+ )
76
 
77
+ # Decode response
78
+ generated_ids_trimmed = [
79
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
80
+ ]
81
+ output_text = self.processor.batch_decode(
82
+ generated_ids_trimmed,
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False
85
+ )[0]
86
+
87
+ # Cleanup
88
+ del inputs, generated_ids, generated_ids_trimmed
 
 
 
89
  gc.collect()
90
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
91
+
92
+ return output_text.strip()
93
+
94
  except Exception as e:
95
+ logger.error(f"Error generating response: {e}")
96
+ return "I apologize, but I encountered an error. Please try again."
97
 
98
+ def _format_messages(self, message: str, history: List = None) -> List[Dict]:
99
+ """Format messages for the Qwen2-VL model"""
100
+ # Add system context
101
+ messages = []
 
 
 
 
 
 
 
 
102
 
103
+ # Add health context
104
+ health_context = self._get_health_context()
105
+ if health_context:
106
+ messages.append({
107
+ "role": "system",
108
+ "content": [{"type": "text", "text": f"Current health information:\n{health_context}"}]
109
+ })
110
 
111
+ # Add conversation history
112
+ if history:
113
+ for user_msg, assistant_msg in history[-3:]: # Last 3 exchanges
114
+ messages.extend([
115
+ {"role": "user", "content": [{"type": "text", "text": user_msg}]},
116
+ {"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}
117
+ ])
118
 
119
+ # Add current message
120
+ messages.append({
121
+ "role": "user",
122
+ "content": [{"type": "text", "text": message}]
123
+ })
 
 
 
 
 
124
 
125
+ return messages
 
 
 
 
 
 
126
 
127
+ def _get_health_context(self) -> str:
128
+ """Get health metrics and medications context"""
129
  context_parts = []
130
 
131
  if self.metrics:
132
  latest = self.metrics[-1]
133
  context_parts.extend([
134
+ "Recent Health Metrics:",
135
+ f"- Weight: {latest.get('Weight', 'N/A')} kg",
136
+ f"- Steps: {latest.get('Steps', 'N/A')}",
137
+ f"- Sleep: {latest.get('Sleep', 'N/A')} hours"
138
  ])
139
 
140
  if self.medications:
141
  context_parts.append("\nCurrent Medications:")
142
  for med in self.medications:
143
  med_info = f"- {med['Medication']} ({med['Dosage']}) at {med['Time']}"
144
+ if med.get('Notes'):
145
  med_info += f" | Note: {med['Notes']}"
146
  context_parts.append(med_info)
147
 
148
+ return "\n".join(context_parts) if context_parts else ""
 
 
 
 
 
 
149
 
150
+ def add_metrics(self, weight: float, steps: int, sleep: float) -> bool:
151
  try:
152
+ self.metrics.append({
153
+ 'Weight': weight,
154
+ 'Steps': steps,
155
+ 'Sleep': sleep
156
+ })
157
+ return True
158
+ except Exception as e:
159
+ logger.error(f"Error adding metrics: {e}")
160
+ return False
 
 
161
 
162
+ def add_medication(self, name: str, dosage: str, time: str, notes: str = "") -> bool:
163
+ try:
164
+ self.medications.append({
165
+ 'Medication': name,
166
+ 'Dosage': dosage,
167
+ 'Time': time,
168
+ 'Notes': notes
169
+ })
170
+ return True
171
  except Exception as e:
172
+ logger.error(f"Error adding medication: {e}")
173
+ return False
174
 
175
+ class GradioInterface:
176
  def __init__(self):
177
  self.assistant = HealthAssistant()
178
 
179
+ def chat_response(self, message: str, history: List) -> tuple:
180
+ if not message.strip():
181
  return "", history
182
 
183
+ response = self.assistant.generate_response(message, history)
184
+ history.append([message, response])
185
  return "", history
186
 
187
+ def add_health_metrics(self, weight: float, steps: int, sleep: float) -> str:
188
+ if not all([weight, steps, sleep]):
189
+ return "⚠️ Please fill in all metrics."
190
 
191
+ if self.assistant.add_metrics(weight, steps, sleep):
192
+ return "βœ… Health metrics saved successfully!"
193
+ return "❌ Error saving metrics."
 
 
194
 
195
+ def add_medication_info(self, name: str, dosage: str, time: str, notes: str) -> str:
196
  if not all([name, dosage, time]):
197
+ return "⚠️ Please fill in all required fields."
198
 
199
+ if self.assistant.add_medication(name, dosage, time, notes):
200
+ return "βœ… Medication added successfully!"
201
+ return "❌ Error adding medication."
 
 
 
 
 
 
 
202
 
203
  def create_interface(self):
204
  with gr.Blocks(title="Health Assistant", theme=gr.themes.Soft()) as demo:
205
  gr.Markdown(
206
  """
207
+ # πŸ₯ AI Health Assistant
208
+ Powered by Qwen2-VL for intelligent health guidance and monitoring.
209
  """
210
  )
211
 
 
214
  with gr.Tab("πŸ’¬ Health Chat"):
215
  chatbot = gr.Chatbot(
216
  height=450,
217
+ show_label=False
218
  )
219
  with gr.Row():
220
  msg = gr.Textbox(
221
+ placeholder="Ask your health question... (Press Enter)",
222
  lines=2,
223
  show_label=False,
224
  scale=9
 
226
  send_btn = gr.Button("Send", scale=1)
227
  clear_btn = gr.Button("Clear Chat")
228
 
229
+ # Health Metrics
230
  with gr.Tab("πŸ“Š Health Metrics"):
231
  with gr.Row():
232
+ weight_input = gr.Number(label="Weight (kg)")
233
+ steps_input = gr.Number(label="Steps")
234
+ sleep_input = gr.Number(label="Hours Slept")
235
+ metrics_btn = gr.Button("Save Metrics")
236
+ metrics_status = gr.Markdown()
 
 
 
 
 
 
237
 
238
+ # Medication Manager
239
  with gr.Tab("πŸ’Š Medication Manager"):
240
  with gr.Row():
241
+ med_name = gr.Textbox(label="Medication Name")
242
+ med_dosage = gr.Textbox(label="Dosage")
243
+ med_time = gr.Textbox(label="Time (e.g., 9:00 AM)")
244
+ med_notes = gr.Textbox(label="Notes (optional)")
245
+ med_btn = gr.Button("Add Medication")
246
+ med_status = gr.Markdown()
 
 
 
 
 
 
247
 
248
  # Event handlers
249
+ msg.submit(self.chat_response, [msg, chatbot], [msg, chatbot])
250
+ send_btn.click(self.chat_response, [msg, chatbot], [msg, chatbot])
251
  clear_btn.click(lambda: [], None, chatbot)
252
 
253
  metrics_btn.click(
254
+ self.add_health_metrics,
255
  inputs=[weight_input, steps_input, sleep_input],
256
+ outputs=[metrics_status]
257
  )
258
 
259
  med_btn.click(
260
+ self.add_medication_info,
261
  inputs=[med_name, med_dosage, med_time, med_notes],
262
+ outputs=[med_status]
263
  )
264
 
265
  gr.Markdown(
266
  """
267
  ### ⚠️ Important Note
268
+ This AI assistant provides general health information only.
269
  Always consult healthcare professionals for medical advice.
270
  """
271
  )
272
 
273
  return demo
274
 
 
 
 
 
 
275
  def main():
276
  try:
277
+ interface = GradioInterface()
278
+ demo = interface.create_interface()
 
 
 
 
 
 
279
  demo.launch(
280
  share=False,
281
  enable_queue=True,
282
  max_threads=4
283
  )
284
  except Exception as e:
285
+ logger.error(f"Error starting application: {e}")
286
 
287
  if __name__ == "__main__":
288
  main()