ffreemt commited on
Commit
4bef18d
1 Parent(s): 103cf8f
Files changed (5) hide show
  1. .gitignore +1 -0
  2. .ruff.toml +17 -0
  3. app.py +329 -31
  4. example_list.py +56 -0
  5. requirements.txt +1 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .ruff_cache
.ruff.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Assume Python 3.10.
2
+ target-version = "py310"
3
+ # Decrease the maximum line length to 79 characters.
4
+ line-length = 300
5
+
6
+ # pyflakes, pycodestyle, isort
7
+ # flake8 YTT, pydocstyle D, pylint PLC
8
+ select = ["F", "E", "W", "I001", "YTT", "D", "PLC"]
9
+ # select = ["ALL"]
10
+
11
+ # D103 Missing docstring in public function
12
+ # D101 Missing docstring in public class
13
+ # `multi-line-summary-first-line` (D212)
14
+ # `one-blank-line-before-class` (D203)
15
+ extend-ignore = ["D103", "D101", "D212", "D203"]
16
+
17
+ exclude = [".venv"]
app.py CHANGED
@@ -1,14 +1,30 @@
1
- """Run qwen 7b.
 
2
 
3
  transformers 4.31.0
 
 
 
 
4
  """
 
 
5
  import os
6
  import time
 
 
 
 
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
9
- from transformers.generation import GenerationConfig
10
- from transformers import BitsAndBytesConfig
11
  from loguru import logger
 
 
 
 
 
 
 
12
 
13
  os.environ["TZ"] = "Asia/Shanghai"
14
  try:
@@ -17,41 +33,323 @@ except Exception:
17
  # Windows
18
  logger.warning("Windows, cant run time.tzset()")
19
 
20
- device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
21
- # has_cuda = False # force cpu
22
-
23
  model_name = "Qwen/Qwen-7B-Chat"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
 
26
- # quantization configuration for NF4 (4 bits)
27
- quantization_config = BitsAndBytesConfig(
28
- load_in_4bit=True,
29
- bnb_4bit_quant_type='nf4',
30
- bnb_4bit_compute_dtype=torch.bfloat16
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # quantization configuration for Int8 (8 bits)
34
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
35
 
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_name,
38
- device_map=device_map,
39
- quantization_config=quantization_config,
40
- # max_memory=max_memory,
41
- trust_remote_code=True,
42
- ).eval()
43
 
44
- # model = model.eval()
 
 
45
 
46
- # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
 
 
 
 
 
47
 
48
- # Runs
49
- # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # 可指定不同的生成长度、top_p等相关超参
52
- model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # response, history = model.chat(tokenizer, "你好", history=None)
55
- response, history = model.chat(tokenizer, "你好", history=[])
56
- print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
1
+ """
2
+ Run qwen 7b chat.
3
 
4
  transformers 4.31.0
5
+
6
+ import torch
7
+ torch.cuda.empty_cache()
8
+
9
  """
10
+ # pylint: disable=line-too-long, invalid-name, no-member, redefined-outer-name, missing-function-docstring, missing-class-docstring, broad-except,
11
+ import gc
12
  import os
13
  import time
14
+ from collections import deque
15
+ from dataclasses import asdict, dataclass
16
+ from types import SimpleNamespace
17
+
18
+ import gradio as gr
19
  import torch
 
 
 
20
  from loguru import logger
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+ from transformers.generation import GenerationConfig
23
+
24
+ from example_list import css, example_list
25
+
26
+ if not torch.cuda.is_available():
27
+ raise gr.Error("No cuda, cant continue...")
28
 
29
  os.environ["TZ"] = "Asia/Shanghai"
30
  try:
 
33
  # Windows
34
  logger.warning("Windows, cant run time.tzset()")
35
 
 
 
 
36
  model_name = "Qwen/Qwen-7B-Chat"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
38
 
