Yingxu He commited on
Commit
5e66ec0
·
verified ·
1 Parent(s): 1a93634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -81
app.py CHANGED
@@ -1,109 +1,71 @@
1
- import argparse
2
- from pathlib import Path
3
-
4
  import chatglm_cpp
5
  import gradio as gr
6
 
7
- import urllib
8
-
9
  DEFAULT_MODEL_PATH = "chatglm3-6b.bin"
10
-
11
  urllib.request.urlretrieve(
12
- "https://huggingface.co/Braddy/chatglm3-6b-chitchat/resolve/main/q5_1_loss_304.bin?download=true",
13
  DEFAULT_MODEL_PATH
14
  )
15
 
16
- parser = argparse.ArgumentParser()
17
- parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
18
- parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
19
- parser.add_argument("-l", "--max_new_tokens", default=64, type=int, help="max total output tokens")
20
- parser.add_argument("-c", "--max_context_length", default=1024, type=int, help="max context length")
21
- parser.add_argument("--top_k", default=40, type=int, help="top-k sampling")
22
- parser.add_argument("--top_p", default=0.75, type=float, help="top-p sampling")
23
- parser.add_argument("--temp", default=0.5, type=float, help="temperature")
24
- parser.add_argument("--repeat_penalty", default=1.0, type=float, help="penalize repeat sequence of tokens")
25
- parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference")
26
- parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
27
- args = parser.parse_args()
28
 
29
- pipeline = chatglm_cpp.Pipeline(args.model)
30
  system_message = chatglm_cpp.ChatMessage(role="system", content="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。")
31
 
32
 
33
- def postprocess(text):
34
- if args.plain:
35
- return f"<pre>{text}</pre>"
36
- return text
 
 
 
 
 
 
37
 
 
 
 
 
 
 
38
 
39
- def predict(input, chatbot, max_new_tokens, top_p, temperature, messages):
40
- chatbot.append((postprocess(input), ""))
41
- messages.append(chatglm_cpp.ChatMessage(role="user", content=input))
42
- full_messages = [system_message] + messages
43
 
44
  generation_kwargs = dict(
45
  max_new_tokens=max_new_tokens,
46
- max_context_length=args.max_context_length,
47
  do_sample=temperature > 0,
48
- top_k=args.top_k,
49
  top_p=top_p,
50
  temperature=temperature,
51
- repetition_penalty=args.repeat_penalty,
52
- num_threads=args.threads,
53
  stream=True,
54
  )
55
 
56
  response = ""
57
- chunks = []
58
- for chunk in pipeline.chat(full_messages, **generation_kwargs):
59
  response += chunk.content
60
- chunks.append(chunk)
61
- chatbot[-1] = (chatbot[-1][0], postprocess(response))
62
- yield chatbot, messages
63
- messages.append(pipeline.merge_streaming_messages(chunks))
64
-
65
- yield chatbot, messages
66
-
67
-
68
- def reset_user_input():
69
- return gr.update(value="")
70
-
71
 
72
- def reset_state():
73
- return [], []
74
 
75
- title = """
76
- <div style="text-align: center">
77
- <h1>Chichat</h1>
78
- <p style="text-align: center;">Free feel to talk about anything :)</p>
79
- </div>
80
  """
81
-
82
-
83
- with gr.Blocks() as demo:
84
- gr.HTML(title)
85
-
86
- chatbot = gr.Chatbot(height=300, label="Check this out!")
87
- with gr.Row():
88
- with gr.Column(scale=4):
89
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=8)
90
- submitBtn = gr.Button("Submit", variant="primary")
91
- with gr.Column(scale=1):
92
- max_new_tokens = gr.Slider(0, 512, value=args.max_new_tokens, step=1.0, label="Maximum output tokens", interactive=True)
93
- top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
94
- temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
95
- emptyBtn = gr.Button("Clear History")
96
-
97
- messages = gr.State([])
98
-
99
- submitBtn.click(
100
- predict,
101
- [user_input, chatbot, max_new_tokens, top_p, temperature, messages],
102
- [chatbot, messages],
103
- show_progress=True,
104
- )
105
- submitBtn.click(reset_user_input, [], [user_input])
106
-
107
- emptyBtn.click(reset_state, outputs=[chatbot, messages], show_progress=True)
108
-
109
- demo.queue().launch(share=False, inbrowser=True)
 
1
+ import urllib
 
 
2
  import chatglm_cpp
3
  import gradio as gr
4
 
 
 
5
  DEFAULT_MODEL_PATH = "chatglm3-6b.bin"
 
6
  urllib.request.urlretrieve(
7
+ "https://huggingface.co/Braddy/chatglm3-6b-chitchat/resolve/main/q5_1.bin?download=true",
8
  DEFAULT_MODEL_PATH
9
  )
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH)
13
  system_message = chatglm_cpp.ChatMessage(role="system", content="请你现在扮演一个软件工程师,名字叫做贺英旭。你需要以这个身份和朋友们对话。")
14
 
15
 
16
+ def respond(
17
+ message,
18
+ history: list[tuple[str, str]],
19
+ system_message,
20
+ max_new_tokens,
21
+ temperature,
22
+ top_p,
23
+ ):
24
+ messages = [chatglm_cpp.ChatMessage(role="system", content=system_message)]
25
+ # messages = [{"role": "system", "content": system_message}]
26
 
27
+ for val in history:
28
+ if val[0]:
29
+ # messages.append({"role": "user", "content": val[0]})
30
+ messages.append(chatglm_cpp.ChatMessage(role="user", content=val[0]))
31
+ if val[1]:
32
+ messages.append(chatglm_cpp.ChatMessage(role="assistant", content=val[1]))
33
 
34
+ messages.append(chatglm_cpp.ChatMessage(role="user", content=message))
 
 
 
35
 
36
  generation_kwargs = dict(
37
  max_new_tokens=max_new_tokens,
 
38
  do_sample=temperature > 0,
 
39
  top_p=top_p,
40
  temperature=temperature,
 
 
41
  stream=True,
42
  )
43
 
44
  response = ""
45
+ for chunk in pipeline.chat(messages, **generation_kwargs):
 
46
  response += chunk.content
47
+ yield response
 
 
 
 
 
 
 
 
 
 
48
 
 
 
49
 
 
 
 
 
 
50
  """
51
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
52
+ """
53
+ demo = gr.ChatInterface(
54
+ respond,
55
+ additional_inputs=[
56
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
57
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
58
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
+ gr.Slider(
60
+ minimum=0.1,
61
+ maximum=1.0,
62
+ value=0.95,
63
+ step=0.05,
64
+ label="Top-p (nucleus sampling)",
65
+ ),
66
+ ],
67
+ )
68
+
69
+
70
+ if __name__ == "__main__":
71
+ demo.launch()