wrAIte / model_generate.py
suhamemon1
done
21d6e86
from constants import model_options, MAX_CONVERSATION_LENGTH, MAX_TOKENS_PER_GENERATION, SAMPLING_TEMPERATURE
import gradio as gr
import openai
from openai import OpenAI
import os
collection_name = str(os.getenv('COLLECTION_NAME'))
oai_key = str(os.getenv('OAI_KEY'))
openai.api_key = oai_key
client = OpenAI(
api_key = oai_key
)
def query_a_chat_completion(model, messages):
# assert model in ["gpt-4"]
assert model in ["gpt-3.5-turbo", "gpt-4"]
completion = openai.ChatCompletion.create(
model=model,
messages=messages,
max_tokens=MAX_TOKENS_PER_GENERATION,
temperature=SAMPLING_TEMPERATURE
)
return completion.choices[0].message.content
def query_chatbot(model, messages):
assert model in ["gpt-3.5-turbo", "gpt-4"]
chat_completion = client.chat.completions.create(
messages=messages,
model="gpt-3.5-turbo",
)
return chat_completion.choices[0].message.content
def chatbot_generate(user_newest_input, history, model, initial_txt):
"""
Generate the next response from the chatbot
:param user_newest_input: The newest input from the user
:param history: The history of the conversation
list[str], where each element starts with "User:" or "AI:"
:return: The chatbot state, the history, the text, the submit button
"""
# convert to openai model format
actual_model = {
"chatgpt4": "gpt-4",
"chatgpt": "gpt-3.5-turbo"
}[model]
# Update the history with newest user input
history.append(f"User: {user_newest_input.strip()}")
# construct chat messages
chat_messages = [{"role": "system", "content": initial_txt}]
# chat_messages = [{"role": "system", "content": initial_txt}]
# current_txt = "My current answer to the instructions is as follows: " + current_answer + '. Now, assist me with the following: '
for hist in history:
if hist.startswith("User:"):
chat_messages.append(
{
"role": "user",
"content": hist[5:].strip()
}
)
elif hist.startswith("Writing Assistant:"):
chat_messages.append(
{
"role": "assistant",
"content": hist[18:].strip()
}
)
else:
raise NotImplementedError
# Get the generation from OpenAI
if actual_model in ["gpt-3.5-turbo", "gpt-4"]:
# print('generating chatbot')
# print(actual_model)
print(chat_messages)
ai_newest_output = query_chatbot(actual_model, chat_messages)
else:
raise NotImplementedError
# Update the history with newest AI output
history.append(f"Writing Assistant: {ai_newest_output.strip()}")
conversations = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
# Whether the textbox and the submit button should be hidden
if len(history) >= 2 * MAX_CONVERSATION_LENGTH:
return conversations, history, gr.update(visible=False), gr.update(visible=False)
else:
return conversations, history, gr.update(visible=True), gr.update(visible=True)