File size: 10,345 Bytes
9d1e12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import gradio as gr
import os
import json
import requests

prompt_templates = {"默认ChatGPT": ""}
# Streaming endpoint
# API_URL = "https://api.openai.com/v1/chat/completions"  # os.getenv("API_URL") + "/generate_stream"
OPENAI_URL = "https://api.openai.com/v1/chat/completions"  # os.getenv("API_URL") + "/generate_stream"
API2D_URL = "https://openai.api2d.net/v1/chat/completions"  # os.getenv("API_URL") + "/generate_stream"
convo_id = 'default'
#5c72c157a8fd54357bd13112cd71952a

def on_prompt_template_change(prompt_template):
    if not isinstance(prompt_template, str): return
    if prompt_template:
        return prompt_templates[prompt_template]
    else:
        ''

def get_empty_state():
    return {"total_tokens": 0, "messages": []}

def get_prompt_templates():
    with open('./prompts_zh.json','r',encoding='utf8') as fp:
        json_data = json.load(fp)
        for data in json_data:
            act = data['act']
            prompt = data['prompt']
            prompt_templates[act] = prompt
        # reader = csv.reader(csv_file)
        # next(reader)  # skip the header row
        # for row in reader:
        #     if len(row) >= 2:
        #         act = row[0].strip('"')
        #         prompt = row[1].strip('"')
        #         prompt_templates[act] = prompt

        choices = list(prompt_templates.keys())
        choices = choices[:1] + sorted(choices[1:])
        return gr.update(value=choices[0], choices=choices)

# Testing with my Open AI Key
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

def predict(inputs, prompt_template, temperature, openai_api_key, chat_counter, context_length, chatbot=[],
            history=[]):  # repetition_penalty, top_k
    print(openai_api_key)
    if openai_api_key.startswith('sk-'):
        API_URL = OPENAI_URL
        print(1)
    elif openai_api_key.startswith('fk'):
        API_URL = API2D_URL
        print(2)
    else:
        API_URL = OPENAI_URL
        print(3)
    if inputs==None:
        inputs = ''
    if prompt_template:
        prompt_template = prompt_templates[prompt_template]
    else:
        prompt_template = ""
    # system_prompt = []
    # if prompt_template:
    #     history = [{"role": "system", "content": prompt_template}]


    payload = {
        "model": "gpt-3.5-turbo",
        "messages": [{"role": "system", "content": prompt_template},{"role": "user", "content": f"{inputs}"}],
        "temperature": 1.0,
        "top_p": 1.0,
        "n": 1,
        "stream": True,
        "presence_penalty": 0,
        "frequency_penalty": 0,
    }

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai_api_key}"
    }

    # print(f"chat_counter - {chat_counter}")
    if chat_counter != 0:
        messages = []
        # print(chatbot)
        # print(chatbot[-context_length:])
        # print(context_length)
        for data in chatbot[-context_length:]:
            temp1 = {}
            temp1["role"] = "user"
            temp1["content"] = data[0]
            temp2 = {}
            temp2["role"] = "assistant"
            temp2["content"] = data[1]
            messages.append(temp1)
            messages.append(temp2)
        temp3 = {}
        temp3["role"] = "user"
        temp3["content"] = inputs
        messages.append(temp3)
        # print(messages)
        # messages
        payload = {
            "model": "gpt-3.5-turbo",
            "messages": [{"role": "system", "content": prompt_template}]+messages,  # [{"role": "user", "content": f"{inputs}"}],
            "temperature": temperature,  # 1.0,
            "n": 1,
            "stream": True,
            "presence_penalty": 0,
            "frequency_penalty": 0,
        }



    history.append(inputs)
    # print(f"payload is - {payload}")
    # make a POST request to the API endpoint using the requests.post method, passing in stream=True
    # print('payload',payload)
    response = requests.post(API_URL, headers=headers, json=payload, stream=True)

    # print('response', response)
    # print('content',response.content)
    # print('text', response.text)
    if response.status_code != 200:
        try:
            payload['id'] = response.content['id']
            response = requests.post(API_URL, headers=headers, json=payload, stream=True)
            if response.status_code != 200:
                payload['id'] = response.content['id']
                response = requests.post(API_URL, headers=headers, json=payload, stream=True)
        except:
            pass

    # print('status_code', response.status_code)
    # response = requests.post(API_URL, headers=headers, json=payload, stream=True)
    token_counter = 0
    partial_words = ""
    counter = 0
    if response.status_code==200:
        chat_counter += 1
        # print('chunk')
        for chunk in response.iter_lines():
            # Skipping first chunk
            if counter == 0:
                counter += 1
                continue
            # check whether each line is non-empty
            chunk = chunk.decode("utf-8")[6:]
            if chunk:
                # print(chunk)
                if chunk=='[DONE]':
                    break
                resp: dict = json.loads(chunk)
                choices = resp.get("choices")
                if not choices:
                    continue
                delta = choices[0].get("delta")
                if not delta:
                    continue
                # decode each line as response data is in bytes
                if len(chunk) > 12 and "content" in resp['choices'][0]['delta']:
                    # if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
                    #  break
                    partial_words = partial_words + resp['choices'][0]["delta"]["content"]
                    # print(partial_words)
                    if token_counter == 0:
                        history.append(" " + partial_words)
                    else:
                        history[-1] = partial_words
                    chat = [(history[i], history[i + 1]) for i in
                            range(0, len(history) - 1, 2)]  # convert to tuples of list
                    # print(chat)
                    token_counter += 1
                    yield chat, history, chat_counter  # resembles {chatbot: chat, state: history}
    else:
        chat = [(history[i], history[i + 1]) for i in
                range(0, len(history) - 1, 2)]  # convert to tuples of list
        chat.append((inputs, "OpenAI服务器网络出现错误,请重试,或重启对话"))
        token_counter += 1
        yield chat, history, chat_counter  # resembles {chatbot: chat, state: history}
        # yield ['OpenAI服务器网络出现错误'], ['OpenAI服务器网络出现错误'], gr.update(value=0)



