ffreemt commited on
Commit
68482b0
1 Parent(s): cd90503

Fix show_progress=full

Browse files
Files changed (1) hide show
  1. app.py +169 -94
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # import gradio as gr
2
 
3
  # model_name = "models/THUDM/chatglm2-6b-int4"
@@ -6,17 +7,19 @@
6
  # %%writefile demo-4bit.py
7
 
8
  from textwrap import dedent
 
9
 
10
- # credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py
11
- # while mistakes are mine
12
- from transformers import AutoModel, AutoTokenizer
13
  import gradio as gr
14
  import mdtex2html
15
-
16
  from loguru import logger
17
 
18
- model_name = "THUDM/chatglm2-6b"
 
 
 
 
19
  model_name = "THUDM/chatglm2-6b-int4"
 
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
 
@@ -25,20 +28,23 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
  # 4/8 bit
26
  # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
27
 
28
- import torch
29
-
30
  has_cuda = torch.cuda.is_available()
31
  # has_cuda = False # force cpu
32
 
33
  if has_cuda:
34
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda() # 3.92G
 
 
35
  else:
36
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half() # .float() .half().float()
 
 
37
 
38
  model = model.eval()
39
 
40
  _ = """Override Chatbot.postprocess"""
41
 
 
42
  def postprocess(self, y):
43
  if y is None:
44
  return []
@@ -61,15 +67,15 @@ def parse_text(text):
61
  for i, line in enumerate(lines):
62
  if "```" in line:
63
  count += 1
64
- items = line.split('`')
65
  if count % 2 == 1:
66
  lines[i] = f'<pre><code class="language-{items[-1]}">'
67
  else:
68
- lines[i] = f'<br></code></pre>'
69
  else:
70
  if i > 0:
71
  if count % 2 == 1:
72
- line = line.replace("`", "\`")
73
  line = line.replace("<", "&lt;")
74
  line = line.replace(">", "&gt;")
75
  line = line.replace(" ", "&nbsp;")
@@ -81,23 +87,31 @@ def parse_text(text):
81
  line = line.replace("(", "&#40;")
82
  line = line.replace(")", "&#41;")
83
  line = line.replace("$", "&#36;")
84
- lines[i] = "<br>"+line
85
  text = "".join(lines)
86
  return text
87
 
88
 
89
- def predict(RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values):
 
 
90
  try:
91
  chatbot.append((parse_text(input), ""))
92
  except Exception as exc:
93
  logger.error(exc)
94
  chatbot[-1] = (parse_text(input), str(exc))
95
  yield chatbot, history, past_key_values
96
-
97
- for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
98
- return_past_key_values=True,
99
- max_length=max_length, top_p=top_p,
100
- temperature=temperature):
 
 
 
 
 
 
101
  chatbot[-1] = (parse_text(input), parse_text(response))
102
 
103
  yield chatbot, history, past_key_values
@@ -112,9 +126,9 @@ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
112
  temperature = 0.01
113
  try:
114
  res, _ = model.chat(
115
- tokenizer,
116
- input,
117
- history=[],
118
  past_key_values=None,
119
  max_length=max_length,
120
  top_p=top_p,
@@ -126,15 +140,16 @@ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
126
  res = str(exc)
127
 
128
  return res
129
-
130
 
131
  def reset_user_input():
132
- return gr.update(value='')
133
 
134
 
135
  def reset_state():
136
  return [], [], None
137
 
 
138
  # Delete last turn
139
  def delete_last_turn(chat, history):
140
  if chat and history:
@@ -145,53 +160,49 @@ def delete_last_turn(chat, history):
145
 
146
  # Regenerate response
147
  def retry_last_answer(
148
- user_input,
149
- chatbot,
150
- max_length,
151
- top_p,
152
- temperature,
153
- history,
154
- past_key_values
155
- ):
156
-
157
  if chatbot and history:
158
- # Removing the previous conversation from chat
159
  chatbot.pop(-1)
160
- # Setting up a flag to capture a retry
161
  RETRY_FLAG = True
162
  # Getting last message from user
163
  user_input = history[-1][0]
164
- # Removing bot response from the history
165
  history.pop(-1)
166
 
167
  yield from predict(
168
  RETRY_FLAG,
169
- user_input,
170
- chatbot,
171
- max_length,
172
- top_p,
173
- temperature,
174
- history,
175
- past_key_values
176
- )
 
177
 
178
  with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) as demo:
179
  # gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
180
- gr.HTML("""<center><a href="https://huggingface.co/spaces/mikeee/chatglm2-6b-4bit?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>To avoid the queue and for faster inference Duplicate this Space and upgrade to GPU</center>""")
 
 
181
 
182
  with gr.Accordion("Info", open=False):
