|
""" |
|
## introduction |
|
Streaming mode is not supported in latest gradio api. But you can implement with http request. |
|
|
|
## gradio api |
|
from gradio_client import Client |
|
client = Client("xu-song/self-chat") |
|
fn_index = self._infer_fn_index("set_state") |
|
""" |
|
import string |
|
import requests |
|
import json |
|
import random |
|
|
|
|
|
def set_system(session_hash, system): |
|
url = 'https://xu-song-self-chat.hf.space/queue/join' |
|
payload = {"data": [system], "event_data": None, "fn_index": 0, "trigger_id": 4, |
|
"session_hash": session_hash} |
|
headers = {'Content-Type': 'application/json', 'Accept': '*/*'} |
|
response = requests.post(url, |
|
data=json.dumps(payload), |
|
headers=headers, |
|
stream=True) |
|
resp = response.json() |
|
event_id = resp["event_id"] |
|
return event_id |
|
|
|
|
|
def submit(session_hash, messages): |
|
""" |
|
""" |
|
url = 'https://xu-song-self-chat.hf.space/queue/join' |
|
payload = {"data": messages, "event_data": None, "fn_index": 1, "trigger_id": 8, "session_hash": session_hash} |
|
headers = {'Content-Type': 'application/json', 'Accept': '*/*'} |
|
response = requests.post(url, |
|
data=json.dumps(payload), |
|
headers=headers, |
|
stream=True) |
|
resp = response.json() |
|
event_id = resp["event_id"] |
|
return event_id |
|
|
|
|
|
def stream_message(session_hash): |
|
""" |
|
- javascript_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/js/src/utils/stream.ts#L42 |
|
- python_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/python/gradio_client/client.py#L248 |
|
""" |
|
sse_url = 'https://xu-song-self-chat.hf.space/queue/data' |
|
|
|
payload = {'session_hash': session_hash} |
|
headers = {'Content-Type': 'application/json', 'Accept': 'text/event-stream'} |
|
response = requests.get(sse_url, |
|
params=payload, |
|
|
|
headers=headers, |
|
stream=True) |
|
|
|
final_output = None |
|
|
|
if response.status_code == 200: |
|
for line in response.iter_lines(): |
|
if not line: |
|
continue |
|
decoded_line = line.decode('utf-8') |
|
|
|
if not decoded_line.startswith("data:"): |
|
continue |
|
data = json.loads(decoded_line.strip("data:").strip()) |
|
if "output" not in data: |
|
continue |
|
|
|
messages = data["output"]["data"][0] |
|
if not messages: |
|
continue |
|
message = messages[-1] |
|
if len(message) == 2: |
|
q, a = message |
|
content = a if a else q |
|
elif len(message) == 3: |
|
action, _, content = message |
|
else: |
|
raise Exception("response error") |
|
|
|
if data['msg'] == "process_completed": |
|
final_output = data["output"]["data"] |
|
else: |
|
print(content, end="") |
|
else: |
|
print(f"Request failed with status code: {response.status_code}") |
|
|
|
print("") |
|
response.close() |
|
return final_output |
|
|
|
|
|
def self_chat_demo(system_message, num_turn=4): |
|
session_hash = create_session_hash() |
|
print(f"SYSTEM: {system_message}") |
|
set_system(session_hash, system_message) |
|
messages = [[], None] |
|
|
|
for num in range(num_turn): |
|
if num % 2 == 0: |
|
print("===" * 10) |
|
print("USER: ", end="") |
|
else: |
|
print("ASSISTANT: ", end="") |
|
submit(session_hash, messages) |
|
messages = stream_message(session_hash) |
|
|
|
|
|
def create_session_hash(hash_size=10): |
|
""" |
|
random().toString(36).substring(2), which implemented in https://github.com/gradio-app/gradio/blob/v3.41.0/client/js/src/client.ts#L258 |
|
""" |
|
chars = string.ascii_letters + string.digits |
|
return ''.join(random.choice(chars) for _ in range(hash_size)) |
|
|
|
|
|
if __name__ == "__main__": |
|
self_chat_demo(system_message="你是一个小说家,擅长写武侠小说") |
|
|