TARS_Demo / app.py
Divymakesml's picture
Update app.py
a19823a verified
import streamlit as st
# 1. Set page config FIRST
st.set_page_config(
page_title="TARS: Therapist Assistance and Response System",
page_icon="๐Ÿง "
)
import os
import torch
from transformers import pipeline
# Canadian Crisis Resources (example)
CRISIS_RESOURCES = {
"Canada Suicide Prevention Service": "1-833-456-4566",
"Crisis Services Canada": "1-833-456-4566",
"Kids Help Phone": "1-800-668-6868",
"First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310"
}
# Keywords for detecting potential self-harm or suicidal language
SUICIDE_KEYWORDS = [
"suicide", "kill myself", "end my life",
"want to die", "hopeless", "no way out",
"better off dead", "pain is too much"
]
class AITherapistAssistant:
def __init__(self, conversation_model_name="microsoft/phi-1_5", summary_model_name="facebook/bart-large-cnn"):
"""
Initialize the conversation (LLM) model and the summarization model.
If you truly have a different 'phi2' from Microsoft, replace 'microsoft/phi-1_5'
with your private or custom Hugging Face repo name.
"""
# Load conversation LLM (phi2 / phi-1_5)
try:
self.conversation_model = pipeline(
"text-generation",
model=conversation_model_name,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
st.error(f"Error loading conversation model: {e}")
self.conversation_model = None
# Load summarization model (BART Large CNN as default)
try:
self.summary_model = pipeline(
"summarization",
model=summary_model_name,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
st.error(f"Error loading summary model: {e}")
self.summary_model = None
def detect_crisis(self, message: str) -> bool:
"""Check if message contains suicidal or distress-related keywords."""
message_lower = message.lower()
return any(keyword in message_lower for keyword in SUICIDE_KEYWORDS)
def generate_response(self, message: str) -> str:
"""Generate a supportive AI response from the conversation model."""
if not self.conversation_model:
return (
"I'm having some technical difficulties right now. "
"I'm truly sorry I can't provide support at the moment. "
"Would you consider reaching out to a human counselor or support helpline?"
)
# Create a more structured and empathetic prompt
prompt = (
"You are a compassionate, empathetic AI therapist trained to provide supportive, "
"non-judgmental responses. Your goal is to validate feelings, offer gentle support, "
"and help the user feel heard and understood.\n\n"
"User's message: {}\n\n"
"Your compassionate response:".format(message)
)
try:
outputs = self.conversation_model(
prompt,
max_length=300, # Increased length for more nuanced responses
num_return_sequences=1,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response_text = outputs[0]["generated_text"]
# More robust prompt stripping
if response_text.startswith(prompt):
response_text = response_text[len(prompt):].strip()
# Additional processing to ensure response quality
response_text = response_text.split('\n')[0].strip() # Take first coherent line
# Fallback if response is too short or nonsensical
if len(response_text) < 20:
response_text = (
"I hear you. Your feelings are valid, and it takes courage to share what you're experiencing. "
"Would you like to tell me a bit more about what's on your mind?"
)
return response_text
except Exception as e:
st.error(f"Error generating response: {e}")
return (
"I'm experiencing some difficulties right now. "
"Your feelings are important, and I want to be fully present for you. "
"Would you be willing to try sharing again?"
)
def generate_summary(self, conversation_text: str) -> str:
"""Generate a summary of the entire conversation for the user's feeling. Cite from their input how they felt."""
if not self.summary_model:
return "Summary model is unavailable at the moment."
try:
summary_output = self.summary_model(
conversation_text,
max_length=130,
min_length=30,
do_sample=False
)
return summary_output[0]["summary_text"]
except Exception as e:
st.error(f"Error generating summary: {e}")
return "Sorry, I couldn't generate a summary."
def main():
st.title("๐Ÿง  TARS: Therapist Assistance and Response System")
st.write(
"A supportive space to share your feelings safely.\n\n"
"**Disclaimer**: I am not a licensed therapist. If you're in crisis, "
"please reach out to professional help immediately."
)
# Note if running on Hugging Face Spaces
if os.environ.get("SPACE_ID"):
st.info("Running on Hugging Face Spaces.")
# Instantiate the assistant with phi-1_5 (or your custom 'phi2')
if "assistant" not in st.session_state:
st.session_state.assistant = AITherapistAssistant(
conversation_model_name="microsoft/phi-1_5", # replace if needed
summary_model_name="facebook/bart-large-cnn"
)
# Keep track of conversation
if "conversation" not in st.session_state:
st.session_state.conversation = []
# Display existing conversation
for message in st.session_state.conversation:
if message["sender"] == "user":
st.chat_message("user").write(message["text"])
else:
st.chat_message("assistant").write(message["text"])
# Collect user input
if prompt := st.chat_input("How are you feeling today?"):
# Crisis detection
if st.session_state.assistant.detect_crisis(prompt):
st.warning("โš ๏ธ Potential crisis detected.")
st.markdown("**Immediate Support Resources (Canada):**")
for org, phone in CRISIS_RESOURCES.items():
st.markdown(f"- {org}: `{phone}`")
# Display user message
st.session_state.conversation.append({"sender": "user", "text": prompt})
st.chat_message("user").write(prompt)
# Generate AI response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
ai_response = st.session_state.assistant.generate_response(prompt)
st.write(ai_response)
st.session_state.conversation.append({"sender": "assistant", "text": ai_response})
# Summarize conversation
if st.button("Generate Session Summary"):
if st.session_state.conversation:
conversation_text = " ".join(msg["text"] for msg in st.session_state.conversation)
summary = st.session_state.assistant.generate_summary(conversation_text)
st.subheader("Session Summary")
st.write(summary)
else:
st.info("No conversation to summarize yet.")
# Crisis Support Info in Sidebar
st.sidebar.title("๐Ÿ†˜ Crisis Support")
st.sidebar.markdown("If you're in crisis, please contact:")
for org, phone in CRISIS_RESOURCES.items():
st.sidebar.markdown(f"- **{org}**: `{phone}`")
if __name__ == "__main__":
main()