Spaces:
Runtime error
Runtime error
| import colorama | |
| from colorama import Fore, Style | |
| import openai | |
| from tenacity import retry, stop_after_attempt, wait_fixed | |
| import json | |
| import os | |
| import tiktoken | |
| import functools as ft | |
| import time | |
| JSON_TEMPLATE = """ | |
| {question} | |
| The required key(s) are: {keys}. | |
| Only and only respond with the key(s) and value(s) mentioned above. | |
| Your answer in valid JSON format:\n | |
| """ | |
| MODEL_COST_DICT = { | |
| "gpt-3.5-turbo": { | |
| "input": 0.0015, | |
| "output": 0.002, | |
| }, | |
| "gpt-4": { | |
| "input": 0.03, | |
| "output": 0.06, | |
| }, | |
| } | |
| def set_api_key(key=None): | |
| """Sets the OpenAI API key.""" | |
| if key is None: | |
| key = os.environ.get("OPENAI_API_KEY") | |
| openai.api_key = key | |
| def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
| """Returns the number of tokens in a text string.""" | |
| encoding = tiktoken.get_encoding(encoding_name) | |
| num_tokens = len(encoding.encode(string)) | |
| return num_tokens | |
| def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"): | |
| """Returns the number of tokens used by a list of messages.""" | |
| try: | |
| encoding = tiktoken.encoding_for_model(model) | |
| except KeyError: | |
| encoding = tiktoken.get_encoding("cl100k_base") | |
| if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this | |
| num_tokens = 0 | |
| for message in messages: | |
| num_tokens += ( | |
| 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n | |
| ) | |
| for key, value in message.items(): | |
| num_tokens += len(encoding.encode(value)) | |
| if key == "name": # if there's a name, the role is omitted | |
| num_tokens += -1 # role is always required and always 1 token | |
| num_tokens += 2 # every reply is primed with <im_start>assistant | |
| return num_tokens | |
| else: | |
| raise NotImplementedError( | |
| f"""num_tokens_from_messages() is not presently implemented for model {model}. | |
| See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" | |
| ) | |
| def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0): | |
| response = openai.ChatCompletion().create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| ) | |
| return response["choices"][0]["message"]["content"] | |
| def make_message(role: str, content: str) -> dict: | |
| return { | |
| "role": role, | |
| "content": content, | |
| } | |
| def make_prompt(template: str, **kwargs): | |
| return template.format(**kwargs) | |
| def unravel_messages(messages: list[dict]) -> list[str]: | |
| """Returns a string representation of a list of messages.""" | |
| return [f"{message['role']}: {message['content']}" for message in messages] | |
| class LLM: | |
| def __init__(self, model="gpt-3.5-turbo", temperature=0.0): | |
| self.model = model | |
| self.temperature = temperature | |
| self.token_counter = 0 | |
| self.cost = 0.0 | |
| def chat(self, messages: list[dict]): | |
| response = openai.ChatCompletion().create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| ) | |
| self.token_counter += int(response["usage"]["total_tokens"]) | |
| self.cost += ( | |
| response["usage"]["prompt_tokens"] | |
| / 1000 | |
| * MODEL_COST_DICT[self.model]["input"] | |
| + response["usage"]["completion_tokens"] | |
| / 1000 | |
| * MODEL_COST_DICT[self.model]["output"] | |
| ) | |
| return response["choices"][0]["message"]["content"] | |
| def reset(self): | |
| self.token_counter = 0 | |
| self.cost = 0.0 | |
| def __call__(self, messages: list[dict]): | |
| return self.chat(messages) | |
| class SummaryMemory: | |
| """ | |
| A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached. | |
| Attributes: | |
| max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs. | |
| messages (list[dict]): A list of messages in the memory. | |
| model (str): The name of the GPT model to use for chat completion. | |
| ai_role (str): The role of the AI in the conversation. | |
| human_role (str): The role of the human in the conversation. | |
| auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached. | |
| """ | |
| # ... | |
| summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}" | |
| def __init__( | |
| self, | |
| system_prompt="", | |
| max_token_limit=4000, | |
| model="gpt-3.5-turbo", | |
| ai_role="answer", | |
| human_role="question/exam", | |
| auto_summarize=False, | |
| ): | |
| self.max_token_limit = max_token_limit | |
| self.messages: list[dict] = [] | |
| self.model = model | |
| self.ai_role = ai_role | |
| self.human_role = human_role | |
| self.auto_summarize = auto_summarize | |
| self.system_prompt = system_prompt | |
| self.reset() | |
| def reset(self): | |
| self.messages = [self.system_prompt] | |
| def remove_last(self): | |
| if len(self.messages) > 1: # don't remove the system prompt | |
| self.messages.pop() | |
| def remove( | |
| self, index: int | |
| ): # don't remove the system prompt and start counting from 1 | |
| if index > 0 and index < len(self.messages): | |
| self.messages.pop(index) | |
| def replace(self, index: int, message: dict): | |
| if index > 0 and index < len(self.messages): | |
| self.messages[index] = message | |
| def change_system_prompt(self, new_prompt: str): | |
| self.system_prompt = new_prompt | |
| self.messages[0] = new_prompt | |
| def remove_first(self): | |
| # dont remove the system prompt | |
| if len(self.messages) > 1: | |
| self.messages.pop(1) # remove the first message after the system prompt | |
| def append(self, message: dict): | |
| total_tokens = num_tokens_from_messages(self.messages + [message]) | |
| while ( | |
| self.auto_summarize and total_tokens > self.max_token_limit | |
| ): # keep summarizing until we're under the limit | |
| self.summarize() | |
| total_tokens = num_tokens_from_messages(self.messages + [message]) | |
| self.messages.append(message) | |
| def summarize(self): | |
| prompt = make_prompt( | |
| self.summary_template, | |
| user="user", | |
| human_role=self.human_role, | |
| assistant="assistant", | |
| ai_role=self.ai_role, | |
| messages="\n".join( | |
| unravel_messages(self.messages[1:]) | |
| ), # don't include the system prompt | |
| ) | |
| summary = chat( | |
| messages=[make_message("user", prompt)], | |
| model=self.model, | |
| ) | |
| self.reset() | |
| self.append(make_message("user", summary)) | |
| def get_messages(self): | |
| return self.messages[1:] # don't include the system prompt | |
| def get_unraveled_messages(self): | |
| return unravel_messages(self.messages[1:]) | |
| class MemoryBuffer: | |
| """ | |
| A class that manages a buffer of messages and clips them to a maximum token limit. | |
| Attributes: | |
| max_token_limit (int): The maximum number of tokens allowed in the buffer. | |
| messages (list[dict]): A list of messages in the buffer. | |
| """ | |
| def __init__( | |
| self, | |
| system_prompt, | |
| max_token_limit=1000, | |
| ): | |
| """ | |
| Initializes a new instance of the MemoryBuffer class. | |
| Args: | |
| max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000. | |
| """ | |
| self.max_token_limit = max_token_limit | |
| self.messages = [] | |
| self.system_prompt = system_prompt | |
| self.reset() | |
| def reset(self): | |
| """ | |
| Resets the buffer by clearing all messages. | |
| """ | |
| self.messages = [self.system_prompt] | |
| def add(self, message: dict): | |
| """ | |
| Adds a message to the buffer and clips the buffer to the maximum token limit. | |
| Args: | |
| message (dict): The message to add to the buffer. | |
| """ | |
| total_tokens = num_tokens_from_messages(self.messages + [message]) | |
| if total_tokens > self.max_token_limit: | |
| # clip the messages to the max token limit | |
| # from the end of the list | |
| # remove messages from the beginning of the list | |
| # until the total number of tokens is less than the max token limit | |
| while total_tokens > self.max_token_limit: | |
| self.messages = self.messages[1:] | |
| total_tokens = num_tokens_from_messages(self.messages + [message]) | |
| self.messages.append(message) | |
| def remove(self, message: dict): | |
| """ | |
| Removes a message from the buffer. | |
| Args: | |
| message (dict): The message to remove from the buffer. | |
| """ | |
| if message in self.messages: | |
| self.messages.remove(message) | |
| def remove_last(self): | |
| """ | |
| Removes the last message from the buffer. | |
| """ | |
| if len(self.messages) > 0: | |
| self.messages.pop() | |
| def remove_first(self): | |
| """ | |
| Removes the first message from the buffer. | |
| """ | |
| if len(self.messages) > 0: | |
| self.messages.pop(0) | |
| def json2dict(string: str) -> dict: | |
| """Returns a dictionary of variables from a string containing JSON.""" | |
| try: | |
| return json.loads(string) | |
| except json.decoder.JSONDecodeError: | |
| print("Error: JSONDecodeError") | |
| return {} | |
| def print_help(num_nodes, color): | |
| """ | |
| Prints the help message for the AI assistant. | |
| """ | |
| colorama.init() | |
| print(color + "The AI assistant presents a clinical case and asks for a diagnosis.") | |
| print( | |
| color + "You need to explore the case by asking questions to the AI assistant." | |
| ) | |
| print( | |
| color | |
| + "You have to ask questions in a logical order, conforming to the clinical guidelines." | |
| ) | |
| print( | |
| color | |
| + "You need to minimize the number of jump between subjects, while covering as many subjects as possible." | |
| ) | |
| print(color + f"there are a total of {num_nodes} visitable nodes in the tree") | |
| print( | |
| color | |
| + "you have to explore the tree as much as possible while avoiding jumps and travelling excessively." | |
| ) | |
| print(Style.RESET_ALL) | |
| def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict: | |
| prompt = make_prompt(template=template, **kwargs) | |
| message = make_message(role, prompt) | |
| return message | |
| # a debugging decorator and use functools to preserve the function name and docstring | |
| # the decorator gets DEBUG as an argument to turn on or off debugging | |
| def debug(DEBUG, print_func, measure_time=True): | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| if DEBUG: | |
| print_func(f"\nCalling {func.__name__}") | |
| if measure_time and DEBUG: | |
| start = time.time() | |
| result = func(*args, **kwargs) | |
| if measure_time and DEBUG: | |
| end = time.time() | |
| print_func(f"Elapsed time: {end - start:.2f}s") | |
| if DEBUG: | |
| print_func(f"Returning {func.__name__}") | |
| return result | |
| return wrapper | |
| return decorator | |
| # to use the decorator, add @debug(DEBUG) above the function definition | |