File size: 3,608 Bytes
ebf3d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import asyncio
import json
from threading import Thread

from websockets.server import serve

from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply

PATH = '/api/v1/stream'


async def _handle_connection(websocket, path):

    if path == '/api/v1/stream':
        async for message in websocket:
            message = json.loads(message)

            prompt = message['prompt']
            generate_params = build_parameters(message)
            stopping_strings = generate_params.pop('stopping_strings')
            generate_params['stream'] = True

            generator = generate_reply(
                prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)

            # As we stream, only send the new bytes.
            skip_index = 0
            message_num = 0

            for a in generator:
                to_send = a[skip_index:]
                if to_send is None or chr(0xfffd) in to_send:  # partial unicode character, don't send it yet.
                    continue

                await websocket.send(json.dumps({
                    'event': 'text_stream',
                    'message_num': message_num,
                    'text': to_send
                }))

                await asyncio.sleep(0)
                skip_index += len(to_send)
                message_num += 1

            await websocket.send(json.dumps({
                'event': 'stream_end',
                'message_num': message_num
            }))

    elif path == '/api/v1/chat-stream':
        async for message in websocket:
            body = json.loads(message)

            user_input = body['user_input']
            history = body['history']
            generate_params = build_parameters(body, chat=True)
            generate_params['stream'] = True
            regenerate = body.get('regenerate', False)
            _continue = body.get('_continue', False)

            generator = generate_chat_reply(
                user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)

            message_num = 0
            for a in generator:
                await websocket.send(json.dumps({
                    'event': 'text_stream',
                    'message_num': message_num,
                    'history': a
                }))

                await asyncio.sleep(0)
                message_num += 1

            await websocket.send(json.dumps({
                'event': 'stream_end',
                'message_num': message_num
            }))

    else:
        print(f'Streaming api: unknown path: {path}')
        return


async def _run(host: str, port: int):
    async with serve(_handle_connection, host, port, ping_interval=None):
        await asyncio.Future()  # run forever


def _run_server(port: int, share: bool = False):
    address = '0.0.0.0' if shared.args.listen else '127.0.0.1'

    def on_start(public_url: str):
        public_url = public_url.replace('https://', 'wss://')
        print(f'Starting streaming server at public url {public_url}{PATH}')

    if share:
        try:
            try_start_cloudflared(port, max_attempts=3, on_start=on_start)
        except Exception as e:
            print(e)
    else:
        print(f'Starting streaming server at ws://{address}:{port}{PATH}')

    asyncio.run(_run(host=address, port=port))


def start_server(port: int, share: bool = False):
    Thread(target=_run_server, args=[port, share], daemon=True).start()