unsubscribe commited on
Commit
80dd0d6
1 Parent(s): adde727

try space for zeroGPU

Browse files
Files changed (1) hide show
  1. app.py +230 -65
app.py CHANGED
@@ -1,71 +1,236 @@
1
- from lmdeploy.serve.gradio.turbomind_coupled import *
2
- from lmdeploy.messages import TurbomindEngineConfig
 
 
 
3
 
4
- backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05, model_format='awq')
5
- model_path = 'internlm/internlm2-chat-20b-4bits'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- InterFace.async_engine = AsyncEngine(
8
- model_path=model_path,
9
- backend='turbomind',
10
- backend_config=backend_config,
11
- tp=1)
12
-
13
- with gr.Blocks(css=CSS, theme=THEME) as demo:
14
- state_chatbot = gr.State([])
15
- state_session_id = gr.State(0)
16
-
17
- with gr.Column(elem_id='container'):
18
- gr.Markdown('## LMDeploy Playground')
19
-
20
- chatbot = gr.Chatbot(
21
- elem_id='chatbot',
22
- label=InterFace.async_engine.engine.model_name)
23
- instruction_txtbox = gr.Textbox(
24
- placeholder='Please input the instruction',
25
- label='Instruction')
26
- with gr.Row():
27
- cancel_btn = gr.Button(value='Cancel', interactive=False)
28
- reset_btn = gr.Button(value='Reset')
29
- with gr.Row():
30
- request_output_len = gr.Slider(1,
31
- 2048,
32
- value=512,
33
- step=1,
34
- label='Maximum new tokens')
35
- top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
36
- temperature = gr.Slider(0.01,
37
- 1.5,
38
- value=0.7,
39
- step=0.01,
40
- label='Temperature')
41
-
42
- send_event = instruction_txtbox.submit(chat_stream_local, [
43
- instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
44
- state_session_id, top_p, temperature, request_output_len
45
- ], [state_chatbot, chatbot, cancel_btn, reset_btn])
46
- instruction_txtbox.submit(
47
- lambda: gr.Textbox.update(value=''),
48
- [],
49
- [instruction_txtbox],
50
- )
51
- cancel_btn.click(
52
- cancel_local_func,
53
- [state_chatbot, cancel_btn, reset_btn, state_session_id],
54
- [state_chatbot, cancel_btn, reset_btn],
55
- cancels=[send_event])
56
-
57
- reset_btn.click(reset_local_func,
58
- [instruction_txtbox, state_chatbot, state_session_id],
59
- [state_chatbot, chatbot, instruction_txtbox],
60
- cancels=[send_event])
61
-
62
- def init():
 
 
 
 
 
 
 
 
 
 
 
63
  with InterFace.lock:
64
  InterFace.global_session_id += 1