39
+ n_gpus = torch.cuda.device_count()
40
+ try:
41
+ _ = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
42
+ except AssertionError:
43
+ _ = 0
44
+ max_memory = {i: _ for i in range(n_gpus)}
45
+
46
+
47
+ def gen_model(model_name: str):
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ model_name,
50
+ trust_remote_code=True,
51
+ device_map="auto",
52
+ load_in_4bit=True,
53
+ max_memory=max_memory,
54
+ fp16=True,
55
+ torch_dtype=torch.float16,
56
+ bnb_4bit_quant_type="nf4",
57
+ bnb_4bit_compute_dtype=torch.bfloat16,
58
+ )
59
+ model = model.eval()
60
+ model.generation_config = GenerationConfig.from_pretrained(
61
+ model_name,
62
+ trust_remote_code=True,
63
+ )
64
+ return model
65
+
66
+
67
+ def user_sub(message, chat_history):
68
+ """Gen a response, clear message in user textbox."""
69
+ logger.debug(f"{message=}")
70
+
71
+ # logger.remove() #to turn on trace
72
+ # logger.add(sys.stderr, level="INFO")
73
+ logger.trace(f"{chat_history=}")
74
+
75
+ try:
76
+ chat_history.append([message, ""])
77
+ except Exception:
78
+ chat_history = deque([message, ""], maxlen=5)
79
+ return "", chat_history
80
+
81
+
82
+ def user(message, chat_history):
83
+ """Gen a response."""
84
+ logger.debug(f"{message=}")
85
+ logger.trace(f"{chat_history=}")
86
+
87
+ try:
88
+ chat_history.append([message, ""])
89
+ except Exception:
90
+ chat_history = deque([message, ""], maxlen=5)
91
+ return message, chat_history
92
+
93
+
94
+ # for rerun in tests
95
+ model = None
96
+ gc.collect()
97
+ torch.cuda.empty_cache()
98
+
99
+ model = gen_model(model_name)
100
+
101
+
102
+ def bot(chat_history, **kwargs):
103
+ try:
104
+ message = chat_history[-1][0]
105
+ except Exception as exc:
106
+ logger.error(f"{chat_history=}: {exc}")
107
+ return chat_history
108
+ logger.debug(f"{chat_history=}")
109
+ try:
110
+ _ = """
111
+ response, chat_history = model.chat(
112
+ tokenizer,
113
+ message,
114
+ history=chat_history,
115
+ temperature=0.7,
116
+ repetition_penalty=1.2,
117
+ # max_length=128,
118
+ )
119
+ """
120
+ logger.debug("run model.chat...")
121
+ response, chat_history = model.chat(
122
+ tokenizer,
123
+ message,
124
+ chat_history[:-1],
125
+ **kwargs,
126
+ )
127
+ del response
128
+ return chat_history
129
+ except Exception as exc:
130
+ logger.error(exc)
131
+ chat_history[:-1].append(["message", str(exc)])
132
+ return chat_history
133
+
134
+
135
+ def bot_stream(chat_history):
136
+ try:
137
+ message = chat_history[-1][0]
138
+ except Exception as exc:
139
+ logger.error(f"{chat_history=}: {exc}")
140
+ raise gr.Error(f"{chat_history=}")
141
+ # yield chat_history
142
+ for elm in model.chat_stream(tokenizer, message, chat_history):
143
+ chat_history[-1] = [message, elm]
144
+ yield chat_history
145
+
146
+
147
+ SYSTEM_PROMPT = "You are a helpful assistant."
148
+ MAX_MAX_NEW_TOKENS = 1024
149
+ MAX_NEW_TOKENS = 128
150
+
151
+
152
+ @dataclass
153
+ class Config:
154
+ max_new_tokens: int = 64
155
+ repetition_penalty: float = 1.1
156
+ temperature: float = 1.0
157
+ top_k: int = 0
158
+ top_p: float = 0.9
159
+
160
+
161
+ stats_default = SimpleNamespace(llm=None, system_prompt=SYSTEM_PROMPT, config=Config())
162
+
163
+ theme = gr.themes.Soft(text_size="sm")
164
+ with gr.Blocks(
165
+ theme=theme,
166
+ title=model_name.lower(),
167
+ css=css,
168
+ ) as block:
169
+ stats = gr.State(stats_default)
170
+
171
+ def bot_stream_state(chat_history):
172
+ config = asdict(stats.value.config)
173
+ return bot_stream(chat_history, **config)
174
+
175
+ with gr.Accordion("🎈 Info", open=False):
176
+ gr.Markdown(
177
+ f"""<h5><center>{model_name.lower()}</center></h4>
178
+ Set `repetition_penalty` to 2.1 or higher for a chatty conversation. Lower it to 1.1 or smaller if more focused anwsers are desired (for example for translations or fact-oriented queries). Smaller `top_k` probably will result in smoothies sentences.
179
+ Consult `transformers` documentation for more details.
180
+
181
+ Most examples are meant for another model.
182
+ You probably should try to test
183
+ some related prompts.""",
184
+ elem_classes="xsmall",
185
+ )
186
+
187
+ chatbot = gr.Chatbot(height=500, value=deque([], maxlen=5)) # type: ignore
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=5):
191
+ msg = gr.Textbox(
192
+ label="Chat Message Box",
193
+ placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
194
+ show_label=False,
195
+ # container=False,
196
+ lines=4,
197
+ max_lines=30,
198
+ show_copy_button=True,
199
+ # ).style(container=False)
200
+ )
201
+ with gr.Column(scale=1, min_width=50):
202
+ with gr.Row():
203
+ submit = gr.Button("Submit", elem_classes="xsmall")
204
+ stop = gr.Button("Stop", visible=True)
205
+ clear = gr.Button("Clear History", visible=True)
206
+
207
+ msg_submit_event = msg.submit(
208
+ # fn=conversation.user_turn,
209
+ fn=user_sub,
210
+ inputs=[msg, chatbot],
211
+ outputs=[msg, chatbot],
212
+ queue=True,
213
+ show_progress="full",
214
+ # api_name=None,
215
+ ).then(bot_stream_state, chatbot, chatbot, queue=True)
216
+ submit_click_event = submit.click(
217
+ # fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
218
+ fn=user, # clear msg
219
+ inputs=[msg, chatbot],
220
+ outputs=[msg, chatbot],
221
+ queue=True,
222
+ show_progress="full",
223
+ # api_name=None,
224
+ ).then(bot_stream_state, chatbot, chatbot, queue=True)
225
+ stop.click(
226
+ fn=None,
227
+ inputs=None,
228
+ outputs=None,
229
+ cancels=[msg_submit_event, submit_click_event],
230
+ queue=False,
231
+ )
232
+ clear.click(lambda: None, None, chatbot, queue=False)
233
+
234
+ with gr.Accordion(label="Advanced Options", open=False):
235
+ system_prompt = gr.Textbox(
236
+ label="System prompt",
237
+ value=stats_default.system_prompt,
238
+ lines=3,
239
+ visible=True,
240
+ )
241
+ max_new_tokens = gr.Slider(
242
+ label="Max new tokens",
243
+ minimum=1,
244
+ maximum=MAX_MAX_NEW_TOKENS,
245
+ step=1,
246
+ value=stats_default.config.max_new_tokens,
247
+ )
248
+ repetition_penalty = gr.Slider(
249
+ label="Repetition penalty",
250
+ minimum=0.1,
251
+ maximum=40.0,
252
+ step=0.1,
253
+ value=stats_default.config.repetition_penalty,
254
+ )
255
+ temperature = gr.Slider(
256
+ label="Temperature",
257
+ minimum=0.1,
258
+ maximum=40.0,
259
+ step=0.1,
260
+ value=stats_default.config.temperature,
261
+ )
262
+ top_p = gr.Slider(
263
+ label="Top-p (nucleus sampling)",
264
+ minimum=0.05,
265
+ maximum=1.0,
266
+ step=0.05,
267
+ value=stats_default.config.top_p,
268
+ )
269
+ top_k = gr.Slider(
270
+ label="Top-k",
271
+ minimum=0,
272
+ maximum=1000,
273
+ step=1,
274
+ value=stats_default.config.top_k,
275
+ )
276
+
277
+ def system_prompt_fn(system_prompt):
278
+ stats.value.system_prompt = system_prompt
279
+ logger.debug(f"{stats.value.system_prompt=}")
280
+
281
+ def max_new_tokens_fn(max_new_tokens):
282
+ stats.value.config.max_new_tokens = max_new_tokens
283
+ logger.debug(f"{stats.value.config.max_new_tokens=}")
284
+
285
+ def repetition_penalty_fn(repetition_penalty):
286
+ stats.value.config.repetition_penalty = repetition_penalty
287
+ logger.debug(f"{stats.value=}")
288
 
