xu song commited on
Commit
dbf8ee3
1 Parent(s): aad5245
Files changed (4) hide show
  1. README.md +23 -6
  2. client.py +6 -6
  3. client_streaming.py +123 -0
  4. models/vllm_qwen2.py +4 -0
README.md CHANGED
@@ -13,20 +13,37 @@ tags:
13
  short_description: Generating synthetic data via self-chat
14
  ---
15
 
16
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
17
 
18
 
19
- ## 安装问题
20
 
21
- 直接从源码安装,推理速度较慢,因此加入以下参数。
 
 
22
  ```sh
23
  pip install git+https://github.com/abetlen/llama-cpp-python.git -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
24
  ```
25
 
26
 
27
- ## Serverless Inference API
28
 
29
 
30
- client.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- ## Reference
 
13
  short_description: Generating synthetic data via self-chat
14
  ---
15
 
 
16
 
17
 
18
+ ## ss
19
 
20
+ ## Dependency
21
+
22
+ Install llama-cpp-python with the following arguments
23
  ```sh
24
  pip install git+https://github.com/abetlen/llama-cpp-python.git -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
25
  ```
26
 
27
 
 
28
 
29
 
30
+ ## Local inference
31
+
32
+ ```sh
33
+ python models/cpp_qwen2.py
34
+ ```
35
+
36
+ ## Serverless Inference
37
+
38
+
39
+ ```sh
40
+ python client_gradio.py
41
+ ```
42
+
43
+
44
+ For streaming inference
45
+ ```sh
46
+ python client_streaming.py
47
+ ```
48
+
49
 
 
client.py CHANGED
@@ -1,27 +1,27 @@
1
  from gradio_client import Client
2
 
3
- client = Client("xu-song/self-chat")
4
-
5
 
6
  def self_chat_demo(system_message, num_turn=4):
 
 
7
  # 1. set system message for the agent
8
  client.predict(
9
  system=system_message,
10
  api_name="/reset_state"
11
  )
12
  messages = []
13
- # 2. self chat
14
  for num in range(num_turn):
15
  # 注意:history 是 gr.state类型,不能通过API传参
16
  messages = client.predict(
17
  chatbot=messages,
18
- api_name="/generate"
19
  )
20
  if num % 2 == 0:
21
  assert messages[-1][1] is None
22
- print(f"Q: {messages[-1][0]}")
23
  else:
24
- print(f"A: {messages[-1][1]}")
25
 
26
 
27
  if __name__ == "__main__":
 
1
  from gradio_client import Client
2
 
 
 
3
 
4
  def self_chat_demo(system_message, num_turn=4):
5
+ client = Client("xu-song/self-chat")
6
+
7
  # 1. set system message for the agent
8
  client.predict(
9
  system=system_message,
10
  api_name="/reset_state"
11
  )
12
  messages = []
13
+ # 2. start self-chatting
14
  for num in range(num_turn):
15
  # 注意:history 是 gr.state类型,不能通过API传参
16
  messages = client.predict(
17
  chatbot=messages,
18
+ api_name="/chat"
19
  )
20
  if num % 2 == 0:
21
  assert messages[-1][1] is None
22
+ print(f"USER: {messages[-1][0]}")
23
  else:
24
+ print(f"ASSISTANT: {messages[-1][1]}")
25
 
26
 
27
  if __name__ == "__main__":
