File size: 10,213 Bytes
4ef990b
e1fb3b2
70448af
e0bdeef
a548a89
 
25b6c4d
e0bdeef
f0cbaa0
70448af
 
 
 
e0bdeef
 
f0cbaa0
 
70448af
a548a89
70448af
 
f0cbaa0
70448af
 
 
a548a89
 
 
 
e1fb3b2
54c24d5
e1fb3b2
e0bdeef
70448af
 
 
a548a89
25b6c4d
a548a89
f0cbaa0
70448af
 
 
f0cbaa0
70448af
 
 
 
 
25b6c4d
70448af
 
 
f0cbaa0
70448af
f0cbaa0
a548a89
70448af
f0cbaa0
70448af
 
 
 
f0cbaa0
e0bdeef
a548a89
 
70448af
f0cbaa0
70448af
 
 
a548a89
70448af
 
 
 
e0bdeef
70448af
 
f0cbaa0
a548a89
70448af
 
 
 
 
 
 
 
 
 
 
 
e1fb3b2
70448af
 
 
 
 
a548a89
 
70448af
e1fb3b2
a548a89
70448af
a548a89
54c24d5
70448af
a548a89
e0bdeef
70448af
 
f0cbaa0
70448af
 
 
a548a89
70448af
f0cbaa0
70448af
f0cbaa0
 
a548a89
70448af
 
 
 
 
 
 
fe937b3
a548a89
e1fb3b2
fe937b3
 
e1fb3b2
e0bdeef
a548a89
 
 
 
e0bdeef
fe937b3
 
e1fb3b2
fe937b3
e1fb3b2
a548a89
e1fb3b2
 
fe937b3
a548a89
e1fb3b2
a548a89
54c24d5
a548a89
 
 
 
 
 
 
 
 
e0bdeef
a548a89
 
 
 
 
 
 
 
 
54c24d5
a548a89
 
fe937b3
a548a89
fe937b3
70448af
 
f0cbaa0
70448af
 
 
 
 
 
fe937b3
a548a89
 
fe937b3
 
a548a89
 
fe937b3
 
a548a89
25b6c4d
a548a89
fe937b3
a548a89
 
 
fe937b3
a548a89
fe937b3
a548a89
fe937b3
a548a89
 
 
fe937b3
 
f0cbaa0
25b6c4d
fe937b3
 
f0cbaa0
fe937b3
6ed2a87
f0cbaa0
 
6ed2a87
fe937b3
 
a548a89
fe937b3
 
 
 
e1fb3b2
fe937b3
6ed2a87
f0cbaa0
fe937b3
 
a548a89
 
 
 
 
6ed2a87
f0cbaa0
fe937b3
 
a548a89
 
 
 
 
 
e1fb3b2
f0cbaa0
a548a89
 
6ed2a87
e1fb3b2
 
a548a89
e1fb3b2
a548a89
e1fb3b2
 
 
a548a89
e1fb3b2
a548a89
e1fb3b2
4ef990b
f0cbaa0
 
fe937b3
 
 
e0bdeef
70448af
a548a89
 
70448af
e0bdeef
25b6c4d
 
f0cbaa0
e0bdeef
 
a548a89
70448af
4ef990b
fe937b3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import List, Dict
import gc
import os

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Force CPU usage and set memory optimizations
torch.set_num_threads(4)

class HealthAssistant:
    def __init__(self, use_smaller_model=True):
        if use_smaller_model:
            self.model_name = "facebook/opt-125m"
        else:
            self.model_name = "Qwen/Qwen2-VL-7B-Instruct"
        
        self.model = None
        self.tokenizer = None
        self.metrics = []
        self.medications = []
        self.initialize_model()

    def initialize_model(self):
        try:
            logger.info(f"Starting model initialization: {self.model_name}")
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            logger.info("Tokenizer loaded")

            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            
            self.model = self.model.to("cpu")
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            logger.info("Model loaded successfully")
            return True
            
        except Exception as e:
            logger.error(f"Error in model initialization: {str(e)}")
            raise

    def is_initialized(self):
        return (self.model is not None and 
                self.tokenizer is not None and 
                hasattr(self.model, 'generate'))

    def generate_response(self, message: str, history: List = None) -> str:
        try:
            if not self.is_initialized():
                return "System is still initializing. Please try again in a moment."

            # Prepare prompt
            prompt = self._prepare_prompt(message, history)
            
            # Tokenize
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            )

            # Generate
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs["input_ids"],
                    max_new_tokens=128,
                    num_beams=1,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )

            # Decode
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )

            # Cleanup
            del outputs, inputs
            gc.collect()

            return response.strip()

        except Exception as e:
            logger.error(f"Error generating response: {str(e)}")
            return "I apologize, but I encountered an error. Please try again."

    def _prepare_prompt(self, message: str, history: List = None) -> str:
        parts = [
            "You are a helpful healthcare assistant providing accurate and helpful medical information.",
            self._get_health_context() or "No health data available yet."
        ]
        
        if history:
            parts.append("Previous conversation:")
            for h in history[-3:]:
                parts.extend([
                    f"User: {h[0]}",
                    f"Assistant: {h[1]}"
                ])
        
        parts.extend([
            f"User: {message}",
            "Assistant:"
        ])
        
        return "\n\n".join(parts)

    def _get_health_context(self) -> str:
        context_parts = []
        
        if self.metrics:
            latest = self.metrics[-1]
            context_parts.extend([
                "Recent Health Metrics:",
                f"- Weight: {latest.get('Weight', 'N/A')} kg",
                f"- Steps: {latest.get('Steps', 'N/A')}",
                f"- Sleep: {latest.get('Sleep', 'N/A')} hours"
            ])

        if self.medications:
            context_parts.append("\nCurrent Medications:")
            for med in self.medications:
                med_info = f"- {med['Medication']} ({med['Dosage']}) at {med['Time']}"
                if med.get('Notes'):
                    med_info += f" | Note: {med['Notes']}"
                context_parts.append(med_info)

        return "\n".join(context_parts) if context_parts else ""

    def add_metrics(self, weight: float, steps: int, sleep: float) -> bool:
        try:
            self.metrics.append({
                'Weight': weight,
                'Steps': steps,
                'Sleep': sleep
            })
            return True
        except Exception as e:
            logger.error(f"Error adding metrics: {e}")
            return False

    def add_medication(self, name: str, dosage: str, time: str, notes: str = "") -> bool:
        try:
            self.medications.append({
                'Medication': name,
                'Dosage': dosage,
                'Time': time,
                'Notes': notes
            })
            return True
        except Exception as e:
            logger.error(f"Error adding medication: {e}")
            return False