65
- new_session_id = InterFace.global_session_id
66
- return new_session_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- demo.load(init, inputs=None, outputs=[state_session_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- demo.queue(concurrency_count=InterFace.async_engine.instance_num,
71
- max_size=100).launch()
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import random
3
+ import spaces
4
+ from threading import Lock
5
+ from typing import Literal, Optional, Sequence, Union
6
 
7
+ import gradio as gr
8
+
9
+ from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
10
+ TurbomindEngineConfig)
11
+ from lmdeploy.model import ChatTemplateConfig
12
+ from lmdeploy.serve.async_engine import AsyncEngine
13
+ from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
14
+
15
+
16
+ class InterFace:
17
+ async_engine: AsyncEngine = None
18
+ global_session_id: int = 0
19
+ lock = Lock()
20
+
21
+ @spaces.GPU
22
+ async def chat_stream_local(instruction: str, state_chatbot: Sequence,
23
+ cancel_btn: gr.Button, reset_btn: gr.Button,
24
+ session_id: int, top_p: float, temperature: float,
25
+ request_output_len: int):
26
+ """Chat with AI assistant.
27
+
28
+ Args:
29
+ instruction (str): user's prompt
30
+ state_chatbot (Sequence): the chatting history
31
+ cancel_btn (gr.Button): the cancel button
32
+ reset_btn (gr.Button): the reset button
33
+ session_id (int): the session id
34
+ """
35
+ state_chatbot = state_chatbot + [(instruction, None)]
36
 
37
+ yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
38
+ gen_config = GenerationConfig(max_new_tokens=request_output_len,
39
+ top_p=top_p,
40
+ top_k=40,
41
+ temperature=temperature,
42
+ random_seed=random.getrandbits(64)
43
+ if len(state_chatbot) == 1 else None)
44
+
45
+ async for outputs in InterFace.async_engine.generate(
46
+ instruction,
47
+ session_id,
48
+ gen_config=gen_config,
49
+ stream_response=True,
50
+ sequence_start=(len(state_chatbot) == 1),
51
+ sequence_end=False):
52
+ response = outputs.response
53
+ if outputs.finish_reason == 'length':
54
+ gr.Warning('WARNING: exceed session max length.'
55
+ ' Please restart the session by reset button.')
56
+ if outputs.generate_token_len < 0:
57
+ gr.Warning('WARNING: running on the old session.'
58
+ ' Please restart the session by reset button.')
59
+ if state_chatbot[-1][-1] is None:
60
+ state_chatbot[-1] = (state_chatbot[-1][0], response)
61
+ else:
62
+ state_chatbot[-1] = (state_chatbot[-1][0],
63
+ state_chatbot[-1][1] + response
64
+ ) # piece by piece
65
+ yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
66
+
67
+ yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
68
+
69
+ @spaces.GPU
70
+ async def reset_local_func(instruction_txtbox: gr.Textbox,
71
+ state_chatbot: Sequence, session_id: int):
72
+ """reset the session.
73
+
74
+ Args:
75
+ instruction_txtbox (str): user's prompt
76
+ state_chatbot (Sequence): the chatting history
77
+ session_id (int): the session id
78
+ """
79
+ state_chatbot = []
80
+ # end the session
81
+ with InterFace.lock:
82
+ InterFace.global_session_id += 1
83
+ session_id = InterFace.global_session_id
84
+ return (state_chatbot, state_chatbot, gr.Textbox.update(value=''), session_id)
85
+
86
+ @spaces.GPU
87
+ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
88
+ reset_btn: gr.Button, session_id: int):
89
+ """stop the session.
90
+
91
+ Args:
92
+ instruction_txtbox (str): user's prompt
93
+ state_chatbot (Sequence): the chatting history
94
+ cancel_btn (gr.Button): the cancel button
95
+ reset_btn (gr.Button): the reset button
96
+ session_id (int): the session id
97
+ """
98
+ yield (state_chatbot, disable_btn, disable_btn)
99
+ InterFace.async_engine.stop_session(session_id)
100
+ # pytorch backend does not support resume chat history now
101
+ if InterFace.async_engine.backend == 'pytorch':
102
+ yield (state_chatbot, disable_btn, enable_btn)
103
+ else:
104
  with InterFace.lock:
105
  InterFace.global_session_id += 1
