KumaTea commited on
Commit
e65d733
·
1 Parent(s): 7a2c4d5
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +130 -0
  3. requirements.txt +22 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: KumaGLM
3
- emoji: 👁
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
  app_file: app.py
 
1
  ---
2
  title: KumaGLM
3
+ emoji: 🐻
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.24.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Credit:
2
+ # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
3
+
4
+
5
+ import torch
6
+ import gradio as gr
7
+ from peft import PeftModel
8
+ from transformers import AutoTokenizer, GenerationConfig, AutoModel
9
+
10
+
11
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe").float()
12
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe")
13
+ peft_path = 'KumaTea/twitter'
14
+ model = PeftModel.from_pretrained(
15
+ model,
16
+ peft_path,
17
+ torch_dtype=torch.float,
18
+ )
19
+
20
+ # dump a log to ensure everything works well
21
+ # print(model.peft_config)
22
+ # We have to use full precision, as some tokens are >65535
23
+ model.eval()
24
+ # print(model)
25
+
26
+ torch.set_default_tensor_type(torch.FloatTensor)
27
+
28
+
29
+ def evaluate(context, temperature, top_p, top_k):
30
+ generation_config = GenerationConfig(
31
+ temperature=temperature,
32
+ top_p=top_p,
33
+ top_k=top_k,
34
+ #repetition_penalty=1.1,
35
+ num_beams=1,
36
+ do_sample=True,
37
+ )
38
+ with torch.no_grad():
39
+ input_text = f"Context: {context}Answer: "
40
+ ids = tokenizer.encode(input_text)
41
+ input_ids = torch.LongTensor([ids]).to('cpu')
42
+ out = model.generate(
43
+ input_ids=input_ids,
44
+ max_length=160,
45
+ generation_config=generation_config
46
+ )
47
+ out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
48
+ return out_text
49
+
50
+
51
+ def evaluate_stream(msg, history, temperature, top_p):
52
+ generation_config = GenerationConfig(
53
+ temperature=temperature,
54
+ top_p=top_p,
55
+ #repetition_penalty=1.1,
56
+ num_beams=1,
57
+ do_sample=True,
58
+ )
59
+
60
+ history.append([msg, None])
61
+
62
+ context = ""
63
+ if len(history) > 4:
64
+ history.pop(0)
65
+
66
+ for j in range(len(history)):
67
+ history[j][0] = history[j][0].replace("<br>", "")
68
+
69
+ # concatenate context
70
+ for h in history[:-1]:
71
+ context += h[0] + "||" + h[1] + "||"
72
+
73
+ context += history[-1][0]
74
+ context = context.replace(r'<br>', '')
75
+
76
+ # TODO: Avoid the tokens are too long.
77
+ CUTOFF = 224
78
+ while len(tokenizer.encode(context)) > CUTOFF:
79
+ # save 15 token size for the answer
80
+ context = context[15:]
81
+
82
+ h = []
83
+ print("History:", history)
84
+ print("Context:", context)
85
+ for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
86
+ history[-1][1] = response
87
+ yield history, ""
88
+
89
+ #return response
90
+
91
+
92
+ title = """<h1 align="center">KumaGLM</h1>
93
+ <h3 align='center'>这是一个 AI Kuma,你可以与她聊天,或者直接在文本框按下Enter</h3>
94
+ <p align='center'>采样范围 2020/06/13 - 2023/04/15</p>"""
95
+ footer = """<p align='center'>
96
+ 本项目基于
97
+ <a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
98
+ ,模型采用
99
+ <a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
100
+
101
+ </p>
102
+ <p align='center'>
103
+ <em>每天起床第一句!</em>
104
+ </p>"""
105
+
106
+ with gr.Blocks() as demo:
107
+ gr.HTML(title)
108
+ state = gr.State()
109
+ with gr.Row():
110
+ with gr.Column(scale=2):
111
+ temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.8, label="Temperature",
112
+ info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
113
+ top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
114
+ info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
115
+ #code = gr.Textbox(label="temp_output", info="解码器输出")
116
+ #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
117
+ # info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
118
+
119
+ with gr.Column(scale=3):
120
+ chatbot = gr.Chatbot(label="聊天框", info="")
121
+ msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
122
+ info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。对话一般不能太长,否则就复读机了,建议清除数据。")
123
+ clear = gr.Button("清除聊天")
124
+
125
+ msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
126
+ clear.click(lambda: None, None, chatbot, queue=False)
127
+ gr.HTML(footer)
128
+
129
+ demo.queue()
130
+ demo.launch(debug=False)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/requirements.txt
2
+
3
+ # int8
4
+ bitsandbytes>=0.37.1
5
+ accelerate>=0.17.1
6
+
7
+ # chatglm
8
+ protobuf>=3.19.5,<3.20.1
9
+ transformers>=4.27.1
10
+ icetk
11
+ cpm_kernels>=1.0.11
12
+
13
+ #
14
+ datasets>=2.10.1
15
+ git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
16
+
17
+ -f https://download.pytorch.org/whl/cpu
18
+ torch
19
+ -f https://download.pytorch.org/whl/cpu
20
+ torchvision
21
+ -f https://download.pytorch.org/whl/cpu
22
+ torchaudio