183
  _ = """
184
  ## ChatGLM2-6B-int4
185
-
186
- With a GPU, a query takes from a few seconds to a few tens of seconds, dependent on the number of words/characters
187
  the question and responses contain. The quality of the responses varies quite a bit it seems. Even the same
188
  question with the same parameters, asked at different times, can result in quite different responses.
189
 
190
  * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
191
-
192
  * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
193
 
194
- * Top P controls dynamic vocabulary selection based on context.
195
 
196
  For a table of example values for different scenarios, refer to [this](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683)
197
 
@@ -204,8 +215,10 @@ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) a
204
  with gr.Row():
205
  with gr.Column(scale=4):
206
  with gr.Column(scale=12):
207
- user_input = gr.Textbox(show_label=False, placeholder="Input...", ).style(
208
- container=False)
 
 
209
  RETRY_FLAG = gr.Checkbox(value=False, visible=False)
210
  with gr.Column(min_width=32, scale=1):
211
  with gr.Row():
@@ -214,59 +227,109 @@ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) a
214
  retryBtn = gr.Button("Regenerate", variant="secondary")
215
  with gr.Column(scale=1):
216
  emptyBtn = gr.Button("Clear History")
217
- max_length = gr.Slider(0, 32768, value=8192/2, step=1.0, label="Maximum length", interactive=True)
218
- top_p = gr.Slider(0, 1, value=0.85, step=0.01, label="Top P", interactive=True)
219
- temperature = gr.Slider(0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  history = gr.State([])
222
  past_key_values = gr.State(None)
223
 
224
- user_input.submit(predict, [RETRY_FLAG, user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
225
- [chatbot, history, past_key_values], show_progress=True)
226
- submitBtn.click(predict, [RETRY_FLAG, user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
227
- [chatbot, history, past_key_values], show_progress=True, api_name="predict")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  submitBtn.click(reset_user_input, [], [user_input])
229
 
230
- emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
 
 
231
 
232
  retryBtn.click(
233
- retry_last_answer,
234
- inputs = [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
235
- #outputs = [chatbot, history, last_user_message, user_message]
236
- outputs=[chatbot, history, past_key_values]
237
- )
238
- deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
 
 
 
 
 
 
 
 
239
 
240
  with gr.Accordion("Example inputs", open=True):
241
  etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
242
  examples = gr.Examples(
243
  examples=[
244
- ["Explain the plot of Cinderella in a sentence."],
245
- ["How long does it take to become proficient in French, and what are the best methods for retaining information?"],
246
- ["What are some common mistakes to avoid when writing code?"],
247
- ["Build a prompt to generate a beautiful portrait of a horse"],
248
- ["Suggest four metaphors to describe the benefits of AI"],
249
- ["Write a pop song about leaving home for the sandy beaches."],
250
- ["Write a summary demonstrating my ability to tame lions"],
251
- ["鲁迅和周树人什么关系"],
252
- ["从前有一头牛,这头牛后面有什么?"],
253
- ["正无穷大加一大于正无穷大吗?"],
254
- ["正无穷大加正无穷大大于正无穷大吗?"],
255
- ["-2的平方根等于什么"],
256
- ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
257
- ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
258
- ["鲁迅和周树人什么关系 用英文回答"],
259
- ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
260
- [f"{etext} 翻成中文,列出3个版本"],
261
- [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
262
- ["js 判断一个数是不是质数"],
263
- ["js 实现python 的 range(10)"],
264
- ["js 实现python 的 [*(range(10)]"],
265
- ["假定 1 + 2 = 4, 试求 7 + 8"],
266
- ["Erkläre die Handlung von Cinderella in einem Satz."],
267
- ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
 
 
268
  ],
269
- inputs = [user_input],
270
  examples_per_page=30,
271
  )
272
 
@@ -274,10 +337,22 @@ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) a
274
  input_text = gr.Text()
275
  tr_btn = gr.Button("Go", variant="primary")
276
  out_text = gr.Text()
277
- tr_btn.click(trans_api, [input_text, max_length, top_p, temperature], out_text, show_progress=True, api_name="tr")
278
- input_text.submit(trans_api, [input_text, max_length, top_p, temperature], out_text, show_progress=True, api_name="tr")
279
-
 
 
 
 
 
 
 
 
 
 
 
 
280
  # demo.queue().launch(share=False, inbrowser=True)
281
  # demo.queue().launch(share=True, inbrowser=True, debug=True)
282
 
283
- demo.queue().launch(debug=True)
 
1
+ # pylint: disable=broad-exception-caught, redefined-outer-name, missing-function-docstring, missing-module-docstring, too-many-arguments, line-too-long, invalid-name, redefined-builtin
2
  # import gradio as gr
3
 
4
  # model_name = "models/THUDM/chatglm2-6b-int4"
 
7
  # %%writefile demo-4bit.py
8
 
9
  from textwrap import dedent
10
+ import torch
11
 
 
 
 
12
  import gradio as gr
13
  import mdtex2html
 
14
  from loguru import logger
15
 
16
+ # credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py
17
+ # while mistakes are mine
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ model_name = "THUDM/chatglm2-6 b"
21
  model_name = "THUDM/chatglm2-6b-int4"
22
+ RETRY_FLAG = False
23
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
 
 
28
  # 4/8 bit
29
  # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
30
 
 
 
31
  has_cuda = torch.cuda.is_available()
32
  # has_cuda = False # force cpu
33
 
34
  if has_cuda:
35
+ model = AutoModel.from_pretrained(
36
+ model_name, trust_remote_code=True
37
+ ).cuda() # 3.92G
38
  else:
39
+ model = AutoModel.from_pretrained(
40
+ model_name, trust_remote_code=True
41
+ ).half() # .float() .half().float()
42
 
43
  model = model.eval()
44
 
45
  _ = """Override Chatbot.postprocess"""
46
 
47
+
48
  def postprocess(self, y):
49
  if y is None:
50
  return []
 
67
  for i, line in enumerate(lines):
68
  if "```" in line:
69
  count += 1
70
+ items = line.split("`")
71
  if count % 2 == 1:
72
  lines[i] = f'<pre><code class="language-{items[-1]}">'
73
  else:
74
+ lines[i] = "<br></code></pre>"
75
  else:
76
  if i > 0:
77
  if count % 2 == 1:
78
+ line = line.replace("`", r"\`")
79
  line = line.replace("<", "&lt;")
80
  line = line.replace(">", "&gt;")
81
  line = line.replace(" ", "&nbsp;")
 
87
  line = line.replace("(", "&#40;")
88
  line = line.replace(")", "&#41;")
89
  line = line.replace("$", "&#36;")
90
+ lines[i] = "<br>" + line
91
  text = "".join(lines)
92
  return text
93
 
94
 
95
+ def predict(
96
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
97
+ ):
98
  try:
99
  chatbot.append((parse_text(input), ""))
100
  except Exception as exc:
101
  logger.error(exc)
102
  chatbot[-1] = (parse_text(input), str(exc))
103
  yield chatbot, history, past_key_values
104
+
105
+ for response, history, past_key_values in model.stream_chat(
106
+ tokenizer,
107
+ input,
108
+ history,
109
+ past_key_values=past_key_values,
110
+ return_past_key_values=True,
111
+ max_length=max_length,
112
+ top_p=top_p,
113
+ temperature=temperature,
114
+ ):
115
  chatbot[-1] = (parse_text(input), parse_text(response))
116
 
117
  yield chatbot, history, past_key_values
 
126
  temperature = 0.01
127
  try:
128
  res, _ = model.chat(
129
+ tokenizer,
130
+ input,
131
+ history=[],
132
  past_key_values=None,
133
  max_length=max_length,
134
  top_p=top_p,
 
140
  res = str(exc)
141
 
142
  return res
143
+
144
 
145
  def reset_user_input():
146
+ return gr.update(value="")
147
 
148
 
149
  def reset_state():
150
  return [], [], None
151
 
152
+
153
  # Delete last turn
154
  def delete_last_turn(chat, history):
155
  if chat and history:
 
160
 
161
  # Regenerate response
162
  def retry_last_answer(
163
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
164
+ ):
 
 
 
 
 
 
 
165
  if chatbot and history:
166
+ # Removing the previous conversation from chat
167
  chatbot.pop(-1)
168
+ # Setting up a flag to capture a retry
169
  RETRY_FLAG = True
170
  # Getting last message from user
171
  user_input = history[-1][0]
172
+ # Removing bot response from the history
173
  history.pop(-1)
174
 
175
  yield from predict(
176
  RETRY_FLAG,
177
+ user_input,
178
+ chatbot,
179
+ max_length,
180
+ top_p,
181
+ temperature,
182
+ history,
183
+ past_key_values,
184
+ )
185
+
186
 
187
  with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) as demo:
188
  # gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
189
+ gr.HTML(
190
+ """<center><a href="https://huggingface.co/spaces/mikeee/chatglm2-6b-4bit?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>To avoid the queue and for faster inference Duplicate this Space and upgrade to GPU</center>"""
191
+ )
192
 
