File size: 3,210 Bytes
848090a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d6e86
848090a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e15a84
848090a
 
21d6e86
848090a
 
 
 
 
21d6e86
848090a
 
 
7e15a84
848090a
 
 
7e15a84
848090a
 
 
 
 
 
 
21d6e86
 
7e15a84
848090a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from constants import model_options, MAX_CONVERSATION_LENGTH, MAX_TOKENS_PER_GENERATION, SAMPLING_TEMPERATURE

import gradio as gr
import openai
from openai import OpenAI
import os

collection_name = str(os.getenv('COLLECTION_NAME'))

oai_key = str(os.getenv('OAI_KEY'))
openai.api_key = oai_key


client = OpenAI(
    api_key = oai_key
)


def query_a_chat_completion(model, messages):
    # assert model in ["gpt-4"]
    assert model in ["gpt-3.5-turbo", "gpt-4"]
    completion = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        max_tokens=MAX_TOKENS_PER_GENERATION,
        temperature=SAMPLING_TEMPERATURE
    )
    return completion.choices[0].message.content


def query_chatbot(model, messages):
    assert model in ["gpt-3.5-turbo", "gpt-4"]
    chat_completion = client.chat.completions.create(
        messages=messages,
        model="gpt-3.5-turbo",
    )
    return chat_completion.choices[0].message.content


def chatbot_generate(user_newest_input, history, model, initial_txt):
    """
    Generate the next response from the chatbot
    :param user_newest_input: The newest input from the user
    :param history: The history of the conversation
        list[str], where each element starts with "User:" or "AI:"
    :return: The chatbot state, the history, the text, the submit button
    """
    # convert to openai model format
    actual_model = {
        "chatgpt4": "gpt-4",
        "chatgpt": "gpt-3.5-turbo"
    }[model]

    # Update the history with newest user input
    history.append(f"User: {user_newest_input.strip()}")

    # construct chat messages
    chat_messages = [{"role": "system", "content": initial_txt}]

    # chat_messages = [{"role": "system", "content": initial_txt}]
    # current_txt = "My current answer to the instructions is as follows: " + current_answer + '. Now, assist me with the following: '
    for hist in history:
        if hist.startswith("User:"):
            chat_messages.append(
                {
                    "role": "user",
                    "content": hist[5:].strip()
                }
            )
        elif hist.startswith("Writing Assistant:"):

            chat_messages.append(
                {
                    "role": "assistant",
                    "content": hist[18:].strip()
                }
            )
        else:
            raise NotImplementedError

    # Get the generation from OpenAI
    if actual_model in ["gpt-3.5-turbo", "gpt-4"]:
        # print('generating chatbot')
        # print(actual_model)
        print(chat_messages)
        ai_newest_output = query_chatbot(actual_model, chat_messages)
    else:
        raise NotImplementedError

    # Update the history with newest AI output
    history.append(f"Writing Assistant: {ai_newest_output.strip()}")
    conversations = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]

    # Whether the textbox and the submit button should be hidden
    if len(history) >= 2 * MAX_CONVERSATION_LENGTH:
        return conversations, history, gr.update(visible=False), gr.update(visible=False)
    else:
        return conversations, history, gr.update(visible=True), gr.update(visible=True)