class GradioInterface:
    def __init__(self):
        try:
            logger.info("Initializing Health Assistant...")
            self.assistant = HealthAssistant(use_smaller_model=True)
            if not self.assistant.is_initialized():
                raise RuntimeError("Health Assistant failed to initialize properly")
            logger.info("Health Assistant initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize Health Assistant: {e}")
            raise

    def chat_response(self, message: str, history: List) -> tuple:
        if not message.strip():
            return "", history
        
        response = self.assistant.generate_response(message, history)
        history.append([message, response])
        return "", history

    def add_health_metrics(self, weight: float, steps: int, sleep: float) -> str:
        if not all([weight is not None, steps is not None, sleep is not None]):
            return "⚠️ Please fill in all metrics."
        
        if self.assistant.add_metrics(weight, steps, sleep):
            return "βœ… Health metrics saved successfully!"
        return "❌ Error saving metrics."

    def add_medication_info(self, name: str, dosage: str, time: str, notes: str) -> str:
        if not all([name, dosage, time]):
            return "⚠️ Please fill in all required fields."
        
        if self.assistant.add_medication(name, dosage, time, notes):
            return "βœ… Medication added successfully!"
        return "❌ Error adding medication."

    def create_interface(self):
        with gr.Blocks(title="Health Assistant") as demo:
            gr.Markdown("# πŸ₯ AI Health Assistant")
            
            with gr.Tabs():
                # Chat Interface
                with gr.Tab("πŸ’¬ Health Chat"):
                    chatbot = gr.Chatbot(
                        value=[],
                        height=450
                    )
                    with gr.Row():
                        msg = gr.Textbox(
                            placeholder="Ask your health question... (Press Enter)",
                            lines=2,
                            show_label=False,
                            scale=9
                        )
                        send_btn = gr.Button("Send", scale=1)
                    clear_btn = gr.Button("Clear Chat")

                # Health Metrics
                with gr.Tab("πŸ“Š Health Metrics"):
                    with gr.Row():
                        weight_input = gr.Number(label="Weight (kg)")
                        steps_input = gr.Number(label="Steps")
                        sleep_input = gr.Number(label="Hours Slept")
                    metrics_btn = gr.Button("Save Metrics")
                    metrics_status = gr.Markdown()

                # Medication Manager
                with gr.Tab("πŸ’Š Medication Manager"):
                    with gr.Row():
                        med_name = gr.Textbox(label="Medication Name")
                        med_dosage = gr.Textbox(label="Dosage")
                        med_time = gr.Textbox(label="Time (e.g., 9:00 AM)")
                        med_notes = gr.Textbox(label="Notes (optional)")
                    med_btn = gr.Button("Add Medication")
                    med_status = gr.Markdown()

            # Event handlers
            msg.submit(self.chat_response, [msg, chatbot], [msg, chatbot])
            send_btn.click(self.chat_response, [msg, chatbot], [msg, chatbot])
            clear_btn.click(lambda: [], None, chatbot)
            
            metrics_btn.click(
                self.add_health_metrics,
                inputs=[weight_input, steps_input, sleep_input],
                outputs=[metrics_status]
            )
            
            med_btn.click(
                self.add_medication_info,
                inputs=[med_name, med_dosage, med_time, med_notes],
                outputs=[med_status]
            )

            demo.queue()
            
        return demo

def main():
    try:
        logger.info("Starting application...")
        interface = GradioInterface()
        demo = interface.create_interface()
        logger.info("Launching Gradio interface...")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False
        )
    except Exception as e:
        logger.error(f"Error starting application: {e}")
        raise

if __name__ == "__main__":
    main()