193
  with gr.Accordion("Info", open=False):
194
  _ = """
195
  ## ChatGLM2-6B-int4
196
+
197
+ With a GPU, a query takes from a few seconds to a few tens of seconds, dependent on the number of words/characters
198
  the question and responses contain. The quality of the responses varies quite a bit it seems. Even the same
199
  question with the same parameters, asked at different times, can result in quite different responses.
200
 
201
  * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
202
+
203
  * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
204
 
205
+ * Top P controls dynamic vocabulary selection based on context.
206
 
207
  For a table of example values for different scenarios, refer to [this](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683)
208
 
 
215
  with gr.Row():
216
  with gr.Column(scale=4):
217
  with gr.Column(scale=12):
218
+ user_input = gr.Textbox(
219
+ show_label=False,
220
+ placeholder="Input...",
221
+ ).style(container=False)
222
  RETRY_FLAG = gr.Checkbox(value=False, visible=False)
223
  with gr.Column(min_width=32, scale=1):
224
  with gr.Row():
 
227
  retryBtn = gr.Button("Regenerate", variant="secondary")
228
  with gr.Column(scale=1):
229
  emptyBtn = gr.Button("Clear History")
230
+ max_length = gr.Slider(
231
+ 0,
232
+ 32768,
233
+ value=8192 / 2,
234
+ step=1.0,
235
+ label="Maximum length",
236
+ interactive=True,
237
+ )
238
+ top_p = gr.Slider(
239
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
240
+ )
241
+ temperature = gr.Slider(
242
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
243
+ )
244
 
245
  history = gr.State([])
246
  past_key_values = gr.State(None)
247
 
248
+ user_input.submit(
249
+ predict,
250
+ [
251
+ RETRY_FLAG,
252
+ user_input,
253
+ chatbot,
254
+ max_length,
255
+ top_p,
256
+ temperature,
257
+ history,
258
+ past_key_values,
259
+ ],
260
+ [chatbot, history, past_key_values],
261
+ show_progress="full",
262
+ )
263
+ submitBtn.click(
264
+ predict,
265
+ [
266
+ RETRY_FLAG,
267
+ user_input,
268
+ chatbot,
269
+ max_length,
270
+ top_p,
271
+ temperature,
272
+ history,
273
+ past_key_values,
274
+ ],
275
+ [chatbot, history, past_key_values],
276
+ show_progress="full",
277
+ api_name="predict",
278
+ )
279
  submitBtn.click(reset_user_input, [], [user_input])
280
 
281
+ emptyBtn.click(
282
+ reset_state, outputs=[chatbot, history, past_key_values], show_progress="full"
283
+ )
284
 
285
  retryBtn.click(
286
+ retry_last_answer,
287
+ inputs=[
288
+ user_input,
289
+ chatbot,
290
+ max_length,
291
+ top_p,
292
+ temperature,
293
+ history,
294
+ past_key_values,
295
+ ],
296
+ # outputs = [chatbot, history, last_user_message, user_message]
297
+ outputs=[chatbot, history, past_key_values],
298
+ )
299
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
300
 
301
  with gr.Accordion("Example inputs", open=True):
302
  etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
303
  examples = gr.Examples(
304
  examples=[
305
+ ["Explain the plot of Cinderella in a sentence."],
306
+ [
307
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
308
+ ],
309
+ ["What are some common mistakes to avoid when writing code?"],
310
+ ["Build a prompt to generate a beautiful portrait of a horse"],
311
+ ["Suggest four metaphors to describe the benefits of AI"],
312
+ ["Write a pop song about leaving home for the sandy beaches."],
313
+ ["Write a summary demonstrating my ability to tame lions"],
314
+ ["鲁迅和周树人什么关系"],
315
+ ["从前有一头牛,这头牛后面有什么?"],
316
+ ["正无穷大加一大于正无穷大吗?"],
317
+ ["正无穷大加正无穷大大于正无穷大吗?"],
318
+ ["-2的平方根等于什么"],
319
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
320
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
321
+ ["鲁迅和周树人什么关系 用英文回答"],
322
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
323
+ [f"{etext} 翻成中文,列出3个版本"],
324
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
325
+ ["js 判断一个数是不是质数"],
326
+ ["js 实现python range(10)"],
327
+ ["js 实现python [*(range(10)]"],
328
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
329
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
330
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
331
  ],
332
+ inputs=[user_input],
333
  examples_per_page=30,
334
  )
335
 
 
337
  input_text = gr.Text()
338
  tr_btn = gr.Button("Go", variant="primary")
339
  out_text = gr.Text()
340
+ tr_btn.click(
341
+ trans_api,
342
+ [input_text, max_length, top_p, temperature],
343
+ out_text,
344
+ show_progress="full",
345
+ api_name="tr",
346
+ )
347
+ input_text.submit(
348
+ trans_api,
349
+ [input_text, max_length, top_p, temperature],
350
+ out_text,
351
+ show_progress="full",
352
+ api_name="tr",
353
+ )
354
+
355
  # demo.queue().launch(share=False, inbrowser=True)
356
  # demo.queue().launch(share=True, inbrowser=True, debug=True)
357
 
358
+ demo.queue().launch(debug=True)