File size: 4,421 Bytes
dbf8ee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
## introduction
Streaming mode is not supported in latest gradio api. But you can implement with http request.

## gradio api
from gradio_client import Client
client = Client("xu-song/self-chat")
fn_index = self._infer_fn_index("set_state")
"""
import string
import requests
import json
import random


def set_system(session_hash, system):
    url = 'https://xu-song-self-chat.hf.space/queue/join'
    payload = {"data": [system], "event_data": None, "fn_index": 0, "trigger_id": 4,
               "session_hash": session_hash}
    headers = {'Content-Type': 'application/json', 'Accept': '*/*'}
    response = requests.post(url,
                             data=json.dumps(payload),
                             headers=headers,
                             stream=True)
    resp = response.json()
    event_id = resp["event_id"]
    return event_id


def submit(session_hash, messages):
    """
    """
    url = 'https://xu-song-self-chat.hf.space/queue/join'
    payload = {"data": messages, "event_data": None, "fn_index": 1, "trigger_id": 8, "session_hash": session_hash}
    headers = {'Content-Type': 'application/json', 'Accept': '*/*'}
    response = requests.post(url,
                             data=json.dumps(payload),
                             headers=headers,
                             stream=True)
    resp = response.json()
    event_id = resp["event_id"]
    return event_id


def stream_message(session_hash):
    """
    - javascript_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/js/src/utils/stream.ts#L42
    - python_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/python/gradio_client/client.py#L248
    """
    sse_url = 'https://xu-song-self-chat.hf.space/queue/data'
    # sse_url = f'https://xu-song-self-chat.hf.space/queue/data?session_hash={session_hash}'  # 等价
    payload = {'session_hash': session_hash}
    headers = {'Content-Type': 'application/json', 'Accept': 'text/event-stream'}  # Headers for SSE request
    response = requests.get(sse_url,
                            params=payload,  # 等价于拼接在url中。
                            # json=payload,  # ?
                            headers=headers,
                            stream=True)

    final_output = None
    # Check if the request was successful
    if response.status_code == 200:
        for line in response.iter_lines():
            if not line:
                continue
            decoded_line = line.decode('utf-8')
            # print(decoded_line)
            if not decoded_line.startswith("data:"):
                continue
            data = json.loads(decoded_line.strip("data:").strip())
            if "output" not in data:
                continue

            messages = data["output"]["data"][0]
            if not messages:
                continue
            message = messages[-1]
            if len(message) == 2:
                q, a = message
                content = a if a else q
            elif len(message) == 3:
                action, _, content = message
            else:
                raise Exception("response error")

            if data['msg'] == "process_completed":
                final_output = data["output"]["data"]
            else:
                print(content, end="")
    else:
        print(f"Request failed with status code: {response.status_code}")

    print("")
    response.close()  # Close the connection
    return final_output


def self_chat_demo(system_message, num_turn=4):
    session_hash = create_session_hash()
    print(f"SYSTEM: {system_message}")
    set_system(session_hash, system_message)
    messages = [[], None]

    for num in range(num_turn):
        if num % 2 == 0:
            print("===" * 10)
            print("USER: ", end="")
        else:
            print("ASSISTANT: ", end="")
        submit(session_hash, messages)  #
        messages = stream_message(session_hash)


def create_session_hash(hash_size=10):
    """
    random().toString(36).substring(2), which implemented in https://github.com/gradio-app/gradio/blob/v3.41.0/client/js/src/client.ts#L258
    """
    chars = string.ascii_letters + string.digits
    return ''.join(random.choice(chars) for _ in range(hash_size))


if __name__ == "__main__":
    self_chat_demo(system_message="你是一个小说家,擅长写武侠小说")