|
import gradio as gr |
|
import logging |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from abc import ABC, abstractmethod |
|
from typing import Dict, Any |
|
from datetime import datetime |
|
import json |
|
import os |
|
from huggingface_hub import login |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('wellness_assistant.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
logger = logging.getLogger("WellnessAssistant") |
|
|
|
|
|
try: |
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
if HF_TOKEN: |
|
login(token=HF_TOKEN) |
|
logger.info("Successfully logged in to Hugging Face Hub") |
|
else: |
|
logger.warning("HF_TOKEN not found in environment variables") |
|
except Exception as e: |
|
logger.error(f"Failed to login to Hugging Face Hub: {str(e)}") |
|
|
|
class BaseAgent(ABC): |
|
def __init__(self, name: str, model_id: str): |
|
"""Initialize base agent with common properties""" |
|
self.name = name |
|
self.model_id = model_id |
|
self.logger = logging.getLogger(f"Agent.{name}") |
|
self.logger.info(f"Initializing {name} with model {model_id}") |
|
|
|
try: |
|
self.model, self.tokenizer = self._load_model() |
|
self.logger.info(f"Successfully loaded model and tokenizer for {name}") |
|
except Exception as e: |
|
self.logger.error(f"Failed to load model for {name}: {str(e)}") |
|
raise |
|
|
|
def _load_model(self): |
|
"""Load the specified model from Hugging Face""" |
|
self.logger.debug(f"Loading model {self.model_id}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_id, |
|
token=HF_TOKEN, |
|
trust_remote_code=True |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
self.model_id, |
|
token=HF_TOKEN, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
return model, tokenizer |
|
except Exception as e: |
|
self.logger.error(f"Error loading model {self.model_id}: {str(e)}") |
|
raise |
|
|
|
def generate_response(self, prompt: str, max_length: int = 512) -> str: |
|
"""Generate response using the model""" |
|
self.logger.debug(f"Generating response for prompt: {prompt[:100]}...") |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
self.logger.debug("Input tokens created successfully") |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
self.logger.debug("Model generation completed") |
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = response[len(prompt):].strip() |
|
self.logger.debug(f"Generated response: {response[:100]}...") |
|
return response |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error generating response: {str(e)}") |
|
return "I apologize, but I'm having trouble generating a response right now." |
|
|
|
@abstractmethod |
|
def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process input and return response""" |
|
pass |
|
|
|
class TherapeuticAgent(BaseAgent): |
|
def __init__(self): |
|
super().__init__( |
|
name="therapeutic_agent", |
|
model_id="mistralai/Mistral-7B-Instruct-v0.2" |
|
) |
|
self.conversation_history = [] |
|
self.logger.info("Therapeutic agent initialized") |
|
|
|
def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process therapeutic conversations""" |
|
self.logger.info("Processing therapeutic input") |
|
self.logger.debug(f"Input data: {input_data}") |
|
|
|
prompt = self._construct_therapeutic_prompt(input_data["text"]) |
|
response = self.generate_response(prompt) |
|
|
|
|
|
self.conversation_history.append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"user": input_data["text"], |
|
"agent": response |
|
}) |
|
|
|
self.logger.info("Successfully processed therapeutic input") |
|
self.logger.debug(f"Response: {response[:100]}...") |
|
|
|
return { |
|
"response": response, |
|
"conversation_history": self.conversation_history |
|
} |
|
|
|
def _construct_therapeutic_prompt(self, user_input: str) -> str: |
|
return f"""<s>[INST] You are a supportive and empathetic mental wellness assistant. |
|
Your role is to provide caring, thoughtful responses while maintaining appropriate boundaries. |
|
Always encourage professional help when needed. |
|
|
|
User message: {user_input} |
|
|
|
Provide a helpful and empathetic response: [/INST]""" |
|
|
|
class MindfulnessAgent(BaseAgent): |
|
def __init__(self): |
|
super().__init__( |
|
name="mindfulness_agent", |
|
model_id="mistralai/Mistral-7B-Instruct-v0.2" |
|
) |
|
self.session_history = [] |
|
self.logger.info("Mindfulness agent initialized") |
|
|
|
def process(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process mindfulness-related requests""" |
|
self.logger.info("Processing mindfulness input") |
|
self.logger.debug(f"Input data: {input_data}") |
|
|
|
prompt = self._construct_mindfulness_prompt(input_data["text"]) |
|
response = self.generate_response(prompt) |
|
|
|
|
|
self.session_history.append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"user": input_data["text"], |
|
"agent": response |
|
}) |
|
|
|
self.logger.info("Successfully processed mindfulness input") |
|
self.logger.debug(f"Response: {response[:100]}...") |
|
|
|
return { |
|
"response": response, |
|
"session_history": self.session_history |
|
} |
|
|
|
def _construct_mindfulness_prompt(self, user_input: str) -> str: |
|
return f"""<s>[INST] You are a mindfulness and meditation guide. |
|
Your role is to provide calming guidance, meditation instructions, and mindfulness exercises. |
|
Focus on present-moment awareness and gentle guidance. |
|
|
|
User request: {user_input} |
|
|
|
Provide mindfulness guidance: [/INST]""" |
|
|
|
class WellnessApp: |
|
def __init__(self): |
|
self.logger = logging.getLogger("WellnessApp") |
|
self.logger.info("Initializing Wellness App") |
|
|
|
try: |
|
self.therapeutic_agent = TherapeuticAgent() |
|
self.mindfulness_agent = MindfulnessAgent() |
|
self.logger.info("Successfully initialized all agents") |
|
except Exception as e: |
|
self.logger.error(f"Failed to initialize agents: {str(e)}") |
|
raise |
|
|
|
self.current_agent = "therapeutic" |
|
|
|
def switch_agent(self, agent_type: str) -> str: |
|
"""Switch between therapeutic and mindfulness agents""" |
|
self.logger.info(f"Switching to {agent_type} agent") |
|
self.current_agent = agent_type |
|
return f"Switched to {agent_type} mode" |
|
|
|
def respond(self, message: str, history: list) -> str: |
|
"""Process user message and return agent response""" |
|
self.logger.info(f"Processing message with {self.current_agent} agent") |
|
self.logger.debug(f"Message: {message}") |
|
|
|
try: |
|
if self.current_agent == "therapeutic": |
|
response = self.therapeutic_agent.process({"text": message}) |
|
else: |
|
response = self.mindfulness_agent.process({"text": message}) |
|
|
|
self.logger.info("Successfully generated response") |
|
return response["response"] |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error processing message: {str(e)}") |
|
return "I apologize, but I'm having trouble processing your message right now." |
|
|
|
def create_interface(self): |
|
"""Create Gradio interface""" |
|
self.logger.info("Creating Gradio interface") |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Mental Wellness Assistant") |
|
|
|
with gr.Row(): |
|
therapeutic_btn = gr.Button("Therapeutic Mode") |
|
mindfulness_btn = gr.Button("Mindfulness Mode") |
|
|
|
chatbot = gr.ChatInterface( |
|
fn=self.respond, |
|
examples=[ |
|
"I've been feeling anxious lately", |
|
"Guide me through a breathing exercise", |
|
"I need help managing stress", |
|
"Can you teach me meditation?" |
|
], |
|
title="", |
|
) |
|
|
|
therapeutic_btn.click( |
|
fn=lambda: self.switch_agent("therapeutic"), |
|
outputs=gr.Textbox(label="Status") |
|
) |
|
mindfulness_btn.click( |
|
fn=lambda: self.switch_agent("mindfulness"), |
|
outputs=gr.Textbox(label="Status") |
|
) |
|
|
|
gr.Markdown(""" |
|
### Important Notice |
|
This is a demo AI assistant and not a substitute for professional mental health care. |
|
If you're experiencing a mental health crisis, please contact emergency services or a mental health professional. |
|
""") |
|
|
|
self.logger.info("Gradio interface created successfully") |
|
return demo |
|
|
|
|
|
def main(): |
|
logger.info("Starting Wellness Assistant application") |
|
|
|
try: |
|
app = WellnessApp() |
|
demo = app.create_interface() |
|
logger.info("Application initialized successfully") |
|
|
|
if __name__ == "__main__": |
|
logger.info("Launching Gradio interface") |
|
demo.launch() |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to start application: {str(e)}") |
|
raise |
|
|
|
main() |