Judge-AI / perspective.py
darthPanda's picture
updated pre 5, 7
dc6ca0d
import os
from openai import OpenAI, AzureOpenAI
from llama_index.llms.groq import Groq
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import time
import re
from mock_constants_dir import pers_mock_constants
import utils
from datetime import date
from prompts import perspective_stage_prompts
import prompt_tracing
import comet_llm
from dotenv import load_dotenv
load_dotenv()
tracking = True
class Perspective():
def __init__(self):
self.stage = 1
self.lawsuit_extracted_text = ""
self.lawsuit_facts = ""
self.additional_message = ""
self.case_fable = ""
self.recommendations = ""
utils.clear_or_create_empty_file("data//pers_logs.txt") # Clear log file
comet_llm.init(
api_key=os.environ.get("COMET_API_KEY"),
project="judgeai-v0",
)
pass
def reset(self):
self.stage = 1
self.lawsuit_extracted_text = ""
self.lawsuit_facts = ""
self.additional_message = ""
self.case_fable = ""
self.recommendations = ""
utils.clear_or_create_empty_file("data//pers_logs.txt") # Clear log file
pass
def llm_function_call(self, prompt, model):
if "gpt" in model:
# self.client = OpenAI(
# # This is the default and can be omitted
# api_key=os.environ.get("OPENAI_API_KEY"),
# )
# chat_completion = self.client.chat.completions.create(
# messages=[
# {
# "role": "user",
# "content": prompt,
# }
# ],
# model=model,
# # model = "gpt-4o-2024-05-13",
# temperature=0
# )
client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version="2024-02-01",
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
)
chat_completion = client.chat.completions.create(
model="gpt-4",
temperature=0,
messages = [
{
"role":"user",
"content":prompt
}])
# print(prompt_tracing.user + ":\n" + prompt)
if prompt_tracing.tracking:
# prompt_tracing.send_prompt_over_discord(prompt)
# prompt_tracing.send_response_over_discord(chat_completion.choices[0].message.content)
comet_llm.log_prompt(
prompt=prompt_tracing.user + ":\n" + prompt,
output=chat_completion.choices[0].message.content,
)
return chat_completion.choices[0].message.content
elif model == "llama3-70b-8192":
llm = Groq(model="llama3-70b-8192", api_key=os.environ.get("GROQ_API_KEY"), temperature=0)
response = llm.complete(prompt)
return response.text
elif model == "mistral-large-azure":
client = MistralClient(
endpoint=os.environ.get("AZURE_AI_MISTRAL_LARGE_ENDPOINT"), api_key=os.environ.get("AZURE_AI_MISTRAL_LARGE_KEY")
)
response = client.chat(
model="azureai",
messages=[
ChatMessage(
role="user",
content=prompt,
)
],
max_tokens=4096,
temperature=0,
)
print(response.usage)
return response.choices[0].message.content
else:
return "no model available"
def stage_1_call(self, input_text):
'''
Function called to execute stage 1
stage 1 - getting facts from lawsuit
'''
gpt_prompt = perspective_stage_prompts.stage_1_prompt
prompt = input_text + "\n\nPrompt:" + gpt_prompt
response = self.llm_function_call(prompt, "gpt-4-0125-preview")
# response = self.llm_function_call(prompt, "mistral-large-azure")
# response = self.llm_function_call(prompt, "llama3-70b-8192")
# response = pers_mock_constants.stage_1_output
self.additional_message = "JudgeAI will now develop case fabula and give recommendations"
self.lawsuit_facts = "\nLawsuit Facts: \n" + response
utils.write_string_to_file("data//pers_logs.txt", response)
return response
def stage_2_call(self):
'''
Function called to execute stage 1
stage 1 - getting facts from lawsuit
'''
gpt_prompt = perspective_stage_prompts.stage_2_prompt
prompt = self.lawsuit_facts + "\n\nPrompt:" + gpt_prompt
response = self.llm_function_call(prompt, "gpt-4-0125-preview")
# response = self.llm_function_call(prompt, "mistral-large-azure")
# response = self.llm_function_call(prompt, "llama3-70b-8192")
# response = pers_mock_constants.stage_2_output
self.case_fable = "\nCase Fable: \n" + response
utils.write_string_to_file("data//pers_logs.txt", response)
return response
def stage_3_call(self):
'''
Function called to execute stage 1
stage 1 - getting facts from lawsuit
'''
gpt_prompt = perspective_stage_prompts.stage_3_prompt
prompt = self.lawsuit_facts + self.case_fable + "\n\nPrompt:" + gpt_prompt
response = self.llm_function_call(prompt, "gpt-4-0125-preview")
# response = self.llm_function_call(prompt, "mistral-large-azure")
# response = self.llm_function_call(prompt, "llama3-70b-8192")
# response = pers_mock_constants.stage_2_output
self.recommendations = "\nRecommendations: \n" + response
utils.write_string_to_file("data//pers_logs.txt", response)
return response
def generate_response(self, prompt):
if prompt.lower().endswith(".png"):
input_text = self.lawsuit_extracted_text
else:
input_text = prompt
# print(self.stage)
if self.stage == 1:
response = self.stage_1_call(input_text)
return response