File size: 6,541 Bytes
18cd591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
'''
Contributed by SagsMug. Modified by binary-husky
https://github.com/oobabooga/text-generation-webui/pull/175
'''

import asyncio
import json
import random
import string
import websockets
import logging
import time
import threading
import importlib
from toolbox import get_conf, update_ui


def random_hash():
    letters = string.ascii_lowercase + string.digits
    return ''.join(random.choice(letters) for i in range(9))

async def run(context, max_token, temperature, top_p, addr, port):
    params = {
        'max_new_tokens': max_token,
        'do_sample': True,
        'temperature': temperature,
        'top_p': top_p,
        'typical_p': 1,
        'repetition_penalty': 1.05,
        'encoder_repetition_penalty': 1.0,
        'top_k': 0,
        'min_length': 0,
        'no_repeat_ngram_size': 0,
        'num_beams': 1,
        'penalty_alpha': 0,
        'length_penalty': 1,
        'early_stopping': True,
        'seed': -1,
    }
    session = random_hash()

    async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
        while content := json.loads(await websocket.recv()):
            #Python3.10 syntax, replace with if elif on older
            if content["msg"] ==  "send_hash":
                await websocket.send(json.dumps({
                    "session_hash": session,
                    "fn_index": 12
                }))
            elif content["msg"] ==  "estimation":
                pass
            elif content["msg"] ==  "send_data":
                await websocket.send(json.dumps({
                    "session_hash": session,
                    "fn_index": 12,
                    "data": [
                        context,
                        params['max_new_tokens'],
                        params['do_sample'],
                        params['temperature'],
                        params['top_p'],
                        params['typical_p'],
                        params['repetition_penalty'],
                        params['encoder_repetition_penalty'],
                        params['top_k'],
                        params['min_length'],
                        params['no_repeat_ngram_size'],
                        params['num_beams'],
                        params['penalty_alpha'],
                        params['length_penalty'],
                        params['early_stopping'],
                        params['seed'],
                    ]
                }))
            elif content["msg"] ==  "process_starts":
                pass
            elif content["msg"] in ["process_generating", "process_completed"]:
                yield content["output"]["data"][0]
                # You can search for your desired end indicator and 
                #  stop generation by closing the websocket here
                if (content["msg"] == "process_completed"):
                    break





def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
    """
        发送至chatGPT,流式获取输出。
        用于基础的对话功能。
        inputs 是本次问询的输入
        top_p, temperature是chatGPT的内部调优参数
        history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
        additional_fn代表点击的哪个按钮,按钮见functional.py
    """
    if additional_fn is not None:
        from core_functional import handle_core_functionality
        inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)

    raw_input = "What I would like to say is the following: " + inputs
    history.extend([inputs, ""])
    chatbot.append([inputs, ""])
    yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面

    prompt = raw_input
    tgui_say = ""

    model_name, addr_port = llm_kwargs['llm_model'].split('@')
    assert ':' in addr_port, "LLM_MODEL 格式不正确!" + llm_kwargs['llm_model']
    addr, port = addr_port.split(':')


    mutable = ["", time.time()]
    def run_coorotine(mutable):
        async def get_result(mutable):
            # "tgui:galactica-1.3b@localhost:7860"

            async for response in run(context=prompt, max_token=llm_kwargs['max_length'], 
                                      temperature=llm_kwargs['temperature'], 
                                      top_p=llm_kwargs['top_p'], addr=addr, port=port):
                print(response[len(mutable[0]):])
                mutable[0] = response
                if (time.time() - mutable[1]) > 3: 
                    print('exit when no listener')
                    break
        asyncio.run(get_result(mutable))

    thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
    thread_listen.start()

    while thread_listen.is_alive():
        time.sleep(1)
        mutable[1] = time.time()
        # Print intermediate steps
        if tgui_say != mutable[0]:
            tgui_say = mutable[0]
            history[-1] = tgui_say
            chatbot[-1] = (history[-2], history[-1])
            yield from update_ui(chatbot=chatbot, history=history) # 刷新界面




def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False):
    raw_input = "What I would like to say is the following: " + inputs
    prompt = raw_input
    tgui_say = ""
    model_name, addr_port = llm_kwargs['llm_model'].split('@')
    assert ':' in addr_port, "LLM_MODEL 格式不正确!" + llm_kwargs['llm_model']
    addr, port = addr_port.split(':')


    def run_coorotine(observe_window):
        async def get_result(observe_window):
            async for response in run(context=prompt, max_token=llm_kwargs['max_length'], 
                                      temperature=llm_kwargs['temperature'], 
                                      top_p=llm_kwargs['top_p'], addr=addr, port=port):
                print(response[len(observe_window[0]):])
                observe_window[0] = response
                if (time.time() - observe_window[1]) > 5: 
                    print('exit when no listener')
                    break
        asyncio.run(get_result(observe_window))
    thread_listen = threading.Thread(target=run_coorotine, args=(observe_window,))
    thread_listen.start()
    return observe_window[0]