def reset_textbox():
    return gr.update(value='')

def clear_conversation(chatbot):
    return gr.update(value=None, visible=True), [], [], gr.update(value=0)


title = """<h1 align="center">🔥覃秉丰的ChatGPT🔥</h1>"""
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
User: <utterance>
Assistant: <utterance>
User: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of a gpt-3.5-turbo LLM.
"""
with gr.Blocks(css="""#col_container {width: 800px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 500px; overflow: auto;}
                #inputs {font-size: 20px;}
                #prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px;}""") as demo:
    gr.HTML(title)
    # gr.HTML(
    #     '''<center><a href="https://huggingface.co/spaces/QinBingFeng/chatgpt?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>点击图标复制App</center>''')
    with gr.Column(elem_id="col_container"):
        openai_api_key = gr.Textbox(type='password', label="输入你的API Key",placeholder="OpenAI API Key 或者 API2D")
        chatbot = gr.Chatbot(elem_id='chatbot')  # c
        inputs = gr.Textbox(show_label=False, placeholder="在这里输入内容",elem_id="inputs",value='')  # t
        state = gr.State([])  # s
        # state = gr.State(get_empty_state())
        b1 = gr.Button("提交")
        btn_clear_conversation = gr.Button("🔃 开启新的对话")

        # inputs, top_p, temperature, top_k, repetition_penalty
        with gr.Accordion("高级设置", open=False):
            context_length = gr.Slider(minimum=1, maximum=6, value=2, step=1, label="对话长度",
                                       info="关联之前的几轮对话,数值越高tokens消耗越多")
            temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Temperature",
                                    info="数值越高创造性越强")
            prompt_template = gr.Dropdown(label="选择机器人类型",
                                          choices=list(prompt_templates.keys()))
            prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
            # top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
            # repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
            chat_counter = gr.Number(value=0, visible=False, precision=0)

    inputs.submit(predict, [inputs, prompt_template, temperature, openai_api_key, chat_counter, context_length, chatbot, state],
                  [chatbot, state, chat_counter], )
    b1.click(predict, [inputs, prompt_template, temperature, openai_api_key, chat_counter, context_length, chatbot, state],
             [chatbot, state, chat_counter], )
    b1.click(reset_textbox, [], [inputs])

    btn_clear_conversation.click(clear_conversation, [], [inputs, chatbot, state, chat_counter])

    inputs.submit(reset_textbox, [], [inputs])
    prompt_template.change(on_prompt_template_change, inputs=[prompt_template], outputs=[prompt_template_preview])
    demo.load(get_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)

    # gr.Markdown(description)
    demo.queue(concurrency_count=10)
    demo.launch(debug=True)