|
from literal_thread_manager import LiteralThreadManager |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
import os |
|
from typing import List, Dict |
|
|
|
load_dotenv() |
|
|
|
|
|
class ResolutionLogic: |
|
""" |
|
The ResolutionLogic class is designed to handle the resolution of conflicts within conversation threads. |
|
It integrates with the LiteralThreadManager to manage chat history and utilizes the OpenAI API to generate |
|
thoughtful and empathetic responses to help resolve the conflict between partners. The class also provides |
|
methods for summarizing conflict topics and ensuring both parties have participated sufficiently in the conversation. |
|
""" |
|
|
|
def __init__(self): |
|
self.thread_manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY")) |
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
self.model = "gpt-3.5-turbo" |
|
|
|
def resolve_conflict(self, thread_id): |
|
""" |
|
Resolves a conflict by analyzing the conversation threads of both partners and generating a response. |
|
|
|
Args: |
|
thread_id (str): The thread ID of one of the partners. |
|
|
|
Returns: |
|
str: The generated resolution advice from the OpenAI model. |
|
""" |
|
other_partner_thread_id = self.thread_manager.get_other_partner_thread_id(thread_id) |
|
thread_content = self.thread_manager.filter_thread_by_id(thread_id) |
|
other_partner_thread_content = self.thread_manager.filter_thread_by_id(other_partner_thread_id) |
|
|
|
chat_history_partner1 = self.thread_manager.extract_chat_history_from_thread(thread_content) |
|
chat_history_partner2 = self.thread_manager.extract_chat_history_from_thread(other_partner_thread_content) |
|
|
|
combined_chat = chat_history_partner1 + chat_history_partner2 |
|
|
|
print(f"Combined chat: {combined_chat}") |
|
|
|
completion = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=[ |
|
{"role": "system", "content": """ |
|
You are an expert relationship counselor. |
|
The user is providing various perspectives on a conflict within a relationship. |
|
Your task is to analyze these perspectives and offer thoughtful, empathetic, and actionable advice |
|
to help resolve the conflict. Consider the emotions and viewpoints of all parties involved and |
|
suggest a resolution that promotes understanding, communication, and mutual respect. |
|
|
|
Please address them both with their names, to make it more personal. |
|
Try to avoid leaking information you got in the prompt. |
|
"""}, |
|
*combined_chat |
|
] |
|
) |
|
return completion.choices[0].message.content |
|
|
|
def intervention(self, thread_id) -> bool | str: |
|
""" |
|
Checks if both partners have answered enough questions to resolve the conflict. |
|
If they have, resolves the conflict; otherwise, returns False. |
|
|
|
Args: |
|
thread_id (str): The thread ID of one of the partners. |
|
|
|
Returns: |
|
bool | str: False if the conflict is not ready to be resolved, otherwise the resolution advice. |
|
""" |
|
if self.thread_manager.is_conflict_resolved(thread_id): |
|
return False |
|
|
|
partner2_thread_id = self.thread_manager.get_other_partner_thread_id(thread_id) |
|
num_partner1_messages = self.thread_manager.count_llm_messages(thread_id) |
|
num_partner2_messages = self.thread_manager.count_llm_messages(partner2_thread_id) |
|
|
|
if num_partner1_messages >= 3 and num_partner2_messages >= 3: |
|
resolution = self.resolve_conflict(thread_id) |
|
self.thread_manager.send_message(partner2_thread_id, resolution) |
|
self.thread_manager.set_conflict_resolved(thread_id, True) |
|
self.thread_manager.set_conflict_resolved(partner2_thread_id, True) |
|
return resolution |
|
return False |
|
|
|
def summarize_conflict_topic(self, my_name: str, partner_name: str, topic: str): |
|
""" |
|
Summarizes the conflict topic in a proper way for the other partner. |
|
|
|
Args: |
|
my_name (str): The name of the user. |
|
partner_name (str): The name of the partner. |
|
topic (str): The topic of the conflict. |
|
|
|
Returns: |
|
str: The generated summary of the conflict topic. |
|
""" |
|
summary = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=[ |
|
{"role": "system", "content": f""" |
|
As an expert relationship counselor specialized in couple therapy, |
|
you are helping the couple, {my_name} and {partner_name}, to resolve a conflict or communicate better. |
|
{partner_name} has shared a topic with you about {topic}. |
|
Now, your task is to talk to {my_name} and know his perspective on the topic, |
|
ensuring privacy and not revealing explicit details about what the other person thought. |
|
Ask {my_name} a question to get the conversation started. |
|
Keep in mind, this conversation is ONLY between you and {my_name}. |
|
"""}, |
|
{"role": "user", "content": f""" |
|
{topic} |
|
"""} |
|
] |
|
) |
|
return summary.choices[0].message.content |
|
|
|
|
|
def get_summary(chat_history: List[Dict[str, str]], perspective: str) -> str: |
|
""" |
|
Generates a summary of the chat conversation from a specific perspective. |
|
|
|
Args: |
|
chat_history (List[Dict[str, str]]): The chat history to summarize. |
|
perspective (str): The perspective from which to summarize the chat. |
|
|
|
Returns: |
|
str: The generated summary. |
|
""" |
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
|
chat_history_str = "\n".join([f"{msg['role']} - {msg['name']}: {msg['content']}" for msg in chat_history]) |
|
|
|
prompt = { |
|
"model": "gpt-3.5-turbo", |
|
"messages": [ |
|
{"role": "system", |
|
"content": f""" |
|
You are an assistant summarizing a chat conversation to help resolve a conflict about household chores. |
|
Please summarize the chat from {perspective}'s perspective, |
|
ensuring privacy and not revealing explicit details about what the other person thought. |
|
Focus on the key points discussed and any constructive advice given."""}, |
|
{"role": "user", "content": chat_history_str} |
|
] |
|
} |
|
|
|
response = client.chat.completions.create(**prompt) |
|
return response.choices[0].message.content |
|
|
|
|
|
def get_thread_and_summarize(manager: LiteralThreadManager, thread_id: str): |
|
""" |
|
Retrieves the thread and generates a summary of the chat history. |
|
|
|
Args: |
|
manager (LiteralThreadManager): The manager to handle thread operations. |
|
thread_id (str): The thread ID to retrieve and summarize. |
|
|
|
Returns: |
|
str: The generated summary of the chat history. |
|
""" |
|
thread_content = manager.filter_thread_by_id(thread_id) |
|
name = manager.get_user_name_from_thread(thread_content) |
|
chat_history = manager.extract_chat_history_from_thread(thread_content) |
|
return get_summary(chat_history, name) |
|
|
|
|
|
def main(): |
|
|
|
manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY")) |
|
|
|
tom_thread = '83547413-b1bd-4609-9af7-b856ef2108a2' |
|
linda_thread = '79c9b29a-fa26-4860-ba94-6cddb0581604' |
|
|
|
tom_summary = get_thread_and_summarize(manager, tom_thread) |
|
linda_summary = get_thread_and_summarize(manager, linda_thread) |
|
|
|
print(f"Summary for Tom: \n{tom_summary}") |
|
print(f"Summary for Linda: \n{linda_summary}") |
|
|
|
|
|
resplution_logic = ResolutionLogic(manager) |
|
resolution = resplution_logic.resolve_conflict(tom_thread) |
|
print(f"Resolution: \n{resolution}") |
|
|
|
manager.literal_client.api.upsert_thread(id="resolution_1_tom", name="resolution_tom_linda", |
|
participant_id="tom", tags="tom_linda") |
|
manager.literal_client.api.upsert_thread(id="resolution_1_linda", name="resolution_tom_linda", |
|
participant_id="linda", tags="tom_linda") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|