Mahbodez's picture
Upload 5 files
1d80bec
raw
history blame contribute delete
No virus
11.7 kB
import colorama
from colorama import Fore, Style
import openai
from tenacity import retry, stop_after_attempt, wait_fixed
import json
import os
import tiktoken
import functools as ft
import time
JSON_TEMPLATE = """
{question}
The required key(s) are: {keys}.
Only and only respond with the key(s) and value(s) mentioned above.
Your answer in valid JSON format:\n
"""
MODEL_COST_DICT = {
"gpt-3.5-turbo": {
"input": 0.0015,
"output": 0.002,
},
"gpt-4": {
"input": 0.03,
"output": 0.06,
},
}
def set_api_key(key=None):
"""Sets the OpenAI API key."""
if key is None:
key = os.environ.get("OPENAI_API_KEY")
openai.api_key = key
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this
num_tokens = 0
for message in messages:
num_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
)
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0):
response = openai.ChatCompletion().create(
model=model,
messages=messages,
temperature=temperature,
)
return response["choices"][0]["message"]["content"]
def make_message(role: str, content: str) -> dict:
return {
"role": role,
"content": content,
}
def make_prompt(template: str, **kwargs):
return template.format(**kwargs)
def unravel_messages(messages: list[dict]) -> list[str]:
"""Returns a string representation of a list of messages."""
return [f"{message['role']}: {message['content']}" for message in messages]
class LLM:
def __init__(self, model="gpt-3.5-turbo", temperature=0.0):
self.model = model
self.temperature = temperature
self.token_counter = 0
self.cost = 0.0
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def chat(self, messages: list[dict]):
response = openai.ChatCompletion().create(
model=self.model,
messages=messages,
temperature=self.temperature,
)
self.token_counter += int(response["usage"]["total_tokens"])
self.cost += (
response["usage"]["prompt_tokens"]
/ 1000
* MODEL_COST_DICT[self.model]["input"]
+ response["usage"]["completion_tokens"]
/ 1000
* MODEL_COST_DICT[self.model]["output"]
)
return response["choices"][0]["message"]["content"]
def reset(self):
self.token_counter = 0
self.cost = 0.0
def __call__(self, messages: list[dict]):
return self.chat(messages)
class SummaryMemory:
"""
A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached.
Attributes:
max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs.
messages (list[dict]): A list of messages in the memory.
model (str): The name of the GPT model to use for chat completion.
ai_role (str): The role of the AI in the conversation.
human_role (str): The role of the human in the conversation.
auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached.
"""
# ...
summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}"
def __init__(
self,
system_prompt="",
max_token_limit=4000,
model="gpt-3.5-turbo",
ai_role="answer",
human_role="question/exam",
auto_summarize=False,
):
self.max_token_limit = max_token_limit
self.messages: list[dict] = []
self.model = model
self.ai_role = ai_role
self.human_role = human_role
self.auto_summarize = auto_summarize
self.system_prompt = system_prompt
self.reset()
def reset(self):
self.messages = [self.system_prompt]
def remove_last(self):
if len(self.messages) > 1: # don't remove the system prompt
self.messages.pop()
def remove(
self, index: int
): # don't remove the system prompt and start counting from 1
if index > 0 and index < len(self.messages):
self.messages.pop(index)
def replace(self, index: int, message: dict):
if index > 0 and index < len(self.messages):
self.messages[index] = message
def change_system_prompt(self, new_prompt: str):
self.system_prompt = new_prompt
self.messages[0] = new_prompt
def remove_first(self):
# dont remove the system prompt
if len(self.messages) > 1:
self.messages.pop(1) # remove the first message after the system prompt
def append(self, message: dict):
total_tokens = num_tokens_from_messages(self.messages + [message])
while (
self.auto_summarize and total_tokens > self.max_token_limit
): # keep summarizing until we're under the limit
self.summarize()
total_tokens = num_tokens_from_messages(self.messages + [message])
self.messages.append(message)
def summarize(self):
prompt = make_prompt(
self.summary_template,
user="user",
human_role=self.human_role,
assistant="assistant",
ai_role=self.ai_role,
messages="\n".join(
unravel_messages(self.messages[1:])
), # don't include the system prompt
)
summary = chat(
messages=[make_message("user", prompt)],
model=self.model,
)
self.reset()
self.append(make_message("user", summary))
def get_messages(self):
return self.messages[1:] # don't include the system prompt
def get_unraveled_messages(self):
return unravel_messages(self.messages[1:])
class MemoryBuffer:
"""
A class that manages a buffer of messages and clips them to a maximum token limit.
Attributes:
max_token_limit (int): The maximum number of tokens allowed in the buffer.
messages (list[dict]): A list of messages in the buffer.
"""
def __init__(
self,
system_prompt,
max_token_limit=1000,
):
"""
Initializes a new instance of the MemoryBuffer class.
Args:
max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000.
"""
self.max_token_limit = max_token_limit
self.messages = []
self.system_prompt = system_prompt
self.reset()
def reset(self):
"""
Resets the buffer by clearing all messages.
"""
self.messages = [self.system_prompt]
def add(self, message: dict):
"""
Adds a message to the buffer and clips the buffer to the maximum token limit.
Args:
message (dict): The message to add to the buffer.
"""
total_tokens = num_tokens_from_messages(self.messages + [message])
if total_tokens > self.max_token_limit:
# clip the messages to the max token limit
# from the end of the list
# remove messages from the beginning of the list
# until the total number of tokens is less than the max token limit
while total_tokens > self.max_token_limit:
self.messages = self.messages[1:]
total_tokens = num_tokens_from_messages(self.messages + [message])
self.messages.append(message)
def remove(self, message: dict):
"""
Removes a message from the buffer.
Args:
message (dict): The message to remove from the buffer.
"""
if message in self.messages:
self.messages.remove(message)
def remove_last(self):
"""
Removes the last message from the buffer.
"""
if len(self.messages) > 0:
self.messages.pop()
def remove_first(self):
"""
Removes the first message from the buffer.
"""
if len(self.messages) > 0:
self.messages.pop(0)
def json2dict(string: str) -> dict:
"""Returns a dictionary of variables from a string containing JSON."""
try:
return json.loads(string)
except json.decoder.JSONDecodeError:
print("Error: JSONDecodeError")
return {}
def print_help(num_nodes, color):
"""
Prints the help message for the AI assistant.
"""
colorama.init()
print(color + "The AI assistant presents a clinical case and asks for a diagnosis.")
print(
color + "You need to explore the case by asking questions to the AI assistant."
)
print(
color
+ "You have to ask questions in a logical order, conforming to the clinical guidelines."
)
print(
color
+ "You need to minimize the number of jump between subjects, while covering as many subjects as possible."
)
print(color + f"there are a total of {num_nodes} visitable nodes in the tree")
print(
color
+ "you have to explore the tree as much as possible while avoiding jumps and travelling excessively."
)
print(Style.RESET_ALL)
def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict:
prompt = make_prompt(template=template, **kwargs)
message = make_message(role, prompt)
return message
# a debugging decorator and use functools to preserve the function name and docstring
# the decorator gets DEBUG as an argument to turn on or off debugging
def debug(DEBUG, print_func, measure_time=True):
def decorator(func):
@ft.wraps(func)
def wrapper(*args, **kwargs):
if DEBUG:
print_func(f"\nCalling {func.__name__}")
if measure_time and DEBUG:
start = time.time()
result = func(*args, **kwargs)
if measure_time and DEBUG:
end = time.time()
print_func(f"Elapsed time: {end - start:.2f}s")
if DEBUG:
print_func(f"Returning {func.__name__}")
return result
return wrapper
return decorator
# to use the decorator, add @debug(DEBUG) above the function definition