106
+ session_id = InterFace.global_session_id
107
+ messages = []
108
+ for qa in state_chatbot:
109
+ messages.append(dict(role='user', content=qa[0]))
110
+ if qa[1] is not None:
111
+ messages.append(dict(role='assistant', content=qa[1]))
112
+ gen_config = GenerationConfig(max_new_tokens=0)
113
+ async for out in InterFace.async_engine.generate(messages,
114
+ session_id,
115
+ gen_config=gen_config,
116
+ stream_response=True,
117
+ sequence_start=True,
118
+ sequence_end=False):
119
+ pass
120
+ yield (state_chatbot, disable_btn, enable_btn, session_id)
121
+
122
+ @spaces.GPU
123
+ def run_local(model_path: str,
124
+ model_name: Optional[str] = None,
125
+ backend: Literal['turbomind', 'pytorch'] = 'turbomind',
126
+ backend_config: Optional[Union[PytorchEngineConfig,
127
+ TurbomindEngineConfig]] = None,
128
+ chat_template_config: Optional[ChatTemplateConfig] = None,
129
+ server_name: str = 'localhost',
130
+ server_port: int = 6006,
131
+ tp: int = 1,
132
+ **kwargs):
133
+ """chat with AI assistant through web ui.
134
 
135
+ Args:
136
+ model_path (str): the path of a model.
137
+ It could be one of the following options:
138
+ - i) A local directory path of a turbomind model which is
139
+ converted by `lmdeploy convert` command or download from
140
+ ii) and iii).
141
+ - ii) The model_id of a lmdeploy-quantized model hosted
142
+ inside a model repo on huggingface.co, such as
143
+ "InternLM/internlm-chat-20b-4bit",
144
+ "lmdeploy/llama2-chat-70b-4bit", etc.
145
+ - iii) The model_id of a model hosted inside a model repo
146
+ on huggingface.co, such as "internlm/internlm-chat-7b",
147
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
148
+ and so on.
149
+ model_name (str): needed when model_path is a pytorch model on
150
+ huggingface.co, such as "internlm/internlm-chat-7b",
151
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
152
+ backend (str): either `turbomind` or `pytorch` backend. Default to
153
+ `turbomind` backend.
154
+ backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
155
+ config instance. Default to none.
156
+ chat_template_config (ChatTemplateConfig): chat template configuration.
157
+ Default to None.
158
+ server_name (str): the ip address of gradio server
159
+ server_port (int): the port of gradio server
160
+ tp (int): tensor parallel for Turbomind
161
+ """
162
+ InterFace.async_engine = AsyncEngine(
163
+ model_path=model_path,
164
+ backend=backend,
165
+ backend_config=backend_config,
166
+ chat_template_config=chat_template_config,
167
+ model_name=model_name,
168
+ tp=tp,
169
+ **kwargs)
170
+
171
+ with gr.Blocks(css=CSS, theme=THEME) as demo:
172
+ state_chatbot = gr.State([])
173
+ state_session_id = gr.State(0)
174
+
175
+ with gr.Column(elem_id='container'):
176
+ gr.Markdown('## LMDeploy Playground')
177
+
178
+ chatbot = gr.Chatbot(
179
+ elem_id='chatbot',
180
+ label=InterFace.async_engine.engine.model_name)
181
+ instruction_txtbox = gr.Textbox(
182
+ placeholder='Please input the instruction',
183
+ label='Instruction')
184
+ with gr.Row():
185
+ cancel_btn = gr.Button(value='Cancel', interactive=False)
186
+ reset_btn = gr.Button(value='Reset')
187
+ with gr.Row():
188
+ request_output_len = gr.Slider(1,
189
+ 2048,
190
+ value=512,
191
+ step=1,
192
+ label='Maximum new tokens')
193
+ top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
194
+ temperature = gr.Slider(0.01,
195
+ 1.5,
196
+ value=0.7,
197
+ step=0.01,
198
+ label='Temperature')
199
+
200
+ send_event = instruction_txtbox.submit(chat_stream_local, [
201
+ instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
202
+ state_session_id, top_p, temperature, request_output_len
203
+ ], [state_chatbot, chatbot, cancel_btn, reset_btn])
204
+ instruction_txtbox.submit(
205
+ lambda: gr.Textbox.update(value=''),
206
+ [],
207
+ [instruction_txtbox],
208
+ )
209
+ cancel_btn.click(
210
+ cancel_local_func,
211
+ [state_chatbot, cancel_btn, reset_btn, state_session_id],
212
+ [state_chatbot, cancel_btn, reset_btn, state_session_id],
213
+ cancels=[send_event])
214
+
215
+ reset_btn.click(reset_local_func,
216
+ [instruction_txtbox, state_chatbot, state_session_id],
217
+ [state_chatbot, chatbot, instruction_txtbox],
218
+ cancels=[send_event])
219
+
220
+ def init():
221
+ with InterFace.lock:
222
+ InterFace.global_session_id += 1
223
+ new_session_id = InterFace.global_session_id
224
+ return new_session_id
225
+
226
+ demo.load(init, inputs=None, outputs=[state_session_id])
227
+
228
+ demo.queue(concurrency_count=InterFace.async_engine.instance_num,
229
+ max_size=100).launch()
230
+
231
+
232
+
233
+ backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05, model_format='awq')
234
+ model_path = 'internlm/internlm2-chat-20b-4bits'
235
 
236
+ run_local(model_path, backend_config=backend_config)