Jar2023 commited on
Commit
25be583
1 Parent(s): 45b7a2a

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,143 @@
1
- ---
2
- title: Basic Demo
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.36.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: basic_demo
3
+ app_file: trans_web_demo.py
4
+ sdk: gradio
5
+ sdk_version: 4.36.0
6
+ ---
7
+ # Basic Demo
8
+
9
+ Read this in [English](README_en.md).
10
+
11
+ 本 demo 中,你将体验到如何使用 GLM-4-9B 开源模型进行基本的任务。
12
+
13
+ 请严格按照文档的步骤进行操作,以避免不必要的错误。
14
+
15
+ ## 设备和依赖检查
16
+
17
+ ### 相关推理测试数据
18
+
19
+ **本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
20
+
21
+ 测试硬件信息:
22
+
23
+ + OS: Ubuntu 22.04
24
+ + Memory: 512GB
25
+ + Python: 3.12.3
26
+ + CUDA Version: 12.3
27
+ + GPU Driver: 535.104.05
28
+ + GPU: NVIDIA A100-SXM4-80GB * 8
29
+
30
+ 相关推理的压力测试数据如下:
31
+
32
+ **所有测试均在单张GPU上进行测试,所有显存消耗都按照峰值左右进行测算**
33
+
34
+ #### GLM-4-9B-Chat
35
+
36
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
37
+ |------|-------|------------|---------------|--------------|
38
+ | BF16 | 19 GB | 0.2s | 27.8 tokens/s | 输入长度为 1000 |
39
+ | BF16 | 21 GB | 0.8s | 31.8 tokens/s | 输入长度为 8000 |
40
+ | BF16 | 28 GB | 4.3s | 14.4 tokens/s | 输入长度为 32000 |
41
+ | BF16 | 58 GB | 38.1s | 3.4 tokens/s | 输入长度为 128000 |
42
+
43
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
44
+ |------|-------|------------|---------------|-------------|
45
+ | INT4 | 8 GB | 0.2s | 23.3 tokens/s | 输入长度为 1000 |
46
+ | INT4 | 10 GB | 0.8s | 23.4 tokens/s | 输入长度为 8000 |
47
+ | INT4 | 17 GB | 4.3s | 14.6 tokens/s | 输入长度为 32000 |
48
+
49
+ ### GLM-4-9B-Chat-1M
50
+
51
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
52
+ |------|-------|------------|--------------|--------------|
53
+ | BF16 | 75 GB | 98.4s | 2.3 tokens/s | 输入长度为 200000 |
54
+
55
+ 如果您的输入超过200K,我们建议您使用vLLM后端进行多卡推理,以获得更好的性能。
56
+
57
+ #### GLM-4V-9B
58
+
59
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
60
+ |------|-------|------------|---------------|------------|
61
+ | BF16 | 28 GB | 0.1s | 33.4 tokens/s | 输入长度为 1000 |
62
+ | BF16 | 33 GB | 0.7s | 39.2 tokens/s | 输入长度为 8000 |
63
+
64
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
65
+ |------|-------|------------|---------------|------------|
66
+ | INT4 | 10 GB | 0.1s | 28.7 tokens/s | 输入长度为 1000 |
67
+ | INT4 | 15 GB | 0.8s | 24.2 tokens/s | 输入长度为 8000 |
68
+
69
+ ### 最低硬件要求
70
+
71
+ 如果您希望运行官方提供的最基础代码 (transformers 后端) 您需要:
72
+
73
+ + Python >= 3.10
74
+ + 内存不少于 32 GB
75
+
76
+ 如果您希望运行官方提供的本文件夹的所有代码,您还需要:
77
+
78
+ + Linux 操作系统 (Debian 系列最佳)
79
+ + 大于 8GB 显存的,支持 CUDA 或者 ROCM 并且支持 `BF16` 推理的 GPU 设备。(`FP16` 精度无法训练,推理有小概率出现问题)
80
+
81
+ 安装依赖
82
+
83
+ ```shell
84
+ pip install -r requirements.txt
85
+ ```
86
+
87
+ ## 基础功能调用
88
+
89
+ **除非特殊说明,本文件夹所有 demo 并不支持 Function Call 和 All Tools 等进阶用法**
90
+
91
+ ### 使用 transformers 后端代码
92
+
93
+ + 使用命令行与 GLM-4-9B 模型进行对话。
94
+
95
+ ```shell
96
+ python trans_cli_demo.py # GLM-4-9B-Chat
97
+ python trans_cli_vision_demo.py # GLM-4V-9B
98
+ ```
99
+
100
+ + 使用 Gradio 网页端与 GLM-4-9B-Chat 模型进行对话。
101
+
102
+ ```shell
103
+ python trans_web_demo.py
104
+ ```
105
+
106
+ + 使用 Batch 推理。
107
+
108
+ ```shell
109
+ python cli_batch_request_demo.py
110
+ ```
111
+
112
+ ### 使用 vLLM 后端代码
113
+
114
+ + 使用命令行与 GLM-4-9B-Chat 模型进行对话。
115
+
116
+ ```shell
117
+ python vllm_cli_demo.py
118
+ ```
119
+
120
+ + 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
121
+
122
+ 启动服务端:
123
+
124
+ ```shell
125
+ python openai_api_server.py
126
+ ```
127
+
128
+ 客户端请求:
129
+
130
+ ```shell
131
+ python openai_api_request.py
132
+ ```
133
+
134
+ ## 压力测试
135
+
136
+ 用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
137
+
138
+ ```shell
139
+ python trans_stress_test.py
140
+ ```
141
+
142
+
143
+
README_en.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic Demo
2
+
3
+ In this demo, you will experience how to use the GLM-4-9B open source model to perform basic tasks.
4
+
5
+ Please follow the steps in the document strictly to avoid unnecessary errors.
6
+
7
+ ## Device and dependency check
8
+
9
+ ### Related inference test data
10
+
11
+ **The data in this document are tested in the following hardware environment. The actual operating environment
12
+ requirements and the GPU memory occupied by the operation are slightly different. Please refer to the actual operating
13
+ environment.**
14
+
15
+ Test hardware information:
16
+
17
+ + OS: Ubuntu 22.04
18
+ + Memory: 512GB
19
+ + Python: 3.12.3
20
+ + CUDA Version: 12.3
21
+ + GPU Driver: 535.104.05
22
+ + GPU: NVIDIA A100-SXM4-80GB * 8
23
+
24
+ The stress test data of relevant inference are as follows:
25
+
26
+ **All tests are performed on a single GPU, and all GPU memory consumption is calculated based on the peak value**
27
+
28
+ #
29
+
30
+ ### GLM-4-9B-Chat
31
+
32
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
33
+ |-------|------------|------------|---------------|------------------------|
34
+ | BF16 | 19 GB | 0.2s | 27.8 tokens/s | Input length is 1000 |
35
+ | BF16 | 21 GB | 0.8s | 31.8 tokens/s | Input length is 8000 |
36
+ | BF16 | 28 GB | 4.3s | 14.4 tokens/s | Input length is 32000 |
37
+ | BF16 | 58 GB | 38.1s | 3.4 tokens/s | Input length is 128000 |
38
+
39
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
40
+ |-------|------------|------------|---------------|-----------------------|
41
+ | INT4 | 8 GB | 0.2s | 23.3 tokens/s | Input length is 1000 |
42
+ | INT4 | 10 GB | 0.8s | 23.4 tokens/s | Input length is 8000 |
43
+ | INT4 | 17 GB | 4.3s | 14.6 tokens/s | Input length is 32000 |
44
+
45
+ ### GLM-4-9B-Chat-1M
46
+
47
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
48
+ |-------|------------|------------|------------------|------------------------|
49
+ | BF16 | 74497MiB | 98.4s | 2.3653 tokens/s | Input length is 200000 |
50
+
51
+ If your input exceeds 200K, we recommend that you use the vLLM backend with multi gpus for inference to get better
52
+ performance.
53
+
54
+ #### GLM-4V-9B
55
+
56
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
57
+ |-------|------------|------------|---------------|----------------------|
58
+ | BF16 | 28 GB | 0.1s | 33.4 tokens/s | Input length is 1000 |
59
+ | BF16 | 33 GB | 0.7s | 39.2 tokens/s | Input length is 8000 |
60
+
61
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
62
+ |-------|------------|------------|---------------|----------------------|
63
+ | INT4 | 10 GB | 0.1s | 28.7 tokens/s | Input length is 1000 |
64
+ | INT4 | 15 GB | 0.8s | 24.2 tokens/s | Input length is 8000 |
65
+
66
+ ### Minimum hardware requirements
67
+
68
+ If you want to run the most basic code provided by the official (transformers backend) you need:
69
+
70
+ + Python >= 3.10
71
+ + Memory of at least 32 GB
72
+
73
+ If you want to run all the codes in this folder provided by the official, you also need:
74
+
75
+ + Linux operating system (Debian series is best)
76
+ + GPU device with more than 8GB GPU memory, supporting CUDA or ROCM and supporting `BF16` reasoning (`FP16` precision
77
+ cannot be finetuned, and there is a small probability of problems in infering)
78
+
79
+ Install dependencies
80
+
81
+ ```shell
82
+ pip install -r requirements.txt
83
+ ```
84
+
85
+ ## Basic function calls
86
+
87
+ **Unless otherwise specified, all demos in this folder do not support advanced usage such as Function Call and All Tools
88
+ **
89
+
90
+ ### Use transformers backend code
91
+
92
+ + Use the command line to communicate with the GLM-4-9B model.
93
+
94
+ ```shell
95
+ python trans_cli_demo.py # GLM-4-9B-Chat
96
+ python trans_cli_vision_demo.py # GLM-4V-9B
97
+ ```
98
+
99
+ + Use the Gradio web client to communicate with the GLM-4-9B-Chat model.
100
+
101
+ ```shell
102
+ python trans_web_demo.py
103
+ ```
104
+
105
+ + Use Batch inference.
106
+
107
+ ```shell
108
+ python cli_batch_request_demo.py
109
+ ```
110
+
111
+ ### Use vLLM backend code
112
+
113
+ + Use the command line to communicate with the GLM-4-9B-Chat model.
114
+
115
+ ```shell
116
+ python vllm_cli_demo.py
117
+ ```
118
+
119
+ + Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
120
+ demo supports Function Call and All Tools functions.
121
+
122
+ Start the server:
123
+
124
+ ```shell
125
+ python openai_api_server.py
126
+ ```
127
+
128
+ Client request:
129
+
130
+ ```shell
131
+ python openai_api_request.py
132
+ ```
133
+
134
+ ## Stress test
135
+
136
+ Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
137
+
138
+ ```shell
139
+ python trans_stress_test.py
140
+ ```
openai_api_request.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenAI API to interact with the model.
3
+ """
4
+
5
+ from openai import OpenAI
6
+
7
+ base_url = "http://127.0.0.1:8000/v1/"
8
+ client = OpenAI(api_key="EMPTY", base_url=base_url)
9
+
10
+
11
+ def function_chat():
12
+ messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
13
+ tools = [
14
+ {
15
+ "type": "function",
16
+ "function": {
17
+ "name": "get_current_weather",
18
+ "description": "Get the current weather in a given location",
19
+ "parameters": {
20
+ "type": "object",
21
+ "properties": {
22
+ "location": {
23
+ "type": "string",
24
+ "description": "The city and state, e.g. San Francisco, CA",
25
+ },
26
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
27
+ },
28
+ "required": ["location"],
29
+ },
30
+ },
31
+ }
32
+ ]
33
+
34
+ # All Tools 能力: 绘图
35
+ # messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
36
+ # tools = [{"type": "cogview"}]
37
+ #
38
+ # All Tools 能力: 联网查询
39
+ # messages = [{"role": "user", "content": "今天黄金的价格"}]
40
+ # tools = [{"type": "simple_browser"}]
41
+
42
+ response = client.chat.completions.create(
43
+ model="glm-4",
44
+ messages=messages,
45
+ tools=tools,
46
+ tool_choice="auto", # use "auto" to let the model choose the tool automatically
47
+ # tool_choice={"type": "function", "function": {"name": "my_function"}},
48
+ )
49
+ if response:
50
+ content = response.choices[0].message.content
51
+ print(content)
52
+ else:
53
+ print("Error:", response.status_code)
54
+
55
+
56
+ def simple_chat(use_stream=False):
57
+ messages = [
58
+ {
59
+ "role": "system",
60
+ "content": "你是 GLM-4,请你热情回答用户的问题。",
61
+ },
62
+ {
63
+ "role": "user",
64
+ "content": "你好,请你用生动的话语给我讲一个小故事吧"
65
+ }
66
+ ]
67
+ response = client.chat.completions.create(
68
+ model="glm-4",
69
+ messages=messages,
70
+ stream=use_stream,
71
+ max_tokens=1024,
72
+ temperature=0.8,
73
+ presence_penalty=1.1,
74
+ top_p=0.8)
75
+ if response:
76
+ if use_stream:
77
+ for chunk in response:
78
+ print(chunk.choices[0].delta.content)
79
+ else:
80
+ content = response.choices[0].message.content
81
+ print(content)
82
+ else:
83
+ print("Error:", response.status_code)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ simple_chat()
88
+ function_chat()
openai_api_server.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from asyncio.log import logger
4
+
5
+ import uvicorn
6
+ import gc
7
+ import json
8
+ import torch
9
+
10
+ from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
11
+ from fastapi import FastAPI, HTTPException, Response
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from contextlib import asynccontextmanager
14
+ from typing import List, Literal, Optional, Union
15
+ from pydantic import BaseModel, Field
16
+ from transformers import AutoTokenizer, LogitsProcessor
17
+ from sse_starlette.sse import EventSourceResponse
18
+
19
+ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
20
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
21
+ MAX_MODEL_LENGTH = 8192
22
+
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ yield
27
+ if torch.cuda.is_available():
28
+ torch.cuda.empty_cache()
29
+ torch.cuda.ipc_collect()
30
+
31
+
32
+ app = FastAPI(lifespan=lifespan)
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+
43
+ class ModelCard(BaseModel):
44
+ id: str
45
+ object: str = "model"
46
+ created: int = Field(default_factory=lambda: int(time.time()))
47
+ owned_by: str = "owner"
48
+ root: Optional[str] = None
49
+ parent: Optional[str] = None
50
+ permission: Optional[list] = None
51
+
52
+
53
+ class ModelList(BaseModel):
54
+ object: str = "list"
55
+ data: List[ModelCard] = []
56
+
57
+
58
+ class FunctionCallResponse(BaseModel):
59
+ name: Optional[str] = None
60
+ arguments: Optional[str] = None
61
+
62
+
63
+ class ChatMessage(BaseModel):
64
+ role: Literal["user", "assistant", "system", "tool"]
65
+ content: str = None
66
+ name: Optional[str] = None
67
+ function_call: Optional[FunctionCallResponse] = None
68
+
69
+
70
+ class DeltaMessage(BaseModel):
71
+ role: Optional[Literal["user", "assistant", "system"]] = None
72
+ content: Optional[str] = None
73
+ function_call: Optional[FunctionCallResponse] = None
74
+
75
+
76
+ class EmbeddingRequest(BaseModel):
77
+ input: Union[List[str], str]
78
+ model: str
79
+
80
+
81
+ class CompletionUsage(BaseModel):
82
+ prompt_tokens: int
83
+ completion_tokens: int
84
+ total_tokens: int
85
+
86
+
87
+ class EmbeddingResponse(BaseModel):
88
+ data: list
89
+ model: str
90
+ object: str
91
+ usage: CompletionUsage
92
+
93
+
94
+ class UsageInfo(BaseModel):
95
+ prompt_tokens: int = 0
96
+ total_tokens: int = 0
97
+ completion_tokens: Optional[int] = 0
98
+
99
+
100
+ class ChatCompletionRequest(BaseModel):
101
+ model: str
102
+ messages: List[ChatMessage]
103
+ temperature: Optional[float] = 0.8
104
+ top_p: Optional[float] = 0.8
105
+ max_tokens: Optional[int] = None
106
+ stream: Optional[bool] = False
107
+ tools: Optional[Union[dict, List[dict]]] = None
108
+ tool_choice: Optional[Union[str, dict]] = "None"
109
+ repetition_penalty: Optional[float] = 1.1
110
+
111
+
112
+ class ChatCompletionResponseChoice(BaseModel):
113
+ index: int
114
+ message: ChatMessage
115
+ finish_reason: Literal["stop", "length", "function_call"]
116
+
117
+
118
+ class ChatCompletionResponseStreamChoice(BaseModel):
119
+ delta: DeltaMessage
120
+ finish_reason: Optional[Literal["stop", "length", "function_call"]]
121
+ index: int
122
+
123
+
124
+ class ChatCompletionResponse(BaseModel):
125
+ model: str
126
+ id: str
127
+ object: Literal["chat.completion", "chat.completion.chunk"]
128
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
129
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
130
+ usage: Optional[UsageInfo] = None
131
+
132
+
133
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
134
+ def __call__(
135
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
136
+ ) -> torch.FloatTensor:
137
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
138
+ scores.zero_()
139
+ scores[..., 5] = 5e4
140
+ return scores
141
+
142
+
143
+ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
144
+ content = ""
145
+ for response in output.split("<|assistant|>"):
146
+ if "\n" in response:
147
+ metadata, content = response.split("\n", maxsplit=1)
148
+ else:
149
+ metadata, content = "", response
150
+ if not metadata.strip():
151
+ content = content.strip()
152
+ else:
153
+ if use_tool:
154
+ parameters = eval(content.strip())
155
+ content = {
156
+ "name": metadata.strip(),
157
+ "arguments": json.dumps(parameters, ensure_ascii=False)
158
+ }
159
+ else:
160
+ content = {
161
+ "name": metadata.strip(),
162
+ "content": content
163
+ }
164
+ return content
165
+
166
+
167
+ @torch.inference_mode()
168
+ async def generate_stream_glm4(params):
169
+ messages = params["messages"]
170
+ tools = params["tools"]
171
+ tool_choice = params["tool_choice"]
172
+ temperature = float(params.get("temperature", 1.0))
173
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
174
+ top_p = float(params.get("top_p", 1.0))
175
+ max_new_tokens = int(params.get("max_tokens", 8192))
176
+ messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
177
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
178
+ params_dict = {
179
+ "n": 1,
180
+ "best_of": 1,
181
+ "presence_penalty": 1.0,
182
+ "frequency_penalty": 0.0,
183
+ "temperature": temperature,
184
+ "top_p": top_p,
185
+ "top_k": -1,
186
+ "repetition_penalty": repetition_penalty,
187
+ "use_beam_search": False,
188
+ "length_penalty": 1,
189
+ "early_stopping": False,
190
+ "stop_token_ids": [151329, 151336, 151338],
191
+ "ignore_eos": False,
192
+ "max_tokens": max_new_tokens,
193
+ "logprobs": None,
194
+ "prompt_logprobs": None,
195
+ "skip_special_tokens": True,
196
+ }
197
+ sampling_params = SamplingParams(**params_dict)
198
+ async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
199
+ output_len = len(output.outputs[0].token_ids)
200
+ input_len = len(output.prompt_token_ids)
201
+ ret = {
202
+ "text": output.outputs[0].text,
203
+ "usage": {
204
+ "prompt_tokens": input_len,
205
+ "completion_tokens": output_len,
206
+ "total_tokens": output_len + input_len
207
+ },
208
+ "finish_reason": output.outputs[0].finish_reason,
209
+ }
210
+ yield ret
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
213
+
214
+
215
+ def process_messages(messages, tools=None, tool_choice="none"):
216
+ _messages = messages
217
+ messages = []
218
+ msg_has_sys = False
219
+
220
+ def filter_tools(tool_choice, tools):
221
+ function_name = tool_choice.get('function', {}).get('name', None)
222
+ if not function_name:
223
+ return []
224
+ filtered_tools = [
225
+ tool for tool in tools
226
+ if tool.get('function', {}).get('name') == function_name
227
+ ]
228
+ return filtered_tools
229
+
230
+ if tool_choice != "none":
231
+ if isinstance(tool_choice, dict):
232
+ tools = filter_tools(tool_choice, tools)
233
+ if tools:
234
+ messages.append(
235
+ {
236
+ "role": "system",
237
+ "content": None,
238
+ "tools": tools
239
+ }
240
+ )
241
+ msg_has_sys = True
242
+
243
+ # add to metadata
244
+ if isinstance(tool_choice, dict) and tools:
245
+ messages.append(
246
+ {
247
+ "role": "assistant",
248
+ "metadata": tool_choice["function"]["name"],
249
+ "content": ""
250
+ }
251
+ )
252
+
253
+ for m in _messages:
254
+ role, content, func_call = m.role, m.content, m.function_call
255
+ if role == "function":
256
+ messages.append(
257
+ {
258
+ "role": "observation",
259
+ "content": content
260
+ }
261
+ )
262
+ elif role == "assistant" and func_call is not None:
263
+ for response in content.split("<|assistant|>"):
264
+ if "\n" in response:
265
+ metadata, sub_content = response.split("\n", maxsplit=1)
266
+ else:
267
+ metadata, sub_content = "", response
268
+ messages.append(
269
+ {
270
+ "role": role,
271
+ "metadata": metadata,
272
+ "content": sub_content.strip()
273
+ }
274
+ )
275
+ else:
276
+ if role == "system" and msg_has_sys:
277
+ msg_has_sys = False
278
+ continue
279
+ messages.append({"role": role, "content": content})
280
+
281
+ return messages
282
+
283
+
284
+ @app.get("/health")
285
+ async def health() -> Response:
286
+ """Health check."""
287
+ return Response(status_code=200)
288
+
289
+
290
+ @app.get("/v1/models", response_model=ModelList)
291
+ async def list_models():
292
+ model_card = ModelCard(id="glm-4")
293
+ return ModelList(data=[model_card])
294
+
295
+
296
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
297
+ async def create_chat_completion(request: ChatCompletionRequest):
298
+ if len(request.messages) < 1 or request.messages[-1].role == "assistant":
299
+ raise HTTPException(status_code=400, detail="Invalid request")
300
+
301
+ gen_params = dict(
302
+ messages=request.messages,
303
+ temperature=request.temperature,
304
+ top_p=request.top_p,
305
+ max_tokens=request.max_tokens or 1024,
306
+ echo=False,
307
+ stream=request.stream,
308
+ repetition_penalty=request.repetition_penalty,
309
+ tools=request.tools,
310
+ tool_choice=request.tool_choice,
311
+ )
312
+ logger.debug(f"==== request ====\n{gen_params}")
313
+
314
+ if request.stream:
315
+ predict_stream_generator = predict_stream(request.model, gen_params)
316
+ output = await anext(predict_stream_generator)
317
+ if output:
318
+ return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
319
+ logger.debug(f"First result output:\n{output}")
320
+
321
+ function_call = None
322
+ if output and request.tools:
323
+ try:
324
+ function_call = process_response(output, use_tool=True)
325
+ except:
326
+ logger.warning("Failed to parse tool call")
327
+
328
+ # CallFunction
329
+ if isinstance(function_call, dict):
330
+ function_call = FunctionCallResponse(**function_call)
331
+ tool_response = ""
332
+ if not gen_params.get("messages"):
333
+ gen_params["messages"] = []
334
+ gen_params["messages"].append(ChatMessage(role="assistant", content=output))
335
+ gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
336
+ generate = predict(request.model, gen_params)
337
+ return EventSourceResponse(generate, media_type="text/event-stream")
338
+ else:
339
+ generate = parse_output_text(request.model, output)
340
+ return EventSourceResponse(generate, media_type="text/event-stream")
341
+
342
+ response = ""
343
+ async for response in generate_stream_glm4(gen_params):
344
+ pass
345
+
346
+ if response["text"].startswith("\n"):
347
+ response["text"] = response["text"][1:]
348
+ response["text"] = response["text"].strip()
349
+
350
+ usage = UsageInfo()
351
+ function_call, finish_reason = None, "stop"
352
+ if request.tools:
353
+ try:
354
+ function_call = process_response(response["text"], use_tool=True)
355
+ except:
356
+ logger.warning(
357
+ "Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
358
+
359
+ if isinstance(function_call, dict):
360
+ finish_reason = "function_call"
361
+ function_call = FunctionCallResponse(**function_call)
362
+
363
+ message = ChatMessage(
364
+ role="assistant",
365
+ content=response["text"],
366
+ function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
367
+ )
368
+
369
+ logger.debug(f"==== message ====\n{message}")
370
+
371
+ choice_data = ChatCompletionResponseChoice(
372
+ index=0,
373
+ message=message,
374
+ finish_reason=finish_reason,
375
+ )
376
+ task_usage = UsageInfo.model_validate(response["usage"])
377
+ for usage_key, usage_value in task_usage.model_dump().items():
378
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
379
+
380
+ return ChatCompletionResponse(
381
+ model=request.model,
382
+ id="", # for open_source model, id is empty
383
+ choices=[choice_data],
384
+ object="chat.completion",
385
+ usage=usage
386
+ )
387
+
388
+
389
+ async def predict(model_id: str, params: dict):
390
+ choice_data = ChatCompletionResponseStreamChoice(
391
+ index=0,
392
+ delta=DeltaMessage(role="assistant"),
393
+ finish_reason=None
394
+ )
395
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
396
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
397
+
398
+ previous_text = ""
399
+ async for new_response in generate_stream_glm4(params):
400
+ decoded_unicode = new_response["text"]
401
+ delta_text = decoded_unicode[len(previous_text):]
402
+ previous_text = decoded_unicode
403
+
404
+ finish_reason = new_response["finish_reason"]
405
+ if len(delta_text) == 0 and finish_reason != "function_call":
406
+ continue
407
+
408
+ function_call = None
409
+ if finish_reason == "function_call":
410
+ try:
411
+ function_call = process_response(decoded_unicode, use_tool=True)
412
+ except:
413
+ logger.warning(
414
+ "Failed to parse tool call, maybe the response is not a tool call or have been answered.")
415
+
416
+ if isinstance(function_call, dict):
417
+ function_call = FunctionCallResponse(**function_call)
418
+
419
+ delta = DeltaMessage(
420
+ content=delta_text,
421
+ role="assistant",
422
+ function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
423
+ )
424
+
425
+ choice_data = ChatCompletionResponseStreamChoice(
426
+ index=0,
427
+ delta=delta,
428
+ finish_reason=finish_reason
429
+ )
430
+ chunk = ChatCompletionResponse(
431
+ model=model_id,
432
+ id="",
433
+ choices=[choice_data],
434
+ object="chat.completion.chunk"
435
+ )
436
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
437
+
438
+ choice_data = ChatCompletionResponseStreamChoice(
439
+ index=0,
440
+ delta=DeltaMessage(),
441
+ finish_reason="stop"
442
+ )
443
+ chunk = ChatCompletionResponse(
444
+ model=model_id,
445
+ id="",
446
+ choices=[choice_data],
447
+ object="chat.completion.chunk"
448
+ )
449
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
450
+ yield '[DONE]'
451
+
452
+
453
+ async def predict_stream(model_id, gen_params):
454
+ output = ""
455
+ is_function_call = False
456
+ has_send_first_chunk = False
457
+ async for new_response in generate_stream_glm4(gen_params):
458
+ decoded_unicode = new_response["text"]
459
+ delta_text = decoded_unicode[len(output):]
460
+ output = decoded_unicode
461
+
462
+ if not is_function_call and len(output) > 7:
463
+ is_function_call = output and 'get_' in output
464
+ if is_function_call:
465
+ continue
466
+
467
+ finish_reason = new_response["finish_reason"]
468
+ if not has_send_first_chunk:
469
+ message = DeltaMessage(
470
+ content="",
471
+ role="assistant",
472
+ function_call=None,
473
+ )
474
+ choice_data = ChatCompletionResponseStreamChoice(
475
+ index=0,
476
+ delta=message,
477
+ finish_reason=finish_reason
478
+ )
479
+ chunk = ChatCompletionResponse(
480
+ model=model_id,
481
+ id="",
482
+ choices=[choice_data],
483
+ created=int(time.time()),
484
+ object="chat.completion.chunk"
485
+ )
486
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
487
+
488
+ send_msg = delta_text if has_send_first_chunk else output
489
+ has_send_first_chunk = True
490
+ message = DeltaMessage(
491
+ content=send_msg,
492
+ role="assistant",
493
+ function_call=None,
494
+ )
495
+ choice_data = ChatCompletionResponseStreamChoice(
496
+ index=0,
497
+ delta=message,
498
+ finish_reason=finish_reason
499
+ )
500
+ chunk = ChatCompletionResponse(
501
+ model=model_id,
502
+ id="",
503
+ choices=[choice_data],
504
+ created=int(time.time()),
505
+ object="chat.completion.chunk"
506
+ )
507
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
508
+
509
+ if is_function_call:
510
+ yield output
511
+ else:
512
+ yield '[DONE]'
513
+
514
+
515
+ async def parse_output_text(model_id: str, value: str):
516
+ choice_data = ChatCompletionResponseStreamChoice(
517
+ index=0,
518
+ delta=DeltaMessage(role="assistant", content=value),
519
+ finish_reason=None
520
+ )
521
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
522
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
523
+ choice_data = ChatCompletionResponseStreamChoice(
524
+ index=0,
525
+ delta=DeltaMessage(),
526
+ finish_reason="stop"
527
+ )
528
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
529
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
530
+ yield '[DONE]'
531
+
532
+
533
+ if __name__ == "__main__":
534
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
535
+ engine_args = AsyncEngineArgs(
536
+ model=MODEL_PATH,
537
+ tokenizer=MODEL_PATH,
538
+ tensor_parallel_size=1,
539
+ dtype="bfloat16",
540
+ trust_remote_code=True,
541
+ gpu_memory_utilization=0.9,
542
+ enforce_eager=True,
543
+ worker_use_ray=True,
544
+ engine_use_ray=False,
545
+ disable_log_requests=True,
546
+ max_model_len=MAX_MODEL_LENGTH,
547
+ )
548
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
549
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # use vllm
2
+ # vllm>=0.4.3
3
+
4
+ torch>=2.3.0
5
+ torchvision>=0.18.0
6
+ transformers==4.40.0
7
+ huggingface-hub>=0.23.1
8
+ sentencepiece>=0.2.0
9
+ pydantic>=2.7.1
10
+ timm>=0.9.16
11
+ tiktoken>=0.7.0
12
+ accelerate>=0.30.1
13
+ sentence_transformers>=2.7.0
14
+
15
+ # web demo
16
+ gradio>=4.33.0
17
+
18
+ # openai demo
19
+ openai>=1.31.1
20
+ einops>=0.7.0
21
+ sse-starlette>=2.1.0
22
+
23
+ # INT4
24
+ bitsandbytes>=0.43.1
25
+
26
+ # PEFT model, not need if you don't use PEFT finetune model.
27
+ peft>=0.11.0
trans_batch_demo.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Here is an example of using batch request glm-4-9b,
4
+ here you need to build the conversation format yourself and then call the batch function to make batch requests.
5
+ Please note that in this demo, the memory consumption is significantly higher.
6
+
7
+ """
8
+
9
+ from typing import Optional, Union
10
+ from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
11
+
12
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ MODEL_PATH,
16
+ trust_remote_code=True,
17
+ encode_special_tokens=True)
18
+ model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
19
+
20
+
21
+ def process_model_outputs(inputs, outputs, tokenizer):
22
+ responses = []
23
+ for input_ids, output_ids in zip(inputs.input_ids, outputs):
24
+ response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
25
+ responses.append(response)
26
+ return responses
27
+
28
+
29
+ def batch(
30
+ model,
31
+ tokenizer,
32
+ messages: Union[str, list[str]],
33
+ max_input_tokens: int = 8192,
34
+ max_new_tokens: int = 8192,
35
+ num_beams: int = 1,
36
+ do_sample: bool = True,
37
+ top_p: float = 0.8,
38
+ temperature: float = 0.8,
39
+ logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
40
+ ):
41
+ messages = [messages] if isinstance(messages, str) else messages
42
+ batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
43
+ max_length=max_input_tokens).to(model.device)
44
+
45
+ gen_kwargs = {
46
+ "max_new_tokens": max_new_tokens,
47
+ "num_beams": num_beams,
48
+ "do_sample": do_sample,
49
+ "top_p": top_p,
50
+ "temperature": temperature,
51
+ "logits_processor": logits_processor,
52
+ "eos_token_id": model.config.eos_token_id
53
+ }
54
+ batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
55
+ batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
56
+ return batched_response
57
+
58
+
59
+ if __name__ == "__main__":
60
+
61
+ batch_message = [
62
+ [
63
+ {"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
64
+ {"role": "assistant", "content": "因为他们结婚时你还没有出生"},
65
+ {"role": "user", "content": "我刚才的提问是"}
66
+ ],
67
+ [
68
+ {"role": "user", "content": "你好,你是谁"}
69
+ ]
70
+ ]
71
+
72
+ batch_inputs = []
73
+ max_input_tokens = 1024
74
+ for i, messages in enumerate(batch_message):
75
+ new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
76
+ max_input_tokens = max(max_input_tokens, len(new_batch_input))
77
+ batch_inputs.append(new_batch_input)
78
+ gen_kwargs = {
79
+ "max_input_tokens": max_input_tokens,
80
+ "max_new_tokens": 8192,
81
+ "do_sample": True,
82
+ "top_p": 0.8,
83
+ "temperature": 0.8,
84
+ "num_beams": 1,
85
+ }
86
+
87
+ batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
88
+ for response in batch_responses:
89
+ print("=" * 10)
90
+ print(response)
trans_cli_demo.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with transformers backend for the glm-4-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ from threading import Thread
16
+ from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel
17
+
18
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
19
+
20
+ ## If use peft model.
21
+ # def load_model_and_tokenizer(model_dir, trust_remote_code: bool = True):
22
+ # if (model_dir / 'adapter_config.json').exists():
23
+ # model = AutoModel.from_pretrained(
24
+ # model_dir, trust_remote_code=trust_remote_code, device_map='auto'
25
+ # )
26
+ # tokenizer_dir = model.peft_config['default'].base_model_name_or_path
27
+ # else:
28
+ # model = AutoModel.from_pretrained(
29
+ # model_dir, trust_remote_code=trust_remote_code, device_map='auto'
30
+ # )
31
+ # tokenizer_dir = model_dir
32
+ # tokenizer = AutoTokenizer.from_pretrained(
33
+ # tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
34
+ # )
35
+ # return model, tokenizer
36
+
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ MODEL_PATH,
40
+ trust_remote_code=True,
41
+ encode_special_tokens=True
42
+ )
43
+ model = AutoModel.from_pretrained(
44
+ MODEL_PATH,
45
+ trust_remote_code=True,
46
+ device_map="auto").eval()
47
+
48
+
49
+ class StopOnTokens(StoppingCriteria):
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
51
+ stop_ids = model.config.eos_token_id
52
+ for stop_id in stop_ids:
53
+ if input_ids[0][-1] == stop_id:
54
+ return True
55
+ return False
56
+
57
+
58
+ if __name__ == "__main__":
59
+ history = []
60
+ max_length = 8192
61
+ top_p = 0.8
62
+ temperature = 0.6
63
+ stop = StopOnTokens()
64
+
65
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
66
+ while True:
67
+ user_input = input("\nYou: ")
68
+ if user_input.lower() in ["exit", "quit"]:
69
+ break
70
+ history.append([user_input, ""])
71
+
72
+ messages = []
73
+ for idx, (user_msg, model_msg) in enumerate(history):
74
+ if idx == len(history) - 1 and not model_msg:
75
+ messages.append({"role": "user", "content": user_msg})
76
+ break
77
+ if user_msg:
78
+ messages.append({"role": "user", "content": user_msg})
79
+ if model_msg:
80
+ messages.append({"role": "assistant", "content": model_msg})
81
+ model_inputs = tokenizer.apply_chat_template(
82
+ messages,
83
+ add_generation_prompt=True,
84
+ tokenize=True,
85
+ return_tensors="pt"
86
+ ).to(model.device)
87
+ streamer = TextIteratorStreamer(
88
+ tokenizer=tokenizer,
89
+ timeout=60,
90
+ skip_prompt=True,
91
+ skip_special_tokens=True
92
+ )
93
+ generate_kwargs = {
94
+ "input_ids": model_inputs,
95
+ "streamer": streamer,
96
+ "max_new_tokens": max_length,
97
+ "do_sample": True,
98
+ "top_p": top_p,
99
+ "temperature": temperature,
100
+ "stopping_criteria": StoppingCriteriaList([stop]),
101
+ "repetition_penalty": 1.2,
102
+ "eos_token_id": model.config.eos_token_id,
103
+ }
104
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
105
+ t.start()
106
+ print("GLM-4:", end="", flush=True)
107
+ for new_token in streamer:
108
+ if new_token:
109
+ print(new_token, end="", flush=True)
110
+ history[-1][1] += new_token
111
+
112
+ history[-1][1] = history[-1][1].strip()
trans_cli_vision_demo.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with transformers backend for the glm-4v-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ from threading import Thread
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ StoppingCriteria,
19
+ StoppingCriteriaList,
20
+ TextIteratorStreamer, AutoModel, BitsAndBytesConfig
21
+ )
22
+
23
+ from PIL import Image
24
+
25
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ MODEL_PATH,
29
+ trust_remote_code=True,
30
+ encode_special_tokens=True
31
+ )
32
+ model = AutoModel.from_pretrained(
33
+ MODEL_PATH,
34
+ trust_remote_code=True,
35
+ device_map="auto",
36
+ torch_dtype=torch.bfloat16
37
+ ).eval()
38
+
39
+ ## For INT4 inference
40
+ # model = AutoModel.from_pretrained(
41
+ # MODEL_PATH,
42
+ # trust_remote_code=True,
43
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
44
+ # torch_dtype=torch.bfloat16,
45
+ # low_cpu_mem_usage=True
46
+ # ).eval()
47
+
48
+ class StopOnTokens(StoppingCriteria):
49
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ stop_ids = model.config.eos_token_id
51
+ for stop_id in stop_ids:
52
+ if input_ids[0][-1] == stop_id:
53
+ return True
54
+ return False
55
+
56
+
57
+ if __name__ == "__main__":
58
+ history = []
59
+ max_length = 1024
60
+ top_p = 0.8
61
+ temperature = 0.6
62
+ stop = StopOnTokens()
63
+ uploaded = False
64
+ image = None
65
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
66
+ image_path = input("Image Path:")
67
+ try:
68
+ image = Image.open(image_path).convert("RGB")
69
+ except:
70
+ print("Invalid image path. Continuing with text conversation.")
71
+ while True:
72
+ user_input = input("\nYou: ")
73
+ if user_input.lower() in ["exit", "quit"]:
74
+ break
75
+ history.append([user_input, ""])
76
+
77
+ messages = []
78
+ for idx, (user_msg, model_msg) in enumerate(history):
79
+ if idx == len(history) - 1 and not model_msg:
80
+ messages.append({"role": "user", "content": user_msg})
81
+ if image and not uploaded:
82
+ messages[-1].update({"image": image})
83
+ uploaded = True
84
+ break
85
+ if user_msg:
86
+ messages.append({"role": "user", "content": user_msg})
87
+ if model_msg:
88
+ messages.append({"role": "assistant", "content": model_msg})
89
+ model_inputs = tokenizer.apply_chat_template(
90
+ messages,
91
+ add_generation_prompt=True,
92
+ tokenize=True,
93
+ return_tensors="pt",
94
+ return_dict=True
95
+ ).to(next(model.parameters()).device)
96
+ streamer = TextIteratorStreamer(
97
+ tokenizer=tokenizer,
98
+ timeout=60,
99
+ skip_prompt=True,
100
+ skip_special_tokens=True
101
+ )
102
+ generate_kwargs = {
103
+ **model_inputs,
104
+ "streamer": streamer,
105
+ "max_new_tokens": max_length,
106
+ "do_sample": True,
107
+ "top_p": top_p,
108
+ "temperature": temperature,
109
+ "stopping_criteria": StoppingCriteriaList([stop]),
110
+ "repetition_penalty": 1.2,
111
+ "eos_token_id": [151329, 151336, 151338],
112
+ }
113
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
114
+ t.start()
115
+ print("GLM-4:", end="", flush=True)
116
+ for new_token in streamer:
117
+ if new_token:
118
+ print(new_token, end="", flush=True)
119
+ history[-1][1] += new_token
120
+
121
+ history[-1][1] = history[-1][1].strip()
trans_stress_test.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
4
+ import torch
5
+ from threading import Thread
6
+
7
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
8
+
9
+
10
+ def stress_test(token_len, n, num_gpu):
11
+ device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ MODEL_PATH,
14
+ trust_remote_code=True,
15
+ padding_side="left"
16
+ )
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_PATH,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16
21
+ ).to(device).eval()
22
+
23
+ # Use INT4 weight infer
24
+ # model = AutoModelForCausalLM.from_pretrained(
25
+ # MODEL_PATH,
26
+ # trust_remote_code=True,
27
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
28
+ # low_cpu_mem_usage=True,
29
+ # ).eval()
30
+
31
+ times = []
32
+ decode_times = []
33
+
34
+ print("Warming up...")
35
+ vocab_size = tokenizer.vocab_size
36
+ warmup_token_len = 20
37
+ random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
38
+ start_tokens = [151331, 151333, 151336, 198]
39
+ end_tokens = [151337]
40
+ input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
41
+ device)
42
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
43
+ position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
44
+ warmup_inputs = {
45
+ 'input_ids': input_ids,
46
+ 'attention_mask': attention_mask,
47
+ 'position_ids': position_ids
48
+ }
49
+ with torch.no_grad():
50
+ _ = model.generate(
51
+ input_ids=warmup_inputs['input_ids'],
52
+ attention_mask=warmup_inputs['attention_mask'],
53
+ max_new_tokens=2048,
54
+ do_sample=False,
55
+ repetition_penalty=1.0,
56
+ eos_token_id=[151329, 151336, 151338]
57
+ )
58
+ print("Warming up complete. Starting stress test...")
59
+
60
+ for i in range(n):
61
+ random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
62
+ input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
63
+ 0).to(device)
64
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
65
+ position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
66
+ test_inputs = {
67
+ 'input_ids': input_ids,
68
+ 'attention_mask': attention_mask,
69
+ 'position_ids': position_ids
70
+ }
71
+
72
+ streamer = TextIteratorStreamer(
73
+ tokenizer=tokenizer,
74
+ timeout=36000,
75
+ skip_prompt=True,
76
+ skip_special_tokens=True
77
+ )
78
+
79
+ generate_kwargs = {
80
+ "input_ids": test_inputs['input_ids'],
81
+ "attention_mask": test_inputs['attention_mask'],
82
+ "max_new_tokens": 512,
83
+ "do_sample": False,
84
+ "repetition_penalty": 1.0,
85
+ "eos_token_id": [151329, 151336, 151338],
86
+ "streamer": streamer
87
+ }
88
+
89
+ start_time = time.time()
90
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ t.start()
92
+
93
+ first_token_time = None
94
+ all_token_times = []
95
+
96
+ for token in streamer:
97
+ current_time = time.time()
98
+ if first_token_time is None:
99
+ first_token_time = current_time
100
+ times.append(first_token_time - start_time)
101
+ all_token_times.append(current_time)
102
+
103
+ t.join()
104
+ end_time = time.time()
105
+
106
+ avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
107
+ decode_times.append(avg_decode_time_per_token)
108
+ print(
109
+ f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
110
+
111
+ torch.cuda.empty_cache()
112
+
113
+ avg_first_token_time = sum(times) / n
114
+ avg_decode_time = sum(decode_times) / n
115
+ print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
116
+ print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
117
+ return times, avg_first_token_time, decode_times, avg_decode_time
118
+
119
+
120
+ def main():
121
+ parser = argparse.ArgumentParser(description="Stress test for model inference")
122
+ parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
123
+ parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
124
+ parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
125
+ args = parser.parse_args()
126
+
127
+ token_len = args.token_len
128
+ n = args.n
129
+ num_gpu = args.num_gpu
130
+
131
+ stress_test(token_len, n, num_gpu)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
trans_web_demo.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates an interactive web demo for the GLM-4-9B model using Gradio,
3
+ a Python library for building quick and easy UI components for machine learning models.
4
+ It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
5
+ allowing users to interact with the model through a chat-like interface.
6
+ """
7
+
8
+ import os
9
+ import gradio as gr
10
+ import torch
11
+ from threading import Thread
12
+
13
+ from typing import Union
14
+ from pathlib import Path
15
+ from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ PreTrainedModel,
20
+ PreTrainedTokenizer,
21
+ PreTrainedTokenizerFast,
22
+ StoppingCriteria,
23
+ StoppingCriteriaList,
24
+ TextIteratorStreamer
25
+ )
26
+
27
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+
29
+ ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
30
+ TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
31
+
32
+ MODEL_PATH = os.environ.get('MODEL_PATH', '..\models\glm-4-9b-chat')
33
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
34
+
35
+
36
+ def _resolve_path(path: Union[str, Path]) -> Path:
37
+ return Path(path).expanduser().resolve()
38
+
39
+
40
+ def load_model_and_tokenizer(
41
+ model_dir: Union[str, Path], trust_remote_code: bool = True
42
+ ) -> tuple[ModelType, TokenizerType]:
43
+ model_dir = _resolve_path(model_dir)
44
+ if (model_dir / 'adapter_config.json').exists():
45
+ model = AutoPeftModelForCausalLM.from_pretrained(
46
+ model_dir, trust_remote_code=trust_remote_code, device_map='auto'
47
+ )
48
+ tokenizer_dir = model.peft_config['default'].base_model_name_or_path
49
+ else:
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ model_dir, trust_remote_code=trust_remote_code, device_map='auto'
52
+ ).to(DEVICE).eval()
53
+ tokenizer_dir = model_dir
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
56
+ )
57
+ return model, tokenizer
58
+
59
+
60
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
61
+
62
+
63
+ class StopOnTokens(StoppingCriteria):
64
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
65
+ stop_ids = model.config.eos_token_id
66
+ for stop_id in stop_ids:
67
+ if input_ids[0][-1] == stop_id:
68
+ return True
69
+ return False
70
+
71
+
72
+ def parse_text(text):
73
+ lines = text.split("\n")
74
+ lines = [line for line in lines if line != ""]
75
+ count = 0
76
+ for i, line in enumerate(lines):
77
+ if "```" in line:
78
+ count += 1
79
+ items = line.split('`')
80
+ if count % 2 == 1:
81
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
82
+ else:
83
+ lines[i] = f'<br></code></pre>'
84
+ else:
85
+ if i > 0:
86
+ if count % 2 == 1:
87
+ line = line.replace("`", "\`")
88
+ line = line.replace("<", "&lt;")
89
+ line = line.replace(">", "&gt;")
90
+ line = line.replace(" ", "&nbsp;")
91
+ line = line.replace("*", "&ast;")
92
+ line = line.replace("_", "&lowbar;")
93
+ line = line.replace("-", "&#45;")
94
+ line = line.replace(".", "&#46;")
95
+ line = line.replace("!", "&#33;")
96
+ line = line.replace("(", "&#40;")
97
+ line = line.replace(")", "&#41;")
98
+ line = line.replace("$", "&#36;")
99
+ lines[i] = "<br>" + line
100
+ text = "".join(lines)
101
+ return text
102
+
103
+
104
+ def predict(history, max_length, top_p, temperature):
105
+ stop = StopOnTokens()
106
+ messages = []
107
+ for idx, (user_msg, model_msg) in enumerate(history):
108
+ if idx == len(history) - 1 and not model_msg:
109
+ messages.append({"role": "user", "content": user_msg})
110
+ break
111
+ if user_msg:
112
+ messages.append({"role": "user", "content": user_msg})
113
+ if model_msg:
114
+ messages.append({"role": "assistant", "content": model_msg})
115
+
116
+ model_inputs = tokenizer.apply_chat_template(messages,
117
+ add_generation_prompt=True,
118
+ tokenize=True,
119
+ return_tensors="pt").to(next(model.parameters()).device)
120
+ streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
121
+ generate_kwargs = {
122
+ "input_ids": model_inputs,
123
+ "streamer": streamer,
124
+ "max_new_tokens": max_length,
125
+ "do_sample": True,
126
+ "top_p": top_p,
127
+ "temperature": temperature,
128
+ "stopping_criteria": StoppingCriteriaList([stop]),
129
+ "repetition_penalty": 1.2,
130
+ "eos_token_id": model.config.eos_token_id,
131
+ }
132
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
133
+ t.start()
134
+ for new_token in streamer:
135
+ if new_token:
136
+ history[-1][1] += new_token
137
+ yield history
138
+
139
+
140
+ with gr.Blocks() as demo:
141
+ gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
142
+ chatbot = gr.Chatbot()
143
+
144
+ with gr.Row():
145
+ with gr.Column(scale=4):
146
+ with gr.Column(scale=12):
147
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
148
+ with gr.Column(min_width=32, scale=1):
149
+ submitBtn = gr.Button("Submit")
150
+ with gr.Column(scale=1):
151
+ emptyBtn = gr.Button("Clear History")
152
+ max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
153
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
154
+ temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
155
+
156
+
157
+ def user(query, history):
158
+ return "", history + [[parse_text(query), ""]]
159
+
160
+
161
+ submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
162
+ predict, [chatbot, max_length, top_p, temperature], chatbot
163
+ )
164
+ emptyBtn.click(lambda: None, None, chatbot, queue=False)
165
+
166
+ demo.queue()
167
+ demo.launch(server_name="0.0.0.0", server_port=8501, inbrowser=False, share=True)
vllm_cli_demo.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with vllm backand for the glm-4-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+ """
12
+ import time
13
+ import asyncio
14
+ from transformers import AutoTokenizer
15
+ from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
16
+ from typing import List, Dict
17
+
18
+ MODEL_PATH = 'THUDM/glm-4-9b'
19
+
20
+
21
+ def load_model_and_tokenizer(model_dir: str):
22
+ engine_args = AsyncEngineArgs(
23
+ model=model_dir,
24
+ tokenizer=model_dir,
25
+ tensor_parallel_size=1,
26
+ dtype="bfloat16",
27
+ trust_remote_code=True,
28
+ gpu_memory_utilization=0.3,
29
+ enforce_eager=True,
30
+ worker_use_ray=True,
31
+ engine_use_ray=False,
32
+ disable_log_requests=True
33
+ # 如果遇见 OOM 现象,建议开启下述参数
34
+ # enable_chunked_prefill=True,
35
+ # max_num_batched_tokens=8192
36
+ )
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_dir,
39
+ trust_remote_code=True,
40
+ encode_special_tokens=True
41
+ )
42
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
43
+ return engine, tokenizer
44
+
45
+
46
+ engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
47
+
48
+
49
+ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
50
+ inputs = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ tokenize=False
54
+ )
55
+ params_dict = {
56
+ "n": 1,
57
+ "best_of": 1,
58
+ "presence_penalty": 1.0,
59
+ "frequency_penalty": 0.0,
60
+ "temperature": temperature,
61
+ "top_p": top_p,
62
+ "top_k": -1,
63
+ "use_beam_search": False,
64
+ "length_penalty": 1,
65
+ "early_stopping": False,
66
+ "stop_token_ids": [151329, 151336, 151338],
67
+ "ignore_eos": False,
68
+ "max_tokens": max_dec_len,
69
+ "logprobs": None,
70
+ "prompt_logprobs": None,
71
+ "skip_special_tokens": True,
72
+ }
73
+ sampling_params = SamplingParams(**params_dict)
74
+ async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
75
+ yield output.outputs[0].text
76
+
77
+
78
+ async def chat():
79
+ history = []
80
+ max_length = 8192
81
+ top_p = 0.8
82
+ temperature = 0.6
83
+
84
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
85
+ while True:
86
+ user_input = input("\nYou: ")
87
+ if user_input.lower() in ["exit", "quit"]:
88
+ break
89
+ history.append([user_input, ""])
90
+
91
+ messages = []
92
+ for idx, (user_msg, model_msg) in enumerate(history):
93
+ if idx == len(history) - 1 and not model_msg:
94
+ messages.append({"role": "user", "content": user_msg})
95
+ break
96
+ if user_msg:
97
+ messages.append({"role": "user", "content": user_msg})
98
+ if model_msg:
99
+ messages.append({"role": "assistant", "content": model_msg})
100
+
101
+ print("\nGLM-4: ", end="")
102
+ current_length = 0
103
+ output = ""
104
+ async for output in vllm_gen(messages, top_p, temperature, max_length):
105
+ print(output[current_length:], end="", flush=True)
106
+ current_length = len(output)
107
+ history[-1][1] = output
108
+
109
+
110
+ if __name__ == "__main__":
111
+ asyncio.run(chat())