289
+ def temperature_fn(temperature):
290
+ stats.value.config.temperature = temperature
291
+ logger.debug(f"{stats.value=}")
292
 
293
+ def top_p_fn(top_p):
294
+ stats.value.config.top_p = top_p
295
+ logger.debug(f"{stats.value=}")
 
 
 
 
296
 
297
+ def top_k_fn(top_k):
298
+ stats.value.config.top_k = top_k
299
+ logger.debug(f"{stats.value=}")
300
 
301
+ system_prompt.change(system_prompt_fn, system_prompt)
302
+ max_new_tokens.change(max_new_tokens_fn, max_new_tokens)
303
+ repetition_penalty.change(repetition_penalty_fn, repetition_penalty)
304
+ temperature.change(temperature_fn, temperature)
305
+ top_p.change(top_p_fn, top_p)
306
+ top_k.change(top_k_fn, top_k)
307
 
308
+ def reset_fn(stats_):
309
+ logger.debug("reset_fn")
310
+ stats_ = gr.State(stats_default)
311
+ logger.debug(f"{stats_.value=}")
312
+ return (
313
+ stats_,
314
+ stats_default.system_prompt,
315
+ stats_default.config.max_new_tokens,
316
+ stats_default.config.repetition_penalty,
317
+ stats_default.config.temperature,
318
+ stats_default.config.top_p,
319
+ stats_default.config.top_k,
320
+ )
321
 
