File size: 8,092 Bytes
842983b
7780b77
5294bbb
7780b77
 
 
 
 
5294bbb
842983b
5294bbb
842983b
5294bbb
842983b
 
 
 
 
 
 
5294bbb
842983b
 
 
 
 
 
 
5294bbb
 
 
 
 
 
 
 
842983b
 
5294bbb
 
842983b
 
5294bbb
 
 
 
 
 
842983b
 
5294bbb
842983b
8b4d5ab
842983b
5294bbb
842983b
 
5294bbb
 
842983b
5294bbb
842983b
5294bbb
64ce286
 
 
 
 
 
 
81bd7f0
64ce286
 
 
 
 
 
 
 
81bd7f0
64ce286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7780b77
81bd7f0
5294bbb
a19823a
842983b
5294bbb
842983b
 
5294bbb
 
842983b
 
 
5294bbb
 
 
 
 
842983b
 
9f937db
5294bbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842983b
5294bbb
 
 
842983b
5294bbb
842983b
5294bbb
 
8b4d5ab
5294bbb
842983b
5294bbb
 
 
842983b
5294bbb
 
842983b
 
 
5294bbb
 
842983b
 
 
 
5294bbb
842983b
 
5294bbb
842983b
5294bbb
 
 
 
842983b
5294bbb
842983b
5294bbb
 
842983b
5294bbb
842983b
5294bbb
842983b
5294bbb
842983b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import streamlit as st

# 1. Set page config FIRST
st.set_page_config(
    page_title="TARS: Therapist Assistance and Response System", 
    page_icon="๐Ÿง "
)

import os
import torch
from transformers import pipeline

# Canadian Crisis Resources (example)
CRISIS_RESOURCES = {
    "Canada Suicide Prevention Service": "1-833-456-4566",
    "Crisis Services Canada": "1-833-456-4566",
    "Kids Help Phone": "1-800-668-6868",
    "First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310"
}

# Keywords for detecting potential self-harm or suicidal language
SUICIDE_KEYWORDS = [
    "suicide", "kill myself", "end my life", 
    "want to die", "hopeless", "no way out", 
    "better off dead", "pain is too much"
]

class AITherapistAssistant:
    def __init__(self, conversation_model_name="microsoft/phi-1_5", summary_model_name="facebook/bart-large-cnn"):
        """
        Initialize the conversation (LLM) model and the summarization model.
        
        If you truly have a different 'phi2' from Microsoft, replace 'microsoft/phi-1_5'
        with your private or custom Hugging Face repo name.
        """
        # Load conversation LLM (phi2 / phi-1_5)
        try:
            self.conversation_model = pipeline(
                "text-generation",
                model=conversation_model_name,
                device=0 if torch.cuda.is_available() else -1
            )
        except Exception as e:
            st.error(f"Error loading conversation model: {e}")
            self.conversation_model = None

        # Load summarization model (BART Large CNN as default)
        try:
            self.summary_model = pipeline(
                "summarization", 
                model=summary_model_name, 
                device=0 if torch.cuda.is_available() else -1
            )
        except Exception as e:
            st.error(f"Error loading summary model: {e}")
            self.summary_model = None
    
    def detect_crisis(self, message: str) -> bool:
        """Check if message contains suicidal or distress-related keywords."""
        message_lower = message.lower()
        return any(keyword in message_lower for keyword in SUICIDE_KEYWORDS)
    
    def generate_response(self, message: str) -> str:
        """Generate a supportive AI response from the conversation model."""
        if not self.conversation_model:
            return (
                "I'm having some technical difficulties right now. "
                "I'm truly sorry I can't provide support at the moment. "
                "Would you consider reaching out to a human counselor or support helpline?"
            )
        
        # Create a more structured and empathetic prompt
        prompt = (
            "You are a compassionate, empathetic AI therapist trained to provide supportive, "
            "non-judgmental responses. Your goal is to validate feelings, offer gentle support, "
            "and help the user feel heard and understood.\n\n"
            "User's message: {}\n\n"
            "Your compassionate response:".format(message)
        )
        
        try:
            outputs = self.conversation_model(
                prompt, 
                max_length=300,  # Increased length for more nuanced responses
                num_return_sequences=1,
                do_sample=True,
                top_p=0.9,
                temperature=0.7
            )
            
            response_text = outputs[0]["generated_text"]
            
            # More robust prompt stripping
            if response_text.startswith(prompt):
                response_text = response_text[len(prompt):].strip()
            
            # Additional processing to ensure response quality
            response_text = response_text.split('\n')[0].strip()  # Take first coherent line
            
            # Fallback if response is too short or nonsensical
            if len(response_text) < 20:
                response_text = (
                    "I hear you. Your feelings are valid, and it takes courage to share what you're experiencing. "
                    "Would you like to tell me a bit more about what's on your mind?"
                )
            
            return response_text
        except Exception as e:
            st.error(f"Error generating response: {e}")
            return (
                "I'm experiencing some difficulties right now. "
                "Your feelings are important, and I want to be fully present for you. "
                "Would you be willing to try sharing again?"
            )
        
    def generate_summary(self, conversation_text: str) -> str:
        """Generate a summary of the entire conversation for the user's feeling. Cite from their input how they felt."""
        if not self.summary_model:
            return "Summary model is unavailable at the moment."
        
        try:
            summary_output = self.summary_model(
                conversation_text, 
                max_length=130, 
                min_length=30, 
                do_sample=False
            )
            return summary_output[0]["summary_text"]
        except Exception as e:
            st.error(f"Error generating summary: {e}")
            return "Sorry, I couldn't generate a summary."

