from gradio_client import Client | |
client = Client("xu-song/self-chat") | |
def self_chat_demo(system_message, num_turn=4): | |
# 1. set system message for the agent | |
client.predict( | |
system=system_message, | |
api_name="/reset_state" | |
) | |
messages = [] | |
# 2. self chat | |
for num in range(num_turn): | |
# 注意:history 是 gr.state类型,不能通过API传参 | |
messages = client.predict( | |
chatbot=messages, | |
api_name="/generate" | |
) | |
if num % 2 == 0: | |
assert messages[-1][1] is None | |
print(f"Q: {messages[-1][0]}") | |
else: | |
print(f"A: {messages[-1][1]}") | |
if __name__ == "__main__": | |
self_chat_demo(system_message="你是一个小说家,擅长写武侠小说") | |