Spaces:
Running
Running
import torch | |
import numpy as np | |
import io | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
from datetime import datetime | |
from PIL import Image | |
import os | |
from datetime import datetime | |
from openai import OpenAI | |
from ai71 import AI71 | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
dials_embeddings = pd.read_pickle('https://huggingface.co/datasets/vsrinivas/CBT_dialogue_embed_ds/resolve/main/kaggle_therapy_embeddings.pkl') | |
with open ('emotion_group_labels.txt') as file: | |
emotion_group_labels = file.read().splitlines() | |
embed_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') | |
classifier = pipeline("zero-shot-classification", model ='facebook/bart-large-mnli') | |
AI71_BASE_URL = "https://api.ai71.ai/v1/" | |
AI71_API_KEY = os.getenv('AI71_API_KEY') | |
# Detect emotions from patient dialogues | |
def detect_emotions(text): | |
emotion = classifier(text, candidate_labels=emotion_group_labels, batch_size=16) | |
top_5_scores = [i/sum(emotion['scores'][:5]) for i in emotion['scores'][:5]] | |
top_5_emotions = emotion['labels'][:5] | |
emotion_set = {l: "{:.2%}".format(s) for l, s in zip(top_5_emotions, top_5_scores)} | |
return emotion_set | |
# Measure cosine similarity between a pair of vectors | |
def cosine_distance(vec1,vec2): | |
cosine = (np.dot(vec1, vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) | |
return cosine | |
# Generate an image of trigger emotions | |
def generate_triggers_img(items): | |
labels = list(items.keys()) | |
values = [float(v.strip('%')) for v in items.values()] # Convert to float for plotting | |
new_items = {k:v for k, v in zip(labels, values)} | |
new_items = dict(sorted(new_items.items(), key=lambda item: item[1])) | |
labels = list(new_items.keys()) | |
values = list(new_items.values()) | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
colors = plt.cm.viridis(np.linspace(0, 1, len(labels))) | |
bars = ax.barh(labels, values, color=colors) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
ax.tick_params(axis='y', labelsize=18) | |
ax.xaxis.set_visible(False) | |
ax.yaxis.set_ticks_position('none') | |
for bar in bars: | |
width = bar.get_width() | |
ax.text(width, bar.get_y() + bar.get_height()/2, f'{width:.2f}%', | |
ha='left', va='center', fontweight='bold', fontsize=18) | |
plt.tight_layout() | |
plt.savefig('triggeres.png') | |
triggers_img = Image.open('triggeres.png') | |
return triggers_img | |
def get_doc_response_emotions(user_message, therapy_session_conversation): | |
user_messages = [] | |
user_messages.append(user_message) | |
emotion_set = detect_emotions(user_message) | |
print(emotion_set) | |
emotions_msg = generate_triggers_img(emotion_set) | |
user_embedding = embed_model.encode(user_message, device='cuda' if torch.cuda.is_available() else 'cpu') | |
similarities =[] | |
for v in dials_embeddings['embeddings']: | |
similarities.append(cosine_distance(user_embedding,v)) | |
top_match_index = similarities.index(max(similarities)) | |
# doc_response = dials_embeddings.iloc[top_match_index+1]['Doctor'] | |
doc_response = dials_embeddings.iloc[top_match_index]['Doctor'] | |
therapy_session_conversation.append(["User: "+user_message, "Therapist: "+doc_response]) | |
# session_conversation.extend(["User: "+user_message, "Therapist: "+doc_response]) | |
print(f"User's message: {user_message}") | |
print(f"RAG Matching message: {dials_embeddings.iloc[top_match_index]['Patient']}") | |
print(f"Therapist's response: {dials_embeddings.iloc[top_match_index]['Doctor']}\n\n") | |
return '', therapy_session_conversation, emotions_msg | |
def summarize_and_recommend(therapy_session_conversation): | |
session_conversation = list(therapy_session_conversation.value) | |
# session_conversation = [item[0] for item in session_conversation] | |
print("Session conversation:", session_conversation) | |
session_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
session_conversation_processed = session_conversation.copy() | |
session_conversation_processed.insert(0, "Session_time: "+session_time) | |
session_conversation_processed ='\n'.join(session_conversation_processed) | |
print("session_conversation_processed:", session_conversation_processed) | |
full_summary = "" | |
for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
model="tiiuae/falcon-180b-chat", | |
messages=[ | |
{"role": "system", "content": """You are an Expert Cognitive Behavioural Therapist and Precis writer. | |
Summarize 'STRICTLY' the below user content <<<session_conversation_processed>>> 'ONLY' into useful, ethical, relevant and realistic phrases with a format | |
Session Time: | |
Summary of the patient messages: #in two to four sentences | |
Summary of therapist messages: #in two to three sentences: | |
Summary of the whole session: # in two to three sentences. Ensure the entire session summary strictly does not exceed 100 tokens."""}, | |
{"role": "user", "content": session_conversation_processed}, | |
], | |
stream=True, | |
): | |
if chunk.choices[0].delta.content: | |
summary = chunk.choices[0].delta.content | |
full_summary += summary | |
full_summary = full_summary.replace('User:', '').strip() | |
print("\n") | |
print("Full summary:", full_summary) | |
full_recommendations = "" | |
for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
model="tiiuae/falcon-180b-chat", | |
messages=[ | |
{"role": "system", "content": """You are an expert Cognitive Behavioural Therapist. | |
Based on 'STRICTLY' the full summary <<<full_summary>>> 'ONLY' provide clinically valid, useful, appropriate action plan for the Patient as a bullted list. | |
The list shall contain both medical and non medical prescriptions, dos and donts. The format of response shall be in passive voice with proper tense. | |
- The patient is referred to........ #in one sentence | |
- The patient is advised to ........ #in one sentence | |
- The patient is refrained from........ #in one sentence | |
- It is suggested that tha patient ........ #in one sentence | |
- Scheduled a follow-up session with the patient........#in one sentence | |
*Ensure the list contains NOT MORE THAN 7 points"""}, | |
{"role": "user", "content": full_summary}, | |
], | |
stream=True, | |
): | |
if chunk.choices[0].delta.content: | |
rec = chunk.choices[0].delta.content | |
full_recommendations += rec | |
full_recommendations = full_recommendations.replace('User:', '').strip() | |
print("\n") | |
print("Full recommendations:", full_recommendations) | |
therapy_session_conversation=[] | |
return full_summary, full_recommendations | |
# class process_session(): | |
# def __init__(self): | |
# self.session_conversation=[] | |
# def get_doc_response_emotions(self, user_message, therapy_session_conversation): | |
# user_messages = [] | |
# user_messages.append(user_message) | |
# emotion_set = detect_emotions(user_message) | |
# print(emotion_set) | |
# emotions_msg = generate_triggers_img(emotion_set) | |
# user_embedding = embed_model.encode(user_message, device='cuda' if torch.cuda.is_available() else 'cpu') | |
# similarities =[] | |
# for v in dials_embeddings['embeddings']: | |
# similarities.append(cosine_distance(user_embedding,v)) | |
# top_match_index = similarities.index(max(similarities)) | |
# # doc_response = dials_embeddings.iloc[top_match_index+1]['Doctor'] | |
# doc_response = dials_embeddings.iloc[top_match_index]['Doctor'] | |
# therapy_session_conversation.append(["User: "+user_message, "Therapist: "+doc_response]) | |
# self.session_conversation.extend(["User: "+user_message, "Therapist: "+doc_response]) | |
# print(f"User's message: {user_message}") | |
# print(f"RAG Matching message: {dials_embeddings.iloc[top_match_index]['Patient']}") | |
# print(f"Therapist's response: {dials_embeddings.iloc[top_match_index]['Doctor']}\n\n") | |
# return '', therapy_session_conversation, emotions_msg | |
# def summarize_and_recommend(self): | |
# session_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
# session_conversation_processed = self.session_conversation.copy() | |
# session_conversation_processed.insert(0, "Session_time: "+session_time) | |
# session_conversation_processed ='\n'.join(session_conversation_processed) | |
# print("Session conversation:", session_conversation_processed) | |
# full_summary = "" | |
# for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
# model="tiiuae/falcon-180b-chat", | |
# messages=[ | |
# {"role": "system", "content": """You are an Expert Cognitive Behavioural Therapist and Precis writer. | |
# Summarize the below user content <<<session_conversation_processed>>> into useful, ethical, relevant and realistic phrases with a format | |
# Session Time: | |
# Summary of the patient messages: #in two to four sentences | |
# Summary of therapist messages: #in two to three sentences: | |
# Summary of the whole session: # in two to three sentences. Ensure the entire session summary strictly does not exceed 100 tokens."""}, | |
# {"role": "user", "content": session_conversation_processed}, | |
# ], | |
# stream=True, | |
# ): | |
# if chunk.choices[0].delta.content: | |
# summary = chunk.choices[0].delta.content | |
# full_summary += summary | |
# full_summary = full_summary.replace('User:', '').strip() | |
# print("\n") | |
# print("Full summary:", full_summary) | |
# full_recommendations = "" | |
# for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
# model="tiiuae/falcon-180b-chat", | |
# messages=[ | |
# {"role": "system", "content": """You are an expert Cognitive Behavioural Therapist. | |
# Based on the full summary <<<full_summary>>> provide clinically valid, useful, appropriate action plan for the Patient as a bullted list. | |
# The list shall contain both medical and non medical prescriptions, dos and donts. The format of response shall be in passive voice with proper tense. | |
# - The patient is referred to........ #in one sentence | |
# - The patient is advised to ........ #in one sentence | |
# - The patient is refrained from........ #in one sentence | |
# - It is suggested that tha patient ........ #in one sentence | |
# - Scheduled a follow-up session with the patient........#in one sentence | |
# *Ensure the list contains NOT MORE THAN 7 points"""}, | |
# {"role": "user", "content": full_summary}, | |
# ], | |
# stream=True, | |
# ): | |
# if chunk.choices[0].delta.content: | |
# rec = chunk.choices[0].delta.content | |
# full_recommendations += rec | |
# full_recommendations = full_recommendations.replace('User:', '').strip() | |
# print("\n") | |
# print("Full recommendations:", full_recommendations) | |
# self.session_conversation=[] | |
# return full_summary, full_recommendations | |