322
+ reset_btn = gr.Button("Reset")
323
+ reset_btn.click(
324
+ reset_fn,
325
+ stats,
326
+ [
327
+ stats,
328
+ system_prompt,
329
+ max_new_tokens,
330
+ repetition_penalty,
331
+ temperature,
332
+ top_p,
333
+ top_k,
334
+ ],
335
+ )
336
 
337
+ with gr.Accordion("Example inputs", open=True):
338
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
339
+ examples = gr.Examples(
340
+ examples=example_list,
341
+ inputs=[msg],
342
+ examples_per_page=60,
343
+ )
344
+ with gr.Accordion("Disclaimer", open=False):
345
+ _ = model_name.lower()
346
+ gr.Markdown(
347
+ f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce "
348
+ f"factually accurate information. {_} was trained on various public datasets; while great efforts "
349
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
350
+ "biased, or otherwise offensive outputs.",
351
+ elem_classes=["disclaimer"],
352
+ )
353
 
354
+ if __name__ == "__main__":
355
+ block.queue(max_size=8).launch(debug=True)
example_list.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Do exmaple_list css."""
2
+ # pylint: disable=invalid-name, line-too-long,
3
+ css = """
4
+ .importantButton {
5
+ background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
6
+ border: none !important;
7
+ }
8
+ .importantButton:hover {
9
+ background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
10
+ border: none !important;
11
+ }
12
+ .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
13
+ .xsmall {font-size: x-small;}
14
+ """
15
+
16
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
17
+ example_list = [
18
+ ["What NFL team won the Super Bowl in the year Justin Bieber was born?"],
19
+ [
20
+ "What NFL team won the Super Bowl in the year Justin Bieber was born? Think step by step."
21
+ ],
22
+ ["How to pick a lock? Provide detailed steps."],
23
+ [
24
+ "If it takes 10 hours to dry 10 clothes, assuming all the clothes are hung together at the same time for drying , then how long will it take to dry a cloth?"
25
+ ],
26
+ [
27
+ "If it takes 10 hours to dry 10 clothes, assuming all the clothes are hung together at the same time for drying , then how long will it take to dry 23 clothes? Think step by step."
28
+ ],
29
+ ["is infinity + 1 bigger than infinity?"],
30
+ ["Explain the plot of Cinderella in a sentence."],
31
+ [
32
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
33
+ ],
34
+ ["What are some common mistakes to avoid when writing code?"],
35
+ ["Build a prompt to generate a beautiful portrait of a horse"],
36
+ ["Suggest four metaphors to describe the benefits of AI"],
37
+ ["Write a pop song about leaving home for the sandy beaches."],
38
+ ["Write a summary demonstrating my ability to tame lions"],
39
+ ["鲁迅和周树人什么关系"],
40
+ ["从前有一头牛,这头牛后面有什么?"],
41
+ ["正无穷大加一大于正无穷大吗?"],
42
+ ["正无穷大加正无穷大大于正无穷大吗?"],
43
+ ["-2的平方根等于什么"],
44
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
45
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
46
+ ["鲁迅和周树人什么关系 用英文回答"],
47
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
48
+ [f"{etext} 翻成中文,列出3个版本"],
49
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
50
+ ["js 判断一个数是不是质数"],
51
+ ["js 实现python 的 range(10)"],
52
+ ["js 实现python 的 [*(range(10)]"],
53
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
54
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
55
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
56
+ ]
requirements.txt CHANGED
@@ -13,7 +13,7 @@ torch # 2.0.1
13
  safetensors
14
  bitsandbytes
15
  transformers_stream_generator
16
- scipy
17
 
18
  loguru
19
  about-time
 
13
  safetensors
14
  bitsandbytes
15
  transformers_stream_generator
16
+ # scipy
17
 
18
  loguru
19
  about-time