invincible-jha commited on
Commit
560e803
1 Parent(s): 5d670c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logging
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, Any
7
+ from datetime import datetime
8
+ import json
9
+ import os
10
+ from huggingface_hub import login
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.FileHandler('wellness_assistant.log'),
18
+ logging.StreamHandler()
19
+ ]
20
+ )
21
+
22
+ logger = logging.getLogger("WellnessAssistant")
23
+
24
+ # Login to Hugging Face Hub
25
+ try:
26
+ HF_TOKEN = os.getenv('HF_TOKEN')
27
+ if HF_TOKEN:
28
+ login(token=HF_TOKEN)
29
+ logger.info("Successfully logged in to Hugging Face Hub")
30
+ else:
31
+ logger.warning("HF_TOKEN not found in environment variables")
32
+ except Exception as e:
33
+ logger.error(f"Failed to login to Hugging Face Hub: {str(e)}")
34
+
35
+ class BaseAgent(ABC):
36
+ def __init__(self, name: str, model_id: str):
37
+ """Initialize base agent with common properties"""
38
+ self.name = name
39
+ self.model_id = model_id
40
+ self.logger = logging.getLogger(f"Agent.{name}")
41
+ self.logger.info(f"Initializing {name} with model {model_id}")
42
+
43
+ try:
44
+ self.model, self.tokenizer = self._load_model()
45
+ self.logger.info(f"Successfully loaded model and tokenizer for {name}")
46
+ except Exception as e:
47
+ self.logger.error(f"Failed to load model for {name}: {str(e)}")
48
+ raise
49
+
50
+ def _load_model(self):
51
+ """Load the specified model from Hugging Face"""
52
+ self.logger.debug(f"Loading model {self.model_id}")
53
+ try:
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ self.model_id,
56
+ token=HF_TOKEN,
57
+ trust_remote_code=True
58
+ )
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ self.model_id,
61
+ token=HF_TOKEN,
62
+ torch_dtype=torch.float16,
63
+ device_map="auto",
64
+ trust_remote_code=True
65
+ )
66
+ return model, tokenizer
67
+ except Exception as e:
68
+ self.logger.error(f"Error loading model {self.model_id}: {str(e)}")
69
+ raise
70
+
71
+ def generate_response(self, prompt: str, max_length: int = 512) -> str:
72
+ """Generate response using the model"""
73
+ self.logger.debug(f"Generating response for prompt: {prompt[:100]}...")
74
+ try:
75
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
76
+ self.logger.debug("Input tokens created successfully")
77
+
78
+ outputs = self.model.generate(
79
+ **inputs,
80
+ max_length=max_length,
81
+ num_return_sequences=1,
82
+ temperature=0.7,
83
+ top_p=0.9,
84
+ do_sample=True,
85
+ pad_token_id=self.tokenizer.eos_token_id
86
+ )
87
+ self.logger.debug("Model generation completed")
88
+
89
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+ response = response[len(prompt):].strip()
91
+ self.logger.debug(f"Generated response: {response[:100]}...")
92
+ return response
93
+
94
+ except Exception as e:
95
+ self.logger.error(f"Error generating response: {str(e)}")
96
+ return "I apologize, but I'm having trouble generating a response right now."
97
+
98
+ @abstractmethod
99
+ def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
100
+ """Process input and return response"""
101
+ pass
102
+
103
+ class TherapeuticAgent(BaseAgent):
104
+ def __init__(self):
105
+ super().__init__(
106
+ name="therapeutic_agent",
107
+ model_id="mistralai/Mistral-7B-Instruct-v0.2" # Using Mistral model
108
+ )
109
+ self.conversation_history = []
110
+ self.logger.info("Therapeutic agent initialized")
111
+
112
+ def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
113
+ """Process therapeutic conversations"""
114
+ self.logger.info("Processing therapeutic input")
115
+ self.logger.debug(f"Input data: {input_data}")
116
+
117
+ prompt = self._construct_therapeutic_prompt(input_data["text"])
118
+ response = self.generate_response(prompt)
119
+
120
+ # Update conversation history
121
+ self.conversation_history.append({
122
+ "timestamp": datetime.now().isoformat(),
123
+ "user": input_data["text"],
124
+ "agent": response
125
+ })
126
+
127
+ self.logger.info("Successfully processed therapeutic input")
128
+ self.logger.debug(f"Response: {response[:100]}...")
129
+
130
+ return {
131
+ "response": response,
132
+ "conversation_history": self.conversation_history
133
+ }
134
+
135
+ def _construct_therapeutic_prompt(self, user_input: str) -> str:
136
+ return f"""<s>[INST] You are a supportive and empathetic mental wellness assistant.
137
+ Your role is to provide caring, thoughtful responses while maintaining appropriate boundaries.
138
+ Always encourage professional help when needed.
139
+
140
+ User message: {user_input}
141
+
142
+ Provide a helpful and empathetic response: [/INST]"""
143
+
144
+ class MindfulnessAgent(BaseAgent):
145
+ def __init__(self):
146
+ super().__init__(
147
+ name="mindfulness_agent",
148
+ model_id="mistralai/Mistral-7B-Instruct-v0.2" # Using Mistral model
149
+ )
150
+ self.session_history = []
151
+ self.logger.info("Mindfulness agent initialized")
152
+
153
+ def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
154
+ """Process mindfulness-related requests"""
155
+ self.logger.info("Processing mindfulness input")
156
+ self.logger.debug(f"Input data: {input_data}")
157
+
158
+ prompt = self._construct_mindfulness_prompt(input_data["text"])
159
+ response = self.generate_response(prompt)
160
+
161
+ # Update session history
162
+ self.session_history.append({
163
+ "timestamp": datetime.now().isoformat(),
164
+ "user": input_data["text"],
165
+ "agent": response
166
+ })
167
+
168
+ self.logger.info("Successfully processed mindfulness input")
169
+ self.logger.debug(f"Response: {response[:100]}...")
170
+
171
+ return {
172
+ "response": response,
173
+ "session_history": self.session_history
174
+ }
175
+
176
+ def _construct_mindfulness_prompt(self, user_input: str) -> str:
177
+ return f"""<s>[INST] You are a mindfulness and meditation guide.
178
+ Your role is to provide calming guidance, meditation instructions, and mindfulness exercises.
179
+ Focus on present-moment awareness and gentle guidance.
180
+
181
+ User request: {user_input}
182
+
183
+ Provide mindfulness guidance: [/INST]"""
184
+
185
+ class WellnessApp:
186
+ def __init__(self):
187
+ self.logger = logging.getLogger("WellnessApp")
188
+ self.logger.info("Initializing Wellness App")
189
+
190
+ try:
191
+ self.therapeutic_agent = TherapeuticAgent()
192
+ self.mindfulness_agent = MindfulnessAgent()
193
+ self.logger.info("Successfully initialized all agents")
194
+ except Exception as e:
195
+ self.logger.error(f"Failed to initialize agents: {str(e)}")
196
+ raise
197
+
198
+ self.current_agent = "therapeutic" # Default agent
199
+
200
+ def switch_agent(self, agent_type: str) -> str:
201
+ """Switch between therapeutic and mindfulness agents"""
202
+ self.logger.info(f"Switching to {agent_type} agent")
203
+ self.current_agent = agent_type
204
+ return f"Switched to {agent_type} mode"
205
+
206
+ def respond(self, message: str, history: list) -> str:
207
+ """Process user message and return agent response"""
208
+ self.logger.info(f"Processing message with {self.current_agent} agent")
209
+ self.logger.debug(f"Message: {message}")
210
+
211
+ try:
212
+ if self.current_agent == "therapeutic":
213
+ response = self.therapeutic_agent.process({"text": message})
214
+ else:
215
+ response = self.mindfulness_agent.process({"text": message})
216
+
217
+ self.logger.info("Successfully generated response")
218
+ return response["response"]
219
+
220
+ except Exception as e:
221
+ self.logger.error(f"Error processing message: {str(e)}")
222
+ return "I apologize, but I'm having trouble processing your message right now."
223
+
224
+ def create_interface(self):
225
+ """Create Gradio interface"""
226
+ self.logger.info("Creating Gradio interface")
227
+
228
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
229
+ gr.Markdown("# Mental Wellness Assistant")
230
+
231
+ with gr.Row():
232
+ therapeutic_btn = gr.Button("Therapeutic Mode")
233
+ mindfulness_btn = gr.Button("Mindfulness Mode")
234
+
235
+ chatbot = gr.ChatInterface(
236
+ fn=self.respond,
237
+ examples=[
238
+ "I've been feeling anxious lately",
239
+ "Guide me through a breathing exercise",
240
+ "I need help managing stress",
241
+ "Can you teach me meditation?"
242
+ ],
243
+ title="",
244
+ )
245
+
246
+ therapeutic_btn.click(
247
+ fn=lambda: self.switch_agent("therapeutic"),
248
+ outputs=gr.Textbox(label="Status")
249
+ )
250
+ mindfulness_btn.click(
251
+ fn=lambda: self.switch_agent("mindfulness"),
252
+ outputs=gr.Textbox(label="Status")
253
+ )
254
+
255
+ gr.Markdown("""
256
+ ### Important Notice
257
+ This is a demo AI assistant and not a substitute for professional mental health care.
258
+ If you're experiencing a mental health crisis, please contact emergency services or a mental health professional.
259
+ """)
260
+
261
+ self.logger.info("Gradio interface created successfully")
262
+ return demo
263
+
264
+ # Create and launch the app
265
+ def main():
266
+ logger.info("Starting Wellness Assistant application")
267
+
268
+ try:
269
+ app = WellnessApp()
270
+ demo = app.create_interface()
271
+ logger.info("Application initialized successfully")
272
+
273
+ if __name__ == "__main__":
274
+ logger.info("Launching Gradio interface")
275
+ demo.launch()
276
+
277
+ except Exception as e:
278
+ logger.error(f"Failed to start application: {str(e)}")
279
+ raise
280
+
281
+ main()