Spaces:
Runtime error
Runtime error
File size: 11,700 Bytes
1d80bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
import colorama
from colorama import Fore, Style
import openai
from tenacity import retry, stop_after_attempt, wait_fixed
import json
import os
import tiktoken
import functools as ft
import time
JSON_TEMPLATE = """
{question}
The required key(s) are: {keys}.
Only and only respond with the key(s) and value(s) mentioned above.
Your answer in valid JSON format:\n
"""
MODEL_COST_DICT = {
"gpt-3.5-turbo": {
"input": 0.0015,
"output": 0.002,
},
"gpt-4": {
"input": 0.03,
"output": 0.06,
},
}
def set_api_key(key=None):
"""Sets the OpenAI API key."""
if key is None:
key = os.environ.get("OPENAI_API_KEY")
openai.api_key = key
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def num_tokens_from_messages(messages: list[dict], model="gpt-3.5-turbo-0613"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo-0613": # note: future models may deviate from this
num_tokens = 0
for message in messages:
num_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
)
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def chat(messages: list[dict], model="gpt-3.5-turbo", temperature=0.0):
response = openai.ChatCompletion().create(
model=model,
messages=messages,
temperature=temperature,
)
return response["choices"][0]["message"]["content"]
def make_message(role: str, content: str) -> dict:
return {
"role": role,
"content": content,
}
def make_prompt(template: str, **kwargs):
return template.format(**kwargs)
def unravel_messages(messages: list[dict]) -> list[str]:
"""Returns a string representation of a list of messages."""
return [f"{message['role']}: {message['content']}" for message in messages]
class LLM:
def __init__(self, model="gpt-3.5-turbo", temperature=0.0):
self.model = model
self.temperature = temperature
self.token_counter = 0
self.cost = 0.0
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def chat(self, messages: list[dict]):
response = openai.ChatCompletion().create(
model=self.model,
messages=messages,
temperature=self.temperature,
)
self.token_counter += int(response["usage"]["total_tokens"])
self.cost += (
response["usage"]["prompt_tokens"]
/ 1000
* MODEL_COST_DICT[self.model]["input"]
+ response["usage"]["completion_tokens"]
/ 1000
* MODEL_COST_DICT[self.model]["output"]
)
return response["choices"][0]["message"]["content"]
def reset(self):
self.token_counter = 0
self.cost = 0.0
def __call__(self, messages: list[dict]):
return self.chat(messages)
class SummaryMemory:
"""
A class that manages a memory of messages and automatically summarizes them when the maximum token limit is reached.
Attributes:
max_token_limit (int): The maximum number of tokens allowed in the memory before summarization occurs.
messages (list[dict]): A list of messages in the memory.
model (str): The name of the GPT model to use for chat completion.
ai_role (str): The role of the AI in the conversation.
human_role (str): The role of the human in the conversation.
auto_summarize (bool): Whether to automatically summarize the messages when the maximum token limit is reached.
"""
# ...
summary_template = "Summarize the following messages into a paragraph and replace '{user}' with '{human_role}', and '{assistant}' with '{ai_role}':\n{messages}"
def __init__(
self,
system_prompt="",
max_token_limit=4000,
model="gpt-3.5-turbo",
ai_role="answer",
human_role="question/exam",
auto_summarize=False,
):
self.max_token_limit = max_token_limit
self.messages: list[dict] = []
self.model = model
self.ai_role = ai_role
self.human_role = human_role
self.auto_summarize = auto_summarize
self.system_prompt = system_prompt
self.reset()
def reset(self):
self.messages = [self.system_prompt]
def remove_last(self):
if len(self.messages) > 1: # don't remove the system prompt
self.messages.pop()
def remove(
self, index: int
): # don't remove the system prompt and start counting from 1
if index > 0 and index < len(self.messages):
self.messages.pop(index)
def replace(self, index: int, message: dict):
if index > 0 and index < len(self.messages):
self.messages[index] = message
def change_system_prompt(self, new_prompt: str):
self.system_prompt = new_prompt
self.messages[0] = new_prompt
def remove_first(self):
# dont remove the system prompt
if len(self.messages) > 1:
self.messages.pop(1) # remove the first message after the system prompt
def append(self, message: dict):
total_tokens = num_tokens_from_messages(self.messages + [message])
while (
self.auto_summarize and total_tokens > self.max_token_limit
): # keep summarizing until we're under the limit
self.summarize()
total_tokens = num_tokens_from_messages(self.messages + [message])
self.messages.append(message)
def summarize(self):
prompt = make_prompt(
self.summary_template,
user="user",
human_role=self.human_role,
assistant="assistant",
ai_role=self.ai_role,
messages="\n".join(
unravel_messages(self.messages[1:])
), # don't include the system prompt
)
summary = chat(
messages=[make_message("user", prompt)],
model=self.model,
)
self.reset()
self.append(make_message("user", summary))
def get_messages(self):
return self.messages[1:] # don't include the system prompt
def get_unraveled_messages(self):
return unravel_messages(self.messages[1:])
class MemoryBuffer:
"""
A class that manages a buffer of messages and clips them to a maximum token limit.
Attributes:
max_token_limit (int): The maximum number of tokens allowed in the buffer.
messages (list[dict]): A list of messages in the buffer.
"""
def __init__(
self,
system_prompt,
max_token_limit=1000,
):
"""
Initializes a new instance of the MemoryBuffer class.
Args:
max_token_limit (int, optional): The maximum number of tokens allowed in the buffer. Defaults to 1000.
"""
self.max_token_limit = max_token_limit
self.messages = []
self.system_prompt = system_prompt
self.reset()
def reset(self):
"""
Resets the buffer by clearing all messages.
"""
self.messages = [self.system_prompt]
def add(self, message: dict):
"""
Adds a message to the buffer and clips the buffer to the maximum token limit.
Args:
message (dict): The message to add to the buffer.
"""
total_tokens = num_tokens_from_messages(self.messages + [message])
if total_tokens > self.max_token_limit:
# clip the messages to the max token limit
# from the end of the list
# remove messages from the beginning of the list
# until the total number of tokens is less than the max token limit
while total_tokens > self.max_token_limit:
self.messages = self.messages[1:]
total_tokens = num_tokens_from_messages(self.messages + [message])
self.messages.append(message)
def remove(self, message: dict):
"""
Removes a message from the buffer.
Args:
message (dict): The message to remove from the buffer.
"""
if message in self.messages:
self.messages.remove(message)
def remove_last(self):
"""
Removes the last message from the buffer.
"""
if len(self.messages) > 0:
self.messages.pop()
def remove_first(self):
"""
Removes the first message from the buffer.
"""
if len(self.messages) > 0:
self.messages.pop(0)
def json2dict(string: str) -> dict:
"""Returns a dictionary of variables from a string containing JSON."""
try:
return json.loads(string)
except json.decoder.JSONDecodeError:
print("Error: JSONDecodeError")
return {}
def print_help(num_nodes, color):
"""
Prints the help message for the AI assistant.
"""
colorama.init()
print(color + "The AI assistant presents a clinical case and asks for a diagnosis.")
print(
color + "You need to explore the case by asking questions to the AI assistant."
)
print(
color
+ "You have to ask questions in a logical order, conforming to the clinical guidelines."
)
print(
color
+ "You need to minimize the number of jump between subjects, while covering as many subjects as possible."
)
print(color + f"there are a total of {num_nodes} visitable nodes in the tree")
print(
color
+ "you have to explore the tree as much as possible while avoiding jumps and travelling excessively."
)
print(Style.RESET_ALL)
def make_question(template=JSON_TEMPLATE, role="user", **kwargs) -> dict:
prompt = make_prompt(template=template, **kwargs)
message = make_message(role, prompt)
return message
# a debugging decorator and use functools to preserve the function name and docstring
# the decorator gets DEBUG as an argument to turn on or off debugging
def debug(DEBUG, print_func, measure_time=True):
def decorator(func):
@ft.wraps(func)
def wrapper(*args, **kwargs):
if DEBUG:
print_func(f"\nCalling {func.__name__}")
if measure_time and DEBUG:
start = time.time()
result = func(*args, **kwargs)
if measure_time and DEBUG:
end = time.time()
print_func(f"Elapsed time: {end - start:.2f}s")
if DEBUG:
print_func(f"Returning {func.__name__}")
return result
return wrapper
return decorator
# to use the decorator, add @debug(DEBUG) above the function definition
|