client_streaming.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ## introduction
3
+ Streaming mode is not supported in latest gradio api. But you can implement with http request.
4
+
5
+ ## gradio api
6
+ from gradio_client import Client
7
+ client = Client("xu-song/self-chat")
8
+ fn_index = self._infer_fn_index("set_state")
9
+ """
10
+ import string
11
+ import requests
12
+ import json
13
+ import random
14
+
15
+
16
+ def set_system(session_hash, system):
17
+ url = 'https://xu-song-self-chat.hf.space/queue/join'
18
+ payload = {"data": [system], "event_data": None, "fn_index": 0, "trigger_id": 4,
19
+ "session_hash": session_hash}
20
+ headers = {'Content-Type': 'application/json', 'Accept': '*/*'}
21
+ response = requests.post(url,
22
+ data=json.dumps(payload),
23
+ headers=headers,
24
+ stream=True)
25
+ resp = response.json()
26
+ event_id = resp["event_id"]
27
+ return event_id
28
+
29
+
30
+ def submit(session_hash, messages):
31
+ """
32
+ """
33
+ url = 'https://xu-song-self-chat.hf.space/queue/join'
34
+ payload = {"data": messages, "event_data": None, "fn_index": 1, "trigger_id": 8, "session_hash": session_hash}
35
+ headers = {'Content-Type': 'application/json', 'Accept': '*/*'}
36
+ response = requests.post(url,
37
+ data=json.dumps(payload),
38
+ headers=headers,
39
+ stream=True)
40
+ resp = response.json()
41
+ event_id = resp["event_id"]
42
+ return event_id
43
+
44
+
45
+ def stream_message(session_hash):
46
+ """
47
+ - javascript_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/js/src/utils/stream.ts#L42
48
+ - python_client: https://github.com/gradio-app/gradio/blob/9f0fe392c9f2604b9f937b9414e67d9b71b69109/client/python/gradio_client/client.py#L248
49
+ """
50
+ sse_url = 'https://xu-song-self-chat.hf.space/queue/data'
51
+ # sse_url = f'https://xu-song-self-chat.hf.space/queue/data?session_hash={session_hash}' # 等价
52
+ payload = {'session_hash': session_hash}
53
+ headers = {'Content-Type': 'application/json', 'Accept': 'text/event-stream'} # Headers for SSE request
54
+ response = requests.get(sse_url,
55
+ params=payload, # 等价于拼接在url中。
56
+ # json=payload, # ?
57
+ headers=headers,
58
+ stream=True)
59
+
60
+ final_output = None
61
+ # Check if the request was successful
62
+ if response.status_code == 200:
63
+ for line in response.iter_lines():
64
+ if not line:
65
+ continue
66
+ decoded_line = line.decode('utf-8')
67
+ # print(decoded_line)
68
+ if not decoded_line.startswith("data:"):
69
+ continue
70
+ data = json.loads(decoded_line.strip("data:").strip())
71
+ if "output" not in data:
72
+ continue
73
+
74
+ messages = data["output"]["data"][0]
75
+ if not messages:
76
+ continue
77
+ message = messages[-1]
78
+ if len(message) == 2:
79
+ q, a = message
80
+ content = a if a else q
81
+ elif len(message) == 3:
82
+ action, _, content = message
83
+ else:
84
+ raise Exception("response error")
85
+
86
+ if data['msg'] == "process_completed":
87
+ final_output = data["output"]["data"]
88
+ else:
89
+ print(content, end="")
90
+ else:
91
+ print(f"Request failed with status code: {response.status_code}")
92
+
93
+ print("")
94
+ response.close() # Close the connection
95
+ return final_output
96
+
97
+
98
+ def self_chat_demo(system_message, num_turn=4):
99
+ session_hash = create_session_hash()
100
+ print(f"SYSTEM: {system_message}")
101
+ set_system(session_hash, system_message)
102
+ messages = [[], None]
103
+
104
+ for num in range(num_turn):
105
+ if num % 2 == 0:
106
+ print("===" * 10)
107
+ print("USER: ", end="")
108
+ else:
109
+ print("ASSISTANT: ", end="")
110
+ submit(session_hash, messages) #
111
+ messages = stream_message(session_hash)
112
+
113
+
114
+ def create_session_hash(hash_size=10):
115
+ """
116
+ random().toString(36).substring(2), which implemented in https://github.com/gradio-app/gradio/blob/v3.41.0/client/js/src/client.ts#L258
117
+ """
118
+ chars = string.ascii_letters + string.digits
119
+ return ''.join(random.choice(chars) for _ in range(hash_size))
120
+
121
+
122
+ if __name__ == "__main__":
123
+ self_chat_demo(system_message="你是一个小说家,擅长写武侠小说")
models/vllm_qwen2.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+
3
+ https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_with_prefix.py
4
+ """