Spaces:
Runtime error
Runtime error
import colorama | |
from colorama import Fore, Style | |
import openai | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import json | |
import os | |
import tiktoken | |
import functools as ft | |
import time | |
JSON_TEMPLATE = """ | |
{question} | |
The required key(s) are: {keys}. | |
Only and only respond with the key(s) and value(s) mentioned above. | |
Your answer in valid JSON format:\n | |
""" | |
MODEL_COST_DICT = { | |
"gpt-3.5-turbo": { | |
"input": 0.0015, | |
"output": 0.002, | |
}, | |
"gpt-4": { | |
"input": 0.03, | |
"output": 0.06, | |
}, | |
} | |
def set_api_key(key=None): | |
"""Sets the OpenAI API key.""" | |
if key is None: | |
key = os.environ.get("OPENAI_API_KEY") | |
openai.api_key = key | |
def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.get_encoding(encoding_name) | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"): | |
"""Returns the number of tokens used by a list of messages.""" | |
try: | |
encoding = tiktoken.encoding_for_model(model) | |
except KeyError: | |
encoding = tiktoken.get_encoding("cl100k_base") | |
if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this | |
num_tokens = 0 | |
for message in messages: | |
num_tokens += ( | |
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n | |
) | |
for key, value in message.items(): | |
num_tokens += len(encoding.encode(value)) | |
if key == "name": # if there's a name, the role is omitted | |
num_tokens += -1 # role is always required and always 1 token | |
num_tokens += 2 # every reply is primed with <im_start>assistant | |
return num_tokens | |
else: | |
raise NotImplementedError( | |
f"""num_tokens_from_messages() is not presently implemented for model {model}. | |
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" | |
) | |
def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0): | |
response = openai.ChatCompletion().create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
) | |
return response["choices"][0]["message"]["content"] | |
def make_message(role: str, content: str) -> dict: | |
return { | |
"role": role, | |
"content": content, | |
} | |
def make_prompt(template: str, **kwargs): | |
return template.format(**kwargs) | |
def unravel_messages(messages: list[dict]) -> list[str]: | |
"""Returns a string representation of a list of messages.""" | |
return [f"{message['role']}: {message['content']}" for message in messages] | |
class LLM: | |
def __init__(self, model="gpt-3.5-turbo", temperature=0.0): | |
self.model = model | |
self.temperature = temperature | |
self.token_counter = 0 | |
self.cost = 0.0 | |
def chat(self, messages: list[dict]): | |
response = openai.ChatCompletion().create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
) | |
self.token_counter += int(response["usage"]["total_tokens"]) | |
self.cost += ( | |
response["usage"]["prompt_tokens"] | |
/ 1000 | |
* MODEL_COST_DICT[self.model]["input"] | |
+ response["usage"]["completion_tokens"] | |
/ 1000 | |
* MODEL_COST_DICT[self.model]["output"] | |
) | |
return response["choices"][0]["message"]["content"] | |
def reset(self): | |
self.token_counter = 0 | |
self.cost = 0.0 | |
def __call__(self, messages: list[dict]): | |
return self.chat(messages) | |
class SummaryMemory: | |
""" | |
A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached. | |
Attributes: | |
max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs. | |
messages (list[dict]): A list of messages in the memory. | |
model (str): The name of the GPT model to use for chat completion. | |
ai_role (str): The role of the AI in the conversation. | |
human_role (str): The role of the human in the conversation. | |
auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached. | |
""" | |
# ... | |
summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}" | |
def __init__( | |
self, | |
system_prompt="", | |
max_token_limit=4000, | |
model="gpt-3.5-turbo", | |
ai_role="answer", | |
human_role="question/exam", | |
auto_summarize=False, | |
): | |
self.max_token_limit = max_token_limit | |
self.messages: list[dict] = [] | |
self.model = model | |
self.ai_role = ai_role | |
self.human_role = human_role | |
self.auto_summarize = auto_summarize | |
self.system_prompt = system_prompt | |
self.reset() | |
def reset(self): | |
self.messages = [self.system_prompt] | |
def remove_last(self): | |
if len(self.messages) > 1: # don't remove the system prompt | |
self.messages.pop() | |
def remove( | |
self, index: int | |
): # don't remove the system prompt and start counting from 1 | |
if index > 0 and index < len(self.messages): | |
self.messages.pop(index) | |
def replace(self, index: int, message: dict): | |
if index > 0 and index < len(self.messages): | |
self.messages[index] = message | |
def change_system_prompt(self, new_prompt: str): | |
self.system_prompt = new_prompt | |
self.messages[0] = new_prompt | |
def remove_first(self): | |
# dont remove the system prompt | |
if len(self.messages) > 1: | |
self.messages.pop(1) # remove the first message after the system prompt | |
def append(self, message: dict): | |
total_tokens = num_tokens_from_messages(self.messages + [message]) | |
while ( | |
self.auto_summarize and total_tokens > self.max_token_limit | |
): # keep summarizing until we're under the limit | |
self.summarize() | |
total_tokens = num_tokens_from_messages(self.messages + [message]) | |
self.messages.append(message) | |
def summarize(self): | |
prompt = make_prompt( | |
self.summary_template, | |
user="user", | |
human_role=self.human_role, | |
assistant="assistant", | |
ai_role=self.ai_role, | |
messages="\n".join( | |
unravel_messages(self.messages[1:]) | |
), # don't include the system prompt | |
) | |
summary = chat( | |
messages=[make_message("user", prompt)], | |
model=self.model, | |
) | |
self.reset() | |
self.append(make_message("user", summary)) | |
def get_messages(self): | |
return self.messages[1:] # don't include the system prompt | |
def get_unraveled_messages(self): | |
return unravel_messages(self.messages[1:]) | |
class MemoryBuffer: | |
""" | |
A class that manages a buffer of messages and clips them to a maximum token limit. | |
Attributes: | |
max_token_limit (int): The maximum number of tokens allowed in the buffer. | |
messages (list[dict]): A list of messages in the buffer. | |
""" | |
def __init__( | |
self, | |
system_prompt, | |
max_token_limit=1000, | |
): | |
""" | |
Initializes a new instance of the MemoryBuffer class. | |
Args: | |
max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000. | |
""" | |
self.max_token_limit = max_token_limit | |
self.messages = [] | |
self.system_prompt = system_prompt | |
self.reset() | |
def reset(self): | |
""" | |
Resets the buffer by clearing all messages. | |
""" | |
self.messages = [self.system_prompt] | |
def add(self, message: dict): | |
""" | |
Adds a message to the buffer and clips the buffer to the maximum token limit. | |
Args: | |
message (dict): The message to add to the buffer. | |
""" | |
total_tokens = num_tokens_from_messages(self.messages + [message]) | |
if total_tokens > self.max_token_limit: | |
# clip the messages to the max token limit | |
# from the end of the list | |
# remove messages from the beginning of the list | |
# until the total number of tokens is less than the max token limit | |
while total_tokens > self.max_token_limit: | |
self.messages = self.messages[1:] | |
total_tokens = num_tokens_from_messages(self.messages + [message]) | |
self.messages.append(message) | |
def remove(self, message: dict): | |
""" | |
Removes a message from the buffer. | |
Args: | |
message (dict): The message to remove from the buffer. | |
""" | |
if message in self.messages: | |
self.messages.remove(message) | |
def remove_last(self): | |
""" | |
Removes the last message from the buffer. | |
""" | |
if len(self.messages) > 0: | |
self.messages.pop() | |
def remove_first(self): | |
""" | |
Removes the first message from the buffer. | |
""" | |
if len(self.messages) > 0: | |
self.messages.pop(0) | |
def json2dict(string: str) -> dict: | |
"""Returns a dictionary of variables from a string containing JSON.""" | |
try: | |
return json.loads(string) | |
except json.decoder.JSONDecodeError: | |
print("Error: JSONDecodeError") | |
return {} | |
def print_help(num_nodes, color): | |
""" | |
Prints the help message for the AI assistant. | |
""" | |
colorama.init() | |
print(color + "The AI assistant presents a clinical case and asks for a diagnosis.") | |
print( | |
color + "You need to explore the case by asking questions to the AI assistant." | |
) | |
print( | |
color | |
+ "You have to ask questions in a logical order, conforming to the clinical guidelines." | |
) | |
print( | |
color | |
+ "You need to minimize the number of jump between subjects, while covering as many subjects as possible." | |
) | |
print(color + f"there are a total of {num_nodes} visitable nodes in the tree") | |
print( | |
color | |
+ "you have to explore the tree as much as possible while avoiding jumps and travelling excessively." | |
) | |
print(Style.RESET_ALL) | |
def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict: | |
prompt = make_prompt(template=template, **kwargs) | |
message = make_message(role, prompt) | |
return message | |
# a debugging decorator and use functools to preserve the function name and docstring | |
# the decorator gets DEBUG as an argument to turn on or off debugging | |
def debug(DEBUG, print_func, measure_time=True): | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
if DEBUG: | |
print_func(f"\nCalling {func.__name__}") | |
if measure_time and DEBUG: | |
start = time.time() | |
result = func(*args, **kwargs) | |
if measure_time and DEBUG: | |
end = time.time() | |
print_func(f"Elapsed time: {end - start:.2f}s") | |
if DEBUG: | |
print_func(f"Returning {func.__name__}") | |
return result | |
return wrapper | |
return decorator | |
# to use the decorator, add @debug(DEBUG) above the function definition | |