File size: 4,643 Bytes
27d01c0
 
 
 
 
 
 
a287230
 
27d01c0
1fa9a79
27d01c0
 
 
 
 
 
 
 
 
 
 
 
 
 
1fa9a79
27d01c0
 
 
 
 
 
 
 
 
 
a287230
27d01c0
 
1fa9a79
 
 
 
 
 
 
 
 
 
 
27d01c0
1fa9a79
 
27d01c0
 
 
1fa9a79
 
 
 
27d01c0
1fa9a79
27d01c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fa9a79
27d01c0
1fa9a79
 
616c877
27d01c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fa9a79
27d01c0
1fa9a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import gradio as gr
import logging
import traceback
import requests
import importlib

try: from config_private import proxies, API_URL, API_KEY, TIMEOUT_SECONDS # 放自己的秘密如API和代理网址 os.path.exists('config_private.py')
except: from config import proxies, API_URL, API_KEY, TIMEOUT_SECONDS

timeout_bot_msg = 'Request timeout, network error. please check proxy settings in config.py.'

def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', retry=False, 
            stream = True, additional_fn=None):

    if additional_fn is not None:
        import functional
        importlib.reload(functional)
        functional = functional.get_functionals()
        inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]

    if stream:
        raw_input = inputs
        logging.info(f'[raw_input] {raw_input}')
        chatbot.append((inputs, ""))
        yield chatbot, history, "等待响应"

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }

    chat_counter = len(history) // 2

    print(f"chat_counter - {chat_counter}")

    messages = [{"role": "system", "content": system_prompt}]
    if chat_counter:
        for index in range(0, 2*chat_counter, 2):
            what_i_have_asked = {}
            what_i_have_asked["role"] = "user"
            what_i_have_asked["content"] = history[index]
            what_gpt_answer = {}
            what_gpt_answer["role"] = "assistant"
            what_gpt_answer["content"] = history[index+1]
            if what_i_have_asked["content"] != "":
                if not (what_gpt_answer["content"] != "" or retry): continue
                if what_gpt_answer["content"] == timeout_bot_msg: continue
                messages.append(what_i_have_asked)
                messages.append(what_gpt_answer)
            else:
                messages[-1]['content'] = what_gpt_answer['content']

    if retry and chat_counter:
        messages.pop()
    else:
        what_i_ask_now = {}
        what_i_ask_now["role"] = "user"
        what_i_ask_now["content"] = inputs
        messages.append(what_i_ask_now)
        chat_counter += 1

    # messages
    payload = {
        "model": "gpt-3.5-turbo",
        # "model": "gpt-4",
        "messages": messages, 
        "temperature": temperature,  # 1.0,
        "top_p": top_p,  # 1.0,
        "n": 1,
        "stream": stream,
        "presence_penalty": 0,
        "frequency_penalty": 0,
    }

    history.append(inputs)

    try:
        # make a POST request to the API endpoint using the requests.post method, passing in stream=True
        response = requests.post(API_URL, headers=headers, proxies=proxies,
                                json=payload, stream=True, timeout=TIMEOUT_SECONDS)
    except:
        chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
        yield chatbot, history, "请求超时"
        raise TimeoutError

    token_counter = 0
    partial_words = ""

    counter = 0
    if stream:
        stream_response =  response.iter_lines()
        while True:
            chunk = next(stream_response)
            if chunk == b'data: [DONE]':
                break

            if counter == 0:
                counter += 1
                continue
            counter += 1
            # check whether each line is non-empty
            if chunk:
                # decode each line as response data is in bytes
                try:
                    if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
                        logging.info(f'[response] {chatbot[-1][-1]}')
                        break
                except Exception as e:
                    traceback.print_exc()
                    print(chunk.decode())

                try:
                    chunkjson = json.loads(chunk.decode()[6:])
                    status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
                    partial_words = partial_words + json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
                    if token_counter == 0:
                        history.append(" " + partial_words)
                    else:
                        history[-1] = partial_words
                    chatbot[-1] = (history[-2], history[-1])
                    token_counter += 1
                    yield chatbot, history, status_text

                except Exception as e:
                    traceback.print_exc()
                    print(chunk.decode())
                    yield chatbot, history, "Json解析不合常规"