self-chat / client.py
xu song
update
dbf8ee3
raw
history blame
800 Bytes
from gradio_client import Client
def self_chat_demo(system_message, num_turn=4):
client = Client("xu-song/self-chat")
# 1. set system message for the agent
client.predict(
system=system_message,
api_name="/reset_state"
)
messages = []
# 2. start self-chatting
for num in range(num_turn):
# 注意:history 是 gr.state类型,不能通过API传参
messages = client.predict(
chatbot=messages,
api_name="/chat"
)
if num % 2 == 0:
assert messages[-1][1] is None
print(f"USER: {messages[-1][0]}")
else:
print(f"ASSISTANT: {messages[-1][1]}")
if __name__ == "__main__":
self_chat_demo(system_message="你是一个小说家,擅长写武侠小说")