def main():
    st.title("๐Ÿง  TARS: Therapist Assistance and Response System")
    st.write(
        "A supportive space to share your feelings safely.\n\n"
        "**Disclaimer**: I am not a licensed therapist. If you're in crisis, "
        "please reach out to professional help immediately."
    )
    
    # Note if running on Hugging Face Spaces
    if os.environ.get("SPACE_ID"):
        st.info("Running on Hugging Face Spaces.")
    
    # Instantiate the assistant with phi-1_5 (or your custom 'phi2')
    if "assistant" not in st.session_state:
        st.session_state.assistant = AITherapistAssistant(
            conversation_model_name="microsoft/phi-1_5",  # replace if needed
            summary_model_name="facebook/bart-large-cnn"
        )
    
    # Keep track of conversation
    if "conversation" not in st.session_state:
        st.session_state.conversation = []
    
    # Display existing conversation
    for message in st.session_state.conversation:
        if message["sender"] == "user":
            st.chat_message("user").write(message["text"])
        else:
            st.chat_message("assistant").write(message["text"])
    
    # Collect user input
    if prompt := st.chat_input("How are you feeling today?"):
        # Crisis detection
        if st.session_state.assistant.detect_crisis(prompt):
            st.warning("โš ๏ธ Potential crisis detected.")
            st.markdown("**Immediate Support Resources (Canada):**")
            for org, phone in CRISIS_RESOURCES.items():
                st.markdown(f"- {org}: `{phone}`")
        
        # Display user message
        st.session_state.conversation.append({"sender": "user", "text": prompt})
        st.chat_message("user").write(prompt)
        
        # Generate AI response
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                ai_response = st.session_state.assistant.generate_response(prompt)
                st.write(ai_response)
        st.session_state.conversation.append({"sender": "assistant", "text": ai_response})
    
    # Summarize conversation
    if st.button("Generate Session Summary"):
        if st.session_state.conversation:
            conversation_text = " ".join(msg["text"] for msg in st.session_state.conversation)
            summary = st.session_state.assistant.generate_summary(conversation_text)
            st.subheader("Session Summary")
            st.write(summary)
        else:
            st.info("No conversation to summarize yet.")
    
    # Crisis Support Info in Sidebar
    st.sidebar.title("๐Ÿ†˜ Crisis Support")
    st.sidebar.markdown("If you're in crisis, please contact:")
    for org, phone in CRISIS_RESOURCES.items():
        st.sidebar.markdown(f"- **{org}**: `{phone}`")

if __name__ == "__main__":
    main()