Spaces:
Running
Running
import os | |
from openai import OpenAI, AzureOpenAI | |
from llama_index.llms.groq import Groq | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
import time | |
import re | |
from mock_constants_dir import pers_mock_constants | |
import utils | |
from datetime import date | |
from prompts import perspective_stage_prompts | |
import prompt_tracing | |
import comet_llm | |
from dotenv import load_dotenv | |
load_dotenv() | |
tracking = True | |
class Perspective(): | |
def __init__(self): | |
self.stage = 1 | |
self.lawsuit_extracted_text = "" | |
self.lawsuit_facts = "" | |
self.additional_message = "" | |
self.case_fable = "" | |
self.recommendations = "" | |
utils.clear_or_create_empty_file("data//pers_logs.txt") # Clear log file | |
comet_llm.init( | |
api_key=os.environ.get("COMET_API_KEY"), | |
project="judgeai-v0", | |
) | |
pass | |
def reset(self): | |
self.stage = 1 | |
self.lawsuit_extracted_text = "" | |
self.lawsuit_facts = "" | |
self.additional_message = "" | |
self.case_fable = "" | |
self.recommendations = "" | |
utils.clear_or_create_empty_file("data//pers_logs.txt") # Clear log file | |
pass | |
def llm_function_call(self, prompt, model): | |
if "gpt" in model: | |
# self.client = OpenAI( | |
# # This is the default and can be omitted | |
# api_key=os.environ.get("OPENAI_API_KEY"), | |
# ) | |
# chat_completion = self.client.chat.completions.create( | |
# messages=[ | |
# { | |
# "role": "user", | |
# "content": prompt, | |
# } | |
# ], | |
# model=model, | |
# # model = "gpt-4o-2024-05-13", | |
# temperature=0 | |
# ) | |
client = AzureOpenAI( | |
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | |
api_version="2024-02-01", | |
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") | |
) | |
chat_completion = client.chat.completions.create( | |
model="gpt-4", | |
temperature=0, | |
messages = [ | |
{ | |
"role":"user", | |
"content":prompt | |
}]) | |
# print(prompt_tracing.user + ":\n" + prompt) | |
if prompt_tracing.tracking: | |
# prompt_tracing.send_prompt_over_discord(prompt) | |
# prompt_tracing.send_response_over_discord(chat_completion.choices[0].message.content) | |
comet_llm.log_prompt( | |
prompt=prompt_tracing.user + ":\n" + prompt, | |
output=chat_completion.choices[0].message.content, | |
) | |
return chat_completion.choices[0].message.content | |
elif model == "llama3-70b-8192": | |
llm = Groq(model="llama3-70b-8192", api_key=os.environ.get("GROQ_API_KEY"), temperature=0) | |
response = llm.complete(prompt) | |
return response.text | |
elif model == "mistral-large-azure": | |
client = MistralClient( | |
endpoint=os.environ.get("AZURE_AI_MISTRAL_LARGE_ENDPOINT"), api_key=os.environ.get("AZURE_AI_MISTRAL_LARGE_KEY") | |
) | |
response = client.chat( | |
model="azureai", | |
messages=[ | |
ChatMessage( | |
role="user", | |
content=prompt, | |
) | |
], | |
max_tokens=4096, | |
temperature=0, | |
) | |
print(response.usage) | |
return response.choices[0].message.content | |
else: | |
return "no model available" | |
def stage_1_call(self, input_text): | |
''' | |
Function called to execute stage 1 | |
stage 1 - getting facts from lawsuit | |
''' | |
gpt_prompt = perspective_stage_prompts.stage_1_prompt | |
prompt = input_text + "\n\nPrompt:" + gpt_prompt | |
response = self.llm_function_call(prompt, "gpt-4-0125-preview") | |
# response = self.llm_function_call(prompt, "mistral-large-azure") | |
# response = self.llm_function_call(prompt, "llama3-70b-8192") | |
# response = pers_mock_constants.stage_1_output | |
self.additional_message = "JudgeAI will now develop case fabula and give recommendations" | |
self.lawsuit_facts = "\nLawsuit Facts: \n" + response | |
utils.write_string_to_file("data//pers_logs.txt", response) | |
return response | |
def stage_2_call(self): | |
''' | |
Function called to execute stage 1 | |
stage 1 - getting facts from lawsuit | |
''' | |
gpt_prompt = perspective_stage_prompts.stage_2_prompt | |
prompt = self.lawsuit_facts + "\n\nPrompt:" + gpt_prompt | |
response = self.llm_function_call(prompt, "gpt-4-0125-preview") | |
# response = self.llm_function_call(prompt, "mistral-large-azure") | |
# response = self.llm_function_call(prompt, "llama3-70b-8192") | |
# response = pers_mock_constants.stage_2_output | |
self.case_fable = "\nCase Fable: \n" + response | |
utils.write_string_to_file("data//pers_logs.txt", response) | |
return response | |
def stage_3_call(self): | |
''' | |
Function called to execute stage 1 | |
stage 1 - getting facts from lawsuit | |
''' | |
gpt_prompt = perspective_stage_prompts.stage_3_prompt | |
prompt = self.lawsuit_facts + self.case_fable + "\n\nPrompt:" + gpt_prompt | |
response = self.llm_function_call(prompt, "gpt-4-0125-preview") | |
# response = self.llm_function_call(prompt, "mistral-large-azure") | |
# response = self.llm_function_call(prompt, "llama3-70b-8192") | |
# response = pers_mock_constants.stage_2_output | |
self.recommendations = "\nRecommendations: \n" + response | |
utils.write_string_to_file("data//pers_logs.txt", response) | |
return response | |
def generate_response(self, prompt): | |
if prompt.lower().endswith(".png"): | |
input_text = self.lawsuit_extracted_text | |
else: | |
input_text = prompt | |
# print(self.stage) | |
if self.stage == 1: | |
response = self.stage_1_call(input_text) | |
return response |