Spaces:
Sleeping
Sleeping
import streamlit as st | |
# 1. Set page config FIRST | |
st.set_page_config( | |
page_title="TARS: Therapist Assistance and Response System", | |
page_icon="๐ง " | |
) | |
import os | |
import torch | |
from transformers import pipeline | |
# Canadian Crisis Resources (example) | |
CRISIS_RESOURCES = { | |
"Canada Suicide Prevention Service": "1-833-456-4566", | |
"Crisis Services Canada": "1-833-456-4566", | |
"Kids Help Phone": "1-800-668-6868", | |
"First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310" | |
} | |
# Keywords for detecting potential self-harm or suicidal language | |
SUICIDE_KEYWORDS = [ | |
"suicide", "kill myself", "end my life", | |
"want to die", "hopeless", "no way out", | |
"better off dead", "pain is too much" | |
] | |
class AITherapistAssistant: | |
def __init__(self, conversation_model_name="microsoft/phi-1_5", summary_model_name="facebook/bart-large-cnn"): | |
""" | |
Initialize the conversation (LLM) model and the summarization model. | |
If you truly have a different 'phi2' from Microsoft, replace 'microsoft/phi-1_5' | |
with your private or custom Hugging Face repo name. | |
""" | |
# Load conversation LLM (phi2 / phi-1_5) | |
try: | |
self.conversation_model = pipeline( | |
"text-generation", | |
model=conversation_model_name, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
except Exception as e: | |
st.error(f"Error loading conversation model: {e}") | |
self.conversation_model = None | |
# Load summarization model (BART Large CNN as default) | |
try: | |
self.summary_model = pipeline( | |
"summarization", | |
model=summary_model_name, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
except Exception as e: | |
st.error(f"Error loading summary model: {e}") | |
self.summary_model = None | |
def detect_crisis(self, message: str) -> bool: | |
"""Check if message contains suicidal or distress-related keywords.""" | |
message_lower = message.lower() | |
return any(keyword in message_lower for keyword in SUICIDE_KEYWORDS) | |
def generate_response(self, message: str) -> str: | |
"""Generate a supportive AI response from the conversation model.""" | |
if not self.conversation_model: | |
return ( | |
"I'm having some technical difficulties right now. " | |
"I'm truly sorry I can't provide support at the moment. " | |
"Would you consider reaching out to a human counselor or support helpline?" | |
) | |
# Create a more structured and empathetic prompt | |
prompt = ( | |
"You are a compassionate, empathetic AI therapist trained to provide supportive, " | |
"non-judgmental responses. Your goal is to validate feelings, offer gentle support, " | |
"and help the user feel heard and understood.\n\n" | |
"User's message: {}\n\n" | |
"Your compassionate response:".format(message) | |
) | |
try: | |
outputs = self.conversation_model( | |
prompt, | |
max_length=300, # Increased length for more nuanced responses | |
num_return_sequences=1, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7 | |
) | |
response_text = outputs[0]["generated_text"] | |
# More robust prompt stripping | |
if response_text.startswith(prompt): | |
response_text = response_text[len(prompt):].strip() | |
# Additional processing to ensure response quality | |
response_text = response_text.split('\n')[0].strip() # Take first coherent line | |
# Fallback if response is too short or nonsensical | |
if len(response_text) < 20: | |
response_text = ( | |
"I hear you. Your feelings are valid, and it takes courage to share what you're experiencing. " | |
"Would you like to tell me a bit more about what's on your mind?" | |
) | |
return response_text | |
except Exception as e: | |
st.error(f"Error generating response: {e}") | |
return ( | |
"I'm experiencing some difficulties right now. " | |
"Your feelings are important, and I want to be fully present for you. " | |
"Would you be willing to try sharing again?" | |
) | |
def generate_summary(self, conversation_text: str) -> str: | |
"""Generate a summary of the entire conversation for the user's feeling. Cite from their input how they felt.""" | |
if not self.summary_model: | |
return "Summary model is unavailable at the moment." | |
try: | |
summary_output = self.summary_model( | |
conversation_text, | |
max_length=130, | |
min_length=30, | |
do_sample=False | |
) | |
return summary_output[0]["summary_text"] | |
except Exception as e: | |
st.error(f"Error generating summary: {e}") | |
return "Sorry, I couldn't generate a summary." | |
def main(): | |
st.title("๐ง TARS: Therapist Assistance and Response System") | |
st.write( | |
"A supportive space to share your feelings safely.\n\n" | |
"**Disclaimer**: I am not a licensed therapist. If you're in crisis, " | |
"please reach out to professional help immediately." | |
) | |
# Note if running on Hugging Face Spaces | |
if os.environ.get("SPACE_ID"): | |
st.info("Running on Hugging Face Spaces.") | |
# Instantiate the assistant with phi-1_5 (or your custom 'phi2') | |
if "assistant" not in st.session_state: | |
st.session_state.assistant = AITherapistAssistant( | |
conversation_model_name="microsoft/phi-1_5", # replace if needed | |
summary_model_name="facebook/bart-large-cnn" | |
) | |
# Keep track of conversation | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = [] | |
# Display existing conversation | |
for message in st.session_state.conversation: | |
if message["sender"] == "user": | |
st.chat_message("user").write(message["text"]) | |
else: | |
st.chat_message("assistant").write(message["text"]) | |
# Collect user input | |
if prompt := st.chat_input("How are you feeling today?"): | |
# Crisis detection | |
if st.session_state.assistant.detect_crisis(prompt): | |
st.warning("โ ๏ธ Potential crisis detected.") | |
st.markdown("**Immediate Support Resources (Canada):**") | |
for org, phone in CRISIS_RESOURCES.items(): | |
st.markdown(f"- {org}: `{phone}`") | |
# Display user message | |
st.session_state.conversation.append({"sender": "user", "text": prompt}) | |
st.chat_message("user").write(prompt) | |
# Generate AI response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
ai_response = st.session_state.assistant.generate_response(prompt) | |
st.write(ai_response) | |
st.session_state.conversation.append({"sender": "assistant", "text": ai_response}) | |
# Summarize conversation | |
if st.button("Generate Session Summary"): | |
if st.session_state.conversation: | |
conversation_text = " ".join(msg["text"] for msg in st.session_state.conversation) | |
summary = st.session_state.assistant.generate_summary(conversation_text) | |
st.subheader("Session Summary") | |
st.write(summary) | |
else: | |
st.info("No conversation to summarize yet.") | |
# Crisis Support Info in Sidebar | |
st.sidebar.title("๐ Crisis Support") | |
st.sidebar.markdown("If you're in crisis, please contact:") | |
for org, phone in CRISIS_RESOURCES.items(): | |
st.sidebar.markdown(f"- **{org}**: `{phone}`") | |
if __name__ == "__main__": | |
main() | |