File size: 7,263 Bytes
73bb3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
from crewai.tools import BaseTool
from crewai.tools import tool
from transformers import pipeline
from backend.crew_ai.data_retriever_util import get_user_profile
from backend.crew_ai.config import get_config
import psycopg2
from psycopg2.extras import RealDictCursor
from typing import ClassVar
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import pipeline
from gradio_client import Client

class MentalHealthTools:
    """Tools for mental health chatbot"""
    @tool("Bhutanese Helplines")
    def get_bhutanese_helplines() -> str:
        """

        Retrieves Bhutanese mental health helplines from the PostgreSQL `resources` table.



        """
        try:
            db_uri = os.getenv("SUPABASE_DB_URI")
            if not db_uri:
                raise ValueError("SUPABASE_DB_URI not set in environment")

            conn = psycopg2.connect(db_uri)
            cursor = conn.cursor(cursor_factory=RealDictCursor)

            query = """

            SELECT name, description, phone, website, address, operation_hours

            FROM resources

            """
            cursor.execute(query)
            helplines = cursor.fetchall()

            if not helplines:
                return "No helplines found in the database."

            response = "πŸ“ž Bhutanese Mental Health Helplines:\n"
            for h in helplines:
                response += f"\nπŸ“Œ {h['name']}"
                if h['description']:
                    response += f"\n   Description: {h['description']}"
                if h['phone']:
                    response += f"\n   πŸ“± Phone: {h['phone']}"
                if h['website']:
                    response += f"\n   🌐 Website: {h['website']}"
                if h['address']:
                    response += f"\n   🏠 Address: {h['address']}"
                if h['operation_hours']:
                    response += f"\n   ⏰ Hours: {h['operation_hours']}"
                response += "\n"

            cursor.close()
            conn.close()
            return response.strip()

        except Exception as e:
            return f"⚠️ Failed to fetch helplines from DB: {str(e)}"
        

class CrisisClassifierTool(BaseTool):
    name: str = "Crisis Classifier"
    description: str = (
        "A tool that classifies text into predefined categories. "
        "Input should be the text to classify."
    )
            
    def _run(self, text: str) -> str:
        """

        Classifies the given text using the Hugging Face model.

        Returns the classification label and score.

        """
        try:
            # Initialize the pipeline here (will happen on every tool call)
            classifier = pipeline("sentiment-analysis", model="sentinet/suicidality")
            result = classifier(text)
            if result:
                label = result[0]['label']
                score = result[0]['score']
                return f"Classification: {label} (Score: {score:.4f})"
            return "Could not classify the text."
        except Exception as e:
            return f"Error during text classification: {e}"
        
class MentalConditionClassifierTool(BaseTool):
    name: str = "Mental condition Classifier"
    description: str = (
        "A tool that classifies text into predefined categories. "
        "Input should be the text to classify."
    )

    # Class-level cache for the client
    _client = None

    def _get_client(self):
        if self._client is None:
            self.__class__._client = Client("ety89/mental_health_text_classifiaction")  # βœ… fixed typo
        return self._client
            
    def _run(self, text: str) -> str:
        """

        Classifies the given text using the Hugging Face model.

        Returns the classification label and score.

        """
        try:
            # Initialize the pipeline here (will happen on every tool call)
            
            client = Client("ety89/mental_health_text_classifiaction")
            result = client.predict(
                input_text=text,
                api_name="/predict"
            )
            if result:
                label = result.split(':')[-2].split('(')[-2].strip()
                score = result.split(':')[-1].strip(')').strip()
                return label, score
        
            return "Could not classify the text."
    
        except Exception as e:
            return f"Error during text classification: {e}"
        
class DataRetrievalTool(BaseTool):
    name: str = "Data Retrieval"
    description: str = (
        "A tool that fetched the user profile data from the database. "
        "Input should be User Profile ID."
    )

       
    def _run(self, user_profile_id: str) -> str:
        """

        Fetches the user profile data from the database using the user profile ID.

        Returns the user profile information or an error message.

        """
        try:

            config = get_config()   

            if user_profile_id.strip() == "anon_user":
                return config['default_user_profile']

            # Retrieve user profile using the utility function
            user_profile = get_user_profile(user_profile_id)
            if user_profile:
                return f"User Profile: {user_profile}"
            return "User profile not found."
        except Exception as e:
            return f"Error retrieving user profile: {e}"
        
class QueryVectorStoreTool(BaseTool):
    name: str = "Query Vector Store"
    description: str = (
        "Queries the Supabase-hosted PostgreSQL vector database with a user query and classified condition, "
        "and retrieves the top 3 most relevant documents."
    )

    # Shared across all instances
    embedding_model: ClassVar = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )

    def _run(self, user_query: str, classified_condition: str) -> dict:
        query_text = f"{user_query} Condition: {classified_condition}"
        embedding = self.embedding_model.embed_query(query_text)

        db_uri = os.getenv("SUPABASE_DB_URI")
        if not db_uri:
            raise ValueError("SUPABASE_DB_URI not set in environment")

        conn = psycopg2.connect(db_uri)
        cursor = conn.cursor()

        cursor.execute("""

            SELECT ac.chunk_text, a.title, a.topic, a.source, ac.embedding <-> %s::vector AS score

            FROM article_chunks ac

            JOIN articles a ON ac.doc_id = a.id

            ORDER BY score

            LIMIT 3;

        """, (embedding,))


        rows = cursor.fetchall()
        docs = [
            {
                "text": row[0],
                "title": row[1],
                "topic": row[2],
                "source": row[3],
                "score": row[4]
            }
            for row in rows
        ]
        
        cursor.close()
        conn.close()

        return {"docs": docs}

    def _arun(self, *args, **kwargs):
        raise NotImplementedError("Async version not implemented")