ChatPaper / chatbot.py
yiyixin's picture
upload
28c2a3d
from base_class import ChatbotEngine
import os
import openai
import json
import os
import requests
import tiktoken
from config import MAX_TOKEN_MODEL_MAP
from utils import get_filtered_keys_from_object
class ChatbotWrapper:
"""
Wrapper of Official ChatGPT API,
# base on https://github.com/ChatGPT-Hackers/revChatGPT
"""
def __init__(
self,
api_key: str,
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
proxy: str = None,
max_tokens: int = 3000,
temperature: float = 0.5,
top_p: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
reply_count: int = 1,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
overhead_token=96,
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
"""
self.engine = engine
self.session = requests.Session()
self.api_key = api_key
self.system_prompt = system_prompt
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.reply_count = reply_count
self.max_limit = MAX_TOKEN_MODEL_MAP[self.engine]
self.overhead_token = overhead_token
if proxy:
self.session.proxies = {
"http": proxy,
"https": proxy,
}
self.conversation: dict = {
"default": [
{
"role": "system",
"content": system_prompt,
},
],
}
if max_tokens > self.max_limit - self.overhead_token:
raise Exception(
f"Max tokens cannot be greater than {self.max_limit- self.overhead_token}")
if self.get_token_count("default") > self.max_tokens:
raise Exception("System prompt is too long")
def add_to_conversation(
self,
message: str,
role: str,
convo_id: str = "default",
) -> None:
"""
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
Truncate the conversation
"""
# TODO: context condense with soft prompt tuning
while True:
if (
self.get_token_count(convo_id) > self.max_tokens
and len(self.conversation[convo_id]) > 1
):
# Don't remove the first message and remove the first QA pair
self.conversation[convo_id].pop(1)
self.conversation[convo_id].pop(1)
# TODO: optimal pop out based on similarity distance
else:
break
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def get_token_count(self, convo_id: str = "default") -> int:
"""
Get token count
"""
if self.engine not in ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"]:
raise NotImplementedError("Unsupported engine {self.engine}")
encoding = tiktoken.encoding_for_model(self.engine)
num_tokens = 0
for message in self.conversation[convo_id]:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += 1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
def get_max_tokens(self, convo_id: str) -> int:
"""
Get max tokens
"""
return self.max_tokens - self.get_token_count(convo_id)
def ask_stream(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
dynamic_system_prompt=None,
**kwargs,
) -> str:
"""
Ask a question
"""
# Make conversation if it doesn't exist
if convo_id not in self.conversation:
self.reset(convo_id=convo_id, system_prompt=dynamic_system_prompt)
# adjust system prompt
assert dynamic_system_prompt is not None
self.conversation[convo_id][0]["content"] = dynamic_system_prompt
self.add_to_conversation(prompt, "user", convo_id=convo_id)
print(" total tokens:")
print(self.get_token_count(convo_id))
self.__truncate_conversation(convo_id=convo_id)
# Get response
response = self.session.post(
os.environ.get(
"API_URL") or "https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
json={
"model": self.engine,
"messages": self.conversation[convo_id],
"stream": True,
# kwargs
"temperature": kwargs.get("temperature", self.temperature),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get(
"presence_penalty",
self.presence_penalty,
),
"frequency_penalty": kwargs.get(
"frequency_penalty",
self.frequency_penalty,
),
"n": kwargs.get("n", self.reply_count),
"user": role,
"max_tokens": self. get_max_tokens(convo_id=convo_id),
},
stream=True,
)
if response.status_code != 200:
raise Exception(
f"Error: {response.status_code} {response.reason} {response.text}",
)
response_role: str = None
full_response: str = ""
for line in response.iter_lines():
if not line:
continue
# Remove "data: "
line = line.decode("utf-8")[6:]
if line == "[DONE]":
break
resp: dict = json.loads(line)
choices = resp.get("choices")
if not choices:
continue
delta = choices[0].get("delta")
if not delta:
continue
if "role" in delta:
response_role = delta["role"]
if "content" in delta:
content = delta["content"]
full_response += content
yield content
self.add_to_conversation(
full_response, response_role, convo_id=convo_id)
def ask(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
dynamic_system_prompt: str = None,
**kwargs,
) -> str:
"""
Non-streaming ask
"""
response = self.ask_stream(
prompt=prompt,
role=role,
convo_id=convo_id,
dynamic_system_prompt=dynamic_system_prompt,
**kwargs,
)
full_response: str = "".join(response)
return full_response
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
"""
Rollback the conversation
"""
for _ in range(n):
self.conversation[convo_id].pop()
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
"""
Reset the conversation
"""
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]
def save(self, file: str, *keys: str) -> None:
"""
Save the Chatbot configuration to a JSON file
"""
with open(file, "w", encoding="utf-8") as f:
json.dump(
{
key: self.__dict__[key]
for key in get_filtered_keys_from_object(self, *keys)
},
f,
indent=2,
# saves session.proxies dict as session
default=lambda o: o.__dict__["proxies"],
)
def load(self, file: str, *keys: str) -> None:
"""
Load the Chatbot configuration from a JSON file
"""
with open(file, encoding="utf-8") as f:
# load json, if session is in keys, load proxies
loaded_config = json.load(f)
keys = get_filtered_keys_from_object(self, *keys)
if "session" in keys and loaded_config["session"]:
self.session.proxies = loaded_config["session"]
keys = keys - {"session"}
self.__dict__.update({key: loaded_config[key] for key in keys})
class OpenAIChatbot(ChatbotEngine):
def __init__(self, api_key: str,
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
proxy: str = None,
max_tokens: int = 3000,
temperature: float = 0.5,
top_p: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
reply_count: int = 1,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
overhead_token=96) -> None:
openai.api_key = api_key
self.api_key = api_key
self.engine = engine
self.proxy = proxy
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.reply_count = reply_count
self.system_prompt = system_prompt
self.bot = ChatbotWrapper(
api_key=self.api_key,
engine=self.engine,
proxy=self.proxy,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
reply_count=self.reply_count,
system_prompt=self.system_prompt,
overhead_token=overhead_token
)
self.overhead_token = overhead_token
import tiktoken
self.encoding = tiktoken.encoding_for_model(self.engine)
def encode_length(self, text: str) -> int:
return len(self.encoding.encode(text))
def query(self, questions: str,
role: str = "user",
convo_id: str = "default",
context: str = None,
**kwargs,):
return self.bot.ask(prompt=questions, role=role, convo_id=convo_id, dynamic_system_prompt=context, **kwargs)