|
import streamlit as st |
|
import os |
|
import requests |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
from datasets import load_dataset |
|
|
|
|
|
os.environ["GROQ_API_KEY"] = "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" |
|
|
|
|
|
GROQ_API_URL = "https://api.groq.com/v1/inference" |
|
|
|
|
|
def retrieve_embedding(user_query): |
|
payload = { |
|
"model": "microsoft/MiniLM-L6-H384-uncased", |
|
"input_text": user_query |
|
} |
|
headers = { |
|
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}" |
|
} |
|
|
|
|
|
response = requests.post(f"{GROQ_API_URL}/embedding", json=payload, headers=headers) |
|
|
|
|
|
if response.status_code == 200: |
|
json_response = response.json() |
|
if "embedding" in json_response: |
|
return json_response["embedding"] |
|
else: |
|
st.error("The response from the API did not contain an embedding. Please check the API.") |
|
return None |
|
else: |
|
st.error(f"Failed to retrieve embedding. Status code: {response.status_code}") |
|
return None |
|
|
|
|
|
def generate_response(context): |
|
payload = { |
|
"model": "google/flan-t5-small", |
|
"input_text": f"Given the following context, provide a supportive response: {context}" |
|
} |
|
headers = { |
|
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}" |
|
} |
|
|
|
|
|
response = requests.post(f"{GROQ_API_URL}/generate", json=payload, headers=headers) |
|
|
|
|
|
if response.status_code == 200: |
|
json_response = response.json() |
|
if "text" in json_response: |
|
return json_response["text"] |
|
else: |
|
st.error("The response from the API did not contain a 'text' key.") |
|
return None |
|
else: |
|
st.error(f"Failed to generate response. Status code: {response.status_code}") |
|
return None |
|
|
|
|
|
dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"] |
|
|
|
|
|
@st.cache_resource |
|
def embed_dataset(_dataset): |
|
embeddings = [] |
|
for entry in _dataset: |
|
embedding = retrieve_embedding(entry["Response"]) |
|
if embedding is not None: |
|
embeddings.append(embedding) |
|
return embeddings |
|
|
|
dataset_embeddings = embed_dataset(dataset) |
|
|
|
|
|
def retrieve_response(user_query, dataset, dataset_embeddings, k=5): |
|
query_embedding = retrieve_embedding(user_query) |
|
if query_embedding is None: |
|
st.error("Could not retrieve an embedding for the query.") |
|
return [] |
|
|
|
cos_scores = cosine_similarity([query_embedding], dataset_embeddings)[0] |
|
top_indices = np.argsort(cos_scores)[-k:][::-1] |
|
|
|
retrieved_responses = [] |
|
for idx in top_indices: |
|
retrieved_responses.append(dataset[idx]["Response"]) |
|
return retrieved_responses |
|
|
|
|
|
st.title("Emotional Support Buddy") |
|
st.write("Enter your thoughts or concerns, and I'll provide some comforting words.") |
|
|
|
|
|
user_query = st.text_input("How are you feeling today?") |
|
|
|
if user_query: |
|
|
|
retrieved_responses = retrieve_response(user_query, dataset, dataset_embeddings) |
|
|
|
if retrieved_responses: |
|
|
|
context = " ".join(retrieved_responses) |
|
|
|
|
|
supportive_response = generate_response(context) |
|
|
|
if supportive_response: |
|
st.write("Here's some advice or support for you:") |
|
st.write(supportive_response) |
|
else: |
|
st.write("Sorry, I couldn't generate a response at the moment.") |
|
else: |
|
st.write("Sorry, I couldn't find any relevant responses.") |
|
|