import colorama from colorama import Fore, Style import openai from tenacity import retry, stop_after_attempt, wait_fixed import json import os import tiktoken import functools as ft import time JSON_TEMPLATE = """ {question} The required key(s) are: {keys}. Only and only respond with the key(s) and value(s) mentioned above. Your answer in valid JSON format:\n """ MODEL_COST_DICT = { "gpt-3.5-turbo": { "input": 0.0015, "output": 0.002, }, "gpt-4": { "input": 0.03, "output": 0.06, }, } def set_api_key(key=None): """Sets the OpenAI API key.""" if key is None: key = os.environ.get("OPENAI_API_KEY") openai.api_key = key def num_tokens_from_string(string: str, encoding_name: str) -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"): """Returns the number of tokens used by a list of messages.""" try: encoding = tiktoken.encoding_for_model(model) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this num_tokens = 0 for message in messages: num_tokens += ( 4 # every message follows {role/name}\n{content}\n ) for key, value in message.items(): num_tokens += len(encoding.encode(value)) if key == "name": # if there's a name, the role is omitted num_tokens += -1 # role is always required and always 1 token num_tokens += 2 # every reply is primed with assistant return num_tokens else: raise NotImplementedError( f"""num_tokens_from_messages() is not presently implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" ) @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0): response = openai.ChatCompletion().create( model=model, messages=messages, temperature=temperature, ) return response["choices"][0]["message"]["content"] def make_message(role: str, content: str) -> dict: return { "role": role, "content": content, } def make_prompt(template: str, **kwargs): return template.format(**kwargs) def unravel_messages(messages: list[dict]) -> list[str]: """Returns a string representation of a list of messages.""" return [f"{message['role']}: {message['content']}" for message in messages] class LLM: def __init__(self, model="gpt-3.5-turbo", temperature=0.0): self.model = model self.temperature = temperature self.token_counter = 0 self.cost = 0.0 @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) def chat(self, messages: list[dict]): response = openai.ChatCompletion().create( model=self.model, messages=messages, temperature=self.temperature, ) self.token_counter += int(response["usage"]["total_tokens"]) self.cost += ( response["usage"]["prompt_tokens"] / 1000 * MODEL_COST_DICT[self.model]["input"] + response["usage"]["completion_tokens"] / 1000 * MODEL_COST_DICT[self.model]["output"] ) return response["choices"][0]["message"]["content"] def reset(self): self.token_counter = 0 self.cost = 0.0 def __call__(self, messages: list[dict]): return self.chat(messages) class SummaryMemory: """ A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached. Attributes: max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs. messages (list[dict]): A list of messages in the memory. model (str): The name of the GPT model to use for chat completion. ai_role (str): The role of the AI in the conversation. human_role (str): The role of the human in the conversation. auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached. """ # ... summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}" def __init__( self, system_prompt="", max_token_limit=4000, model="gpt-3.5-turbo", ai_role="answer", human_role="question/exam", auto_summarize=False, ): self.max_token_limit = max_token_limit self.messages: list[dict] = [] self.model = model self.ai_role = ai_role self.human_role = human_role self.auto_summarize = auto_summarize self.system_prompt = system_prompt self.reset() def reset(self): self.messages = [self.system_prompt] def remove_last(self): if len(self.messages) > 1: # don't remove the system prompt self.messages.pop() def remove( self, index: int ): # don't remove the system prompt and start counting from 1 if index > 0 and index < len(self.messages): self.messages.pop(index) def replace(self, index: int, message: dict): if index > 0 and index < len(self.messages): self.messages[index] = message def change_system_prompt(self, new_prompt: str): self.system_prompt = new_prompt self.messages[0] = new_prompt def remove_first(self): # dont remove the system prompt if len(self.messages) > 1: self.messages.pop(1) # remove the first message after the system prompt def append(self, message: dict): total_tokens = num_tokens_from_messages(self.messages + [message]) while ( self.auto_summarize and total_tokens > self.max_token_limit ): # keep summarizing until we're under the limit self.summarize() total_tokens = num_tokens_from_messages(self.messages + [message]) self.messages.append(message) def summarize(self): prompt = make_prompt( self.summary_template, user="user", human_role=self.human_role, assistant="assistant", ai_role=self.ai_role, messages="\n".join( unravel_messages(self.messages[1:]) ), # don't include the system prompt ) summary = chat( messages=[make_message("user", prompt)], model=self.model, ) self.reset() self.append(make_message("user", summary)) def get_messages(self): return self.messages[1:] # don't include the system prompt def get_unraveled_messages(self): return unravel_messages(self.messages[1:]) class MemoryBuffer: """ A class that manages a buffer of messages and clips them to a maximum token limit. Attributes: max_token_limit (int): The maximum number of tokens allowed in the buffer. messages (list[dict]): A list of messages in the buffer. """ def __init__( self, system_prompt, max_token_limit=1000, ): """ Initializes a new instance of the MemoryBuffer class. Args: max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000. """ self.max_token_limit = max_token_limit self.messages = [] self.system_prompt = system_prompt self.reset() def reset(self): """ Resets the buffer by clearing all messages. """ self.messages = [self.system_prompt] def add(self, message: dict): """ Adds a message to the buffer and clips the buffer to the maximum token limit. Args: message (dict): The message to add to the buffer. """ total_tokens = num_tokens_from_messages(self.messages + [message]) if total_tokens > self.max_token_limit: # clip the messages to the max token limit # from the end of the list # remove messages from the beginning of the list # until the total number of tokens is less than the max token limit while total_tokens > self.max_token_limit: self.messages = self.messages[1:] total_tokens = num_tokens_from_messages(self.messages + [message]) self.messages.append(message) def remove(self, message: dict): """ Removes a message from the buffer. Args: message (dict): The message to remove from the buffer. """ if message in self.messages: self.messages.remove(message) def remove_last(self): """ Removes the last message from the buffer. """ if len(self.messages) > 0: self.messages.pop() def remove_first(self): """ Removes the first message from the buffer. """ if len(self.messages) > 0: self.messages.pop(0) def json2dict(string: str) -> dict: """Returns a dictionary of variables from a string containing JSON.""" try: return json.loads(string) except json.decoder.JSONDecodeError: print("Error: JSONDecodeError") return {} def print_help(num_nodes, color): """ Prints the help message for the AI assistant. """ colorama.init() print(color + "The AI assistant presents a clinical case and asks for a diagnosis.") print( color + "You need to explore the case by asking questions to the AI assistant." ) print( color + "You have to ask questions in a logical order, conforming to the clinical guidelines." ) print( color + "You need to minimize the number of jump between subjects, while covering as many subjects as possible." ) print(color + f"there are a total of {num_nodes} visitable nodes in the tree") print( color + "you have to explore the tree as much as possible while avoiding jumps and travelling excessively." ) print(Style.RESET_ALL) def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict: prompt = make_prompt(template=template, **kwargs) message = make_message(role, prompt) return message # a debugging decorator and use functools to preserve the function name and docstring # the decorator gets DEBUG as an argument to turn on or off debugging def debug(DEBUG, print_func, measure_time=True): def decorator(func): @ft.wraps(func) def wrapper(*args, **kwargs): if DEBUG: print_func(f"\nCalling {func.__name__}") if measure_time and DEBUG: start = time.time() result = func(*args, **kwargs) if measure_time and DEBUG: end = time.time() print_func(f"Elapsed time: {end - start:.2f}s") if DEBUG: print_func(f"Returning {func.__name__}") return result return wrapper return decorator # to use the decorator, add @debug(DEBUG) above the function definition