|
from constants import model_options, MAX_CONVERSATION_LENGTH, MAX_TOKENS_PER_GENERATION, SAMPLING_TEMPERATURE |
|
|
|
import gradio as gr |
|
import openai |
|
from openai import OpenAI |
|
import os |
|
|
|
collection_name = str(os.getenv('COLLECTION_NAME')) |
|
|
|
oai_key = str(os.getenv('OAI_KEY')) |
|
openai.api_key = oai_key |
|
|
|
|
|
client = OpenAI( |
|
api_key = oai_key |
|
) |
|
|
|
|
|
def query_a_chat_completion(model, messages): |
|
|
|
assert model in ["gpt-3.5-turbo", "gpt-4"] |
|
completion = openai.ChatCompletion.create( |
|
model=model, |
|
messages=messages, |
|
max_tokens=MAX_TOKENS_PER_GENERATION, |
|
temperature=SAMPLING_TEMPERATURE |
|
) |
|
return completion.choices[0].message.content |
|
|
|
|
|
def query_chatbot(model, messages): |
|
assert model in ["gpt-3.5-turbo", "gpt-4"] |
|
chat_completion = client.chat.completions.create( |
|
messages=messages, |
|
model="gpt-3.5-turbo", |
|
) |
|
return chat_completion.choices[0].message.content |
|
|
|
|
|
def chatbot_generate(user_newest_input, history, model, current_answer, initial_txt): |
|
""" |
|
Generate the next response from the chatbot |
|
:param user_newest_input: The newest input from the user |
|
:param history: The history of the conversation |
|
list[str], where each element starts with "User:" or "AI:" |
|
:return: The chatbot state, the history, the text, the submit button |
|
""" |
|
|
|
actual_model = { |
|
"chatgpt4": "gpt-4", |
|
"chatgpt": "gpt-3.5-turbo" |
|
}[model] |
|
|
|
|
|
history.append(f"User: {user_newest_input.strip()}") |
|
|
|
|
|
chat_messages = [{"role": "system", "content": "You are a helpful assistant to a writer."}] |
|
|
|
|
|
current_txt = "My current answer to the prompt is as follows: " + current_answer + '. Now help me answer my question: ' |
|
for hist in history: |
|
if hist.startswith("User:"): |
|
chat_messages.append( |
|
{ |
|
"role": "user", |
|
"content": hist[5:].strip() |
|
} |
|
) |
|
elif hist.startswith("Writing Assistant:"): |
|
content = current_txt + hist[18:].strip() |
|
chat_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": content |
|
} |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if actual_model in ["gpt-3.5-turbo", "gpt-4"]: |
|
ai_newest_output = query_chatbot(actual_model, chat_messages) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
history.append(f"Writing Assistant: {ai_newest_output.strip()}") |
|
conversations = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] |
|
|
|
|
|
if len(history) >= 2 * MAX_CONVERSATION_LENGTH: |
|
return conversations, history, gr.update(visible=False), gr.update(visible=False) |
|
else: |
|
return conversations, history, gr.update(visible=True), gr.update(visible=True) |