NewBreaker commited on
Commit
47b54c6
1 Parent(s): 6647827
Files changed (17) hide show
  1. 1.py +14 -0
  2. 1wandb.py +7 -0
  3. 2 +1 -0
  4. PROJECT.md +18 -0
  5. README.md +0 -12
  6. api.py +60 -0
  7. api_use.py +14 -0
  8. app.py +35 -0
  9. cli_demo.py +57 -0
  10. demo_app.py +18 -0
  11. demo_mult_chats.py +67 -0
  12. demo_single_chat.py +52 -0
  13. requirements.txt +8 -0
  14. utils.py +54 -0
  15. web_demo.py +104 -0
  16. web_demo2.py +69 -0
  17. web_demo_old.py +45 -0
1.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+
3
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
4
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
5
+ kernel_file = ".\\models\\chatglm-6b-int4\\quantization_kernels.so"
6
+
7
+ model = model.quantize(bits=4,kernel_file=kernel_file)
8
+ model = model.eval()
9
+
10
+
11
+ response, history = model.chat(tokenizer, "你好", history=[])
12
+ print(response)
13
+
14
+
1wandb.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import wandb
2
+
3
+ wandb.login(key="b88cdc5f017d4e8c7b6a07aec184f577942139de")
4
+ wandb.init(project="chatglm")
5
+ print(1111)
6
+ # import wandb
7
+ #
2 ADDED
@@ -0,0 +1 @@
 
 
1
+ 你好
PROJECT.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 友情链接
2
+
3
+ 以下是部分基于本仓库开发的开源项目:
4
+ * [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
5
+ * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
6
+ * [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)
7
+ * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于本地知识的 ChatGLM 应用,基于LangChain
8
+ * [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
9
+ * [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
10
+ * [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
11
+ * [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
12
+ * [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
13
+ * [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
14
+ * [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
15
+
16
+ 以下是部分针对本项目的教程/文档:
17
+ * [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
18
+ * [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
README.md DELETED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Chatglm 6b Int4
3
- emoji: 📉
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.27.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import uvicorn, json, datetime
4
+ import torch
5
+
6
+ DEVICE = "cuda"
7
+ DEVICE_ID = "0"
8
+ CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
9
+
10
+
11
+ def torch_gc():
12
+ if torch.cuda.is_available():
13
+ with torch.cuda.device(CUDA_DEVICE):
14
+ torch.cuda.empty_cache()
15
+ torch.cuda.ipc_collect()
16
+
17
+
18
+ app = FastAPI()
19
+
20
+
21
+ @app.post("/")
22
+ async def create_item(request: Request):
23
+ global model, tokenizer
24
+ json_post_raw = await request.json()
25
+ json_post = json.dumps(json_post_raw)
26
+ json_post_list = json.loads(json_post)
27
+ prompt = json_post_list.get('prompt')
28
+ history = json_post_list.get('history')
29
+ max_length = json_post_list.get('max_length')
30
+ top_p = json_post_list.get('top_p')
31
+ temperature = json_post_list.get('temperature')
32
+ response, history = model.chat(tokenizer,
33
+ prompt,
34
+ history=history,
35
+ max_length=max_length if max_length else 2048,
36
+ top_p=top_p if top_p else 0.7,
37
+ temperature=temperature if temperature else 0.95)
38
+ now = datetime.datetime.now()
39
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
40
+ answer = {
41
+ "response": response,
42
+ "history": history,
43
+ "status": 200,
44
+ "time": time
45
+ }
46
+ log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
47
+ print(log)
48
+ torch_gc()
49
+ return answer
50
+
51
+
52
+ if __name__ == '__main__':
53
+ # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
54
+ # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
55
+ tokenizer = AutoTokenizer.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="")
56
+ model = AutoModel.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
57
+
58
+
59
+ model.eval()
60
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
api_use.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ url = 'http://127.0.0.1:8000'
5
+ headers = {
6
+ 'Content-Type': 'application/json'
7
+ }
8
+ data = {
9
+ 'prompt': '你好',
10
+ 'history': []
11
+ }
12
+
13
+ response = requests.post(url=url, headers=headers, data=json.dumps(data))
14
+ print(response.json())
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import gradio as gr
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
5
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda().cpu()
6
+
7
+
8
+ # from transformers import AutoTokenizer, AutoModel
9
+ # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
10
+ # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
11
+ # model = model.eval()
12
+
13
+
14
+ # kernel_file = "./models/chatglm-6b-int4/quantization_kernels.so"
15
+ # tokenizer = AutoTokenizer.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="")
16
+ # model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
17
+ # model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").half()
18
+
19
+
20
+
21
+ # model = model.quantize(bits=model_args.quantization_bit, kernel_file=kernel_file)
22
+
23
+ model = model.eval()
24
+
25
+
26
+
27
+ def chat(msg):
28
+ history = []
29
+ response, history = model.chat(tokenizer, msg, history=history)
30
+ print("response:", response)
31
+ return response
32
+
33
+
34
+ iface = gr.Interface(fn=chat, inputs="text", outputs="text")
35
+ iface.launch()
cli_demo.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import signal
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
7
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
8
+ model = model.eval()
9
+
10
+ os_name = platform.system()
11
+ clear_command = 'cls' if os_name == 'Windows' else 'clear'
12
+ stop_stream = False
13
+
14
+
15
+ def build_prompt(history):
16
+ prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
17
+ for query, response in history:
18
+ prompt += f"\n\n用户:{query}"
19
+ prompt += f"\n\nChatGLM-6B:{response}"
20
+ return prompt
21
+
22
+
23
+ def signal_handler(signal, frame):
24
+ global stop_stream
25
+ stop_stream = True
26
+
27
+
28
+ def main():
29
+ history = []
30
+ global stop_stream
31
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
32
+ while True:
33
+ query = input("\n用户:")
34
+ if query.strip() == "stop":
35
+ break
36
+ if query.strip() == "clear":
37
+ history = []
38
+ os.system(clear_command)
39
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
40
+ continue
41
+ count = 0
42
+ for response, history in model.stream_chat(tokenizer, query, history=history):
43
+ if stop_stream:
44
+ stop_stream = False
45
+ break
46
+ else:
47
+ count += 1
48
+ if count % 8 == 0:
49
+ os.system(clear_command)
50
+ print(build_prompt(history), flush=True)
51
+ signal.signal(signal.SIGINT, signal_handler)
52
+ os.system(clear_command)
53
+ print(build_prompt(history), flush=True)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
demo_app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import gradio as gr
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
5
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
6
+ model = model.eval()
7
+
8
+
9
+
10
+ def chat(msg):
11
+ history = []
12
+ response, history = model.chat(tokenizer, msg, history=history)
13
+ print("response:", response)
14
+ return response
15
+
16
+
17
+ iface = gr.Interface(fn=chat, inputs="text", outputs="text")
18
+ iface.launch()
demo_mult_chats.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
3
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
4
+ model = model.eval()
5
+
6
+
7
+
8
+
9
+ def parse_text(text):
10
+ lines = text.split("\n")
11
+ lines = [line for line in lines if line != ""]
12
+ count = 0
13
+ for i, line in enumerate(lines):
14
+ if "```" in line:
15
+ count += 1
16
+ items = line.split('`')
17
+ if count % 2 == 1:
18
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
19
+ else:
20
+ lines[i] = f'<br></code></pre>'
21
+ else:
22
+ if i > 0:
23
+ if count % 2 == 1:
24
+ line = line.replace("`", "\`")
25
+ line = line.replace("<", "&lt;")
26
+ line = line.replace(">", "&gt;")
27
+ line = line.replace(" ", "&nbsp;")
28
+ line = line.replace("*", "&ast;")
29
+ line = line.replace("_", "&lowbar;")
30
+ line = line.replace("-", "&#45;")
31
+ line = line.replace(".", "&#46;")
32
+ line = line.replace("!", "&#33;")
33
+ line = line.replace("(", "&#40;")
34
+ line = line.replace(")", "&#41;")
35
+ line = line.replace("$", "&#36;")
36
+ lines[i] = "<br>"+line
37
+ text = "".join(lines)
38
+ return text
39
+
40
+ def predict(input, chatbot, max_length, top_p, temperature, history):
41
+ chatbot.append((parse_text(input), ""))
42
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
43
+ temperature=temperature):
44
+ chatbot[-1] = (parse_text(input), parse_text(response))
45
+
46
+ yield chatbot, history
47
+
48
+ response_new = ''
49
+ history = []
50
+
51
+ for i in range(3000):
52
+ length_history = len(history)
53
+ if (length_history > 5): # 如果对话长度太长,就把之前的遗忘掉
54
+ del history[0]
55
+ del history[0]
56
+ print('\033[1;31m{}\033[0m'.format('\nYou:'),end='')
57
+ msg = input()
58
+ print('\033[1;34m{}\033[0m'.format('ChatGLM:'),end='')
59
+
60
+ for chatbot, history in predict(input=msg, chatbot=[], max_length=10000, top_p=0.5, temperature=0.5, history=history):
61
+ response_old = response_new
62
+ response_new = chatbot[0][1]
63
+ new_single = response_new.replace(response_old, '')
64
+ print(new_single,end='')
65
+
66
+
67
+
demo_single_chat.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
3
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
4
+ kernel_file =
5
+ model = model.quantize(bits=4, kernel_file=kernel)
6
+ model = model.eval()
7
+
8
+ def parse_text(text):
9
+ lines = text.split("\n")
10
+ lines = [line for line in lines if line != ""]
11
+ count = 0
12
+ for i, line in enumerate(lines):
13
+ if "```" in line:
14
+ count += 1
15
+ items = line.split('`')
16
+ if count % 2 == 1:
17
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
18
+ else:
19
+ lines[i] = f'<br></code></pre>'
20
+ else:
21
+ if i > 0:
22
+ if count % 2 == 1:
23
+ line = line.replace("`", "\`")
24
+ line = line.replace("<", "&lt;")
25
+ line = line.replace(">", "&gt;")
26
+ line = line.replace(" ", "&nbsp;")
27
+ line = line.replace("*", "&ast;")
28
+ line = line.replace("_", "&lowbar;")
29
+ line = line.replace("-", "&#45;")
30
+ line = line.replace(".", "&#46;")
31
+ line = line.replace("!", "&#33;")
32
+ line = line.replace("(", "&#40;")
33
+ line = line.replace(")", "&#41;")
34
+ line = line.replace("$", "&#36;")
35
+ lines[i] = "<br>"+line
36
+ text = "".join(lines)
37
+ return text
38
+
39
+ def predict(input, chatbot, max_length, top_p, temperature, history):
40
+ chatbot.append((parse_text(input), ""))
41
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
42
+ temperature=temperature):
43
+ chatbot[-1] = (parse_text(input), parse_text(response))
44
+
45
+ yield chatbot, history
46
+ response_new = ''
47
+ history = []
48
+ for chatbot, history in predict('请写一篇1000字的散文', chatbot=[], max_length=10000, top_p=0.5, temperature=0.5, history=history):
49
+ response_old = response_new
50
+ response_new = chatbot[0][1]
51
+ new_single = response_new.replace(response_old, '')
52
+ print(new_single,end='')
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ protobuf
2
+ transformers==4.27.1
3
+ cpm_kernels
4
+ torch>=1.10
5
+ gradio
6
+ mdtex2html
7
+ sentencepiece
8
+ accelerate
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Tuple, Union, Optional
3
+
4
+ from torch.nn import Module
5
+ from transformers import AutoModel
6
+
7
+
8
+ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
9
+ # transformer.word_embeddings 占用1层
10
+ # transformer.final_layernorm 和 lm_head 占用1层
11
+ # transformer.layers 占用 28 层
12
+ # 总共30层分配到num_gpus张卡上
13
+ num_trans_layers = 28
14
+ per_gpu_layers = 30 / num_gpus
15
+
16
+ # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
17
+ # windows下 model.device 会被设置成 transformer.word_embeddings.device
18
+ # linux下 model.device 会被设置成 lm_head.device
19
+ # 在调用chat或者stream_chat时,input_ids会被放到model.device上
20
+ # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
21
+ # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
22
+ device_map = {'transformer.word_embeddings': 0,
23
+ 'transformer.final_layernorm': 0, 'lm_head': 0}
24
+
25
+ used = 2
26
+ gpu_target = 0
27
+ for i in range(num_trans_layers):
28
+ if used >= per_gpu_layers:
29
+ gpu_target += 1
30
+ used = 0
31
+ assert gpu_target < num_gpus
32
+ device_map[f'transformer.layers.{i}'] = gpu_target
33
+ used += 1
34
+
35
+ return device_map
36
+
37
+
38
+ def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
39
+ device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
40
+ if num_gpus < 2 and device_map is None:
41
+ model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
42
+ else:
43
+ from accelerate import dispatch_model
44
+
45
+ model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
46
+
47
+ if device_map is None:
48
+ device_map = auto_configure_device_map(num_gpus)
49
+
50
+ model = dispatch_model(model, device_map=device_map)
51
+
52
+ return model
53
+
54
+
web_demo.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import gradio as gr
3
+ import mdtex2html
4
+
5
+ # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
6
+ # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
7
+ tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
8
+ model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
9
+
10
+ model = model.eval()
11
+
12
+ """Override Chatbot.postprocess"""
13
+
14
+
15
+ def postprocess(self, y):
16
+ if y is None:
17
+ return []
18
+ for i, (message, response) in enumerate(y):
19
+ y[i] = (
20
+ None if message is None else mdtex2html.convert((message)),
21
+ None if response is None else mdtex2html.convert(response),
22
+ )
23
+ return y
24
+
25
+
26
+ gr.Chatbot.postprocess = postprocess
27
+
28
+
29
+ def parse_text(text):
30
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
31
+ lines = text.split("\n")
32
+ lines = [line for line in lines if line != ""]
33
+ count = 0
34
+ for i, line in enumerate(lines):
35
+ if "```" in line:
36
+ count += 1
37
+ items = line.split('`')
38
+ if count % 2 == 1:
39
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
40
+ else:
41
+ lines[i] = f'<br></code></pre>'
42
+ else:
43
+ if i > 0:
44
+ if count % 2 == 1:
45
+ line = line.replace("`", "\`")
46
+ line = line.replace("<", "&lt;")
47
+ line = line.replace(">", "&gt;")
48
+ line = line.replace(" ", "&nbsp;")
49
+ line = line.replace("*", "&ast;")
50
+ line = line.replace("_", "&lowbar;")
51
+ line = line.replace("-", "&#45;")
52
+ line = line.replace(".", "&#46;")
53
+ line = line.replace("!", "&#33;")
54
+ line = line.replace("(", "&#40;")
55
+ line = line.replace(")", "&#41;")
56
+ line = line.replace("$", "&#36;")
57
+ lines[i] = "<br>"+line
58
+ text = "".join(lines)
59
+ return text
60
+
61
+
62
+ def predict(input, chatbot, max_length, top_p, temperature, history):
63
+ chatbot.append((parse_text(input), ""))
64
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
65
+ temperature=temperature):
66
+ chatbot[-1] = (parse_text(input), parse_text(response))
67
+
68
+ yield chatbot, history
69
+
70
+
71
+ def reset_user_input():
72
+ return gr.update(value='')
73
+
74
+
75
+ def reset_state():
76
+ return [], []
77
+
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.HTML("""<h1 align="center">ChatGLM</h1>""")
81
+
82
+ chatbot = gr.Chatbot()
83
+ with gr.Row():
84
+ with gr.Column(scale=4):
85
+ with gr.Column(scale=12):
86
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
87
+ container=False)
88
+ with gr.Column(min_width=32, scale=1):
89
+ submitBtn = gr.Button("Submit", variant="primary")
90
+ with gr.Column(scale=1):
91
+ emptyBtn = gr.Button("Clear History")
92
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
93
+ top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
94
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
95
+
96
+ history = gr.State([])
97
+
98
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
99
+ show_progress=True)
100
+ submitBtn.click(reset_user_input, [], [user_input])
101
+
102
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
103
+
104
+ demo.queue().launch(share=False, inbrowser=True)
web_demo2.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import streamlit as st
3
+ from streamlit_chat import message
4
+
5
+
6
+ st.set_page_config(
7
+ page_title="ChatGLM-6b 演示",
8
+ page_icon=":robot:"
9
+ )
10
+
11
+
12
+ @st.cache_resource
13
+ def get_model():
14
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
15
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
16
+ model = model.eval()
17
+ return tokenizer, model
18
+
19
+
20
+ MAX_TURNS = 20
21
+ MAX_BOXES = MAX_TURNS * 2
22
+
23
+
24
+ def predict(input, max_length, top_p, temperature, history=None):
25
+ tokenizer, model = get_model()
26
+ if history is None:
27
+ history = []
28
+
29
+ with container:
30
+ if len(history) > 0:
31
+ for i, (query, response) in enumerate(history):
32
+ message(query, avatar_style="big-smile", key=str(i) + "_user")
33
+ message(response, avatar_style="bottts", key=str(i))
34
+
35
+ message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
36
+ st.write("AI正在回复:")
37
+ with st.empty():
38
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
39
+ temperature=temperature):
40
+ query, response = history[-1]
41
+ st.write(response)
42
+
43
+ return history
44
+
45
+
46
+ container = st.container()
47
+
48
+ # create a prompt text for the text generation
49
+ prompt_text = st.text_area(label="用户命令输入",
50
+ height = 100,
51
+ placeholder="请在这儿输入您的命令")
52
+
53
+ max_length = st.sidebar.slider(
54
+ 'max_length', 0, 4096, 2048, step=1
55
+ )
56
+ top_p = st.sidebar.slider(
57
+ 'top_p', 0.0, 1.0, 0.6, step=0.01
58
+ )
59
+ temperature = st.sidebar.slider(
60
+ 'temperature', 0.0, 1.0, 0.95, step=0.01
61
+ )
62
+
63
+ if 'state' not in st.session_state:
64
+ st.session_state['state'] = []
65
+
66
+ if st.button("发送", key="predict"):
67
+ with st.spinner("AI正在思考,请稍等........"):
68
+ # text generation
69
+ st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
web_demo_old.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import gradio as gr
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
5
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
6
+ model = model.eval()
7
+
8
+ MAX_TURNS = 20
9
+ MAX_BOXES = MAX_TURNS * 2
10
+
11
+
12
+ def predict(input, max_length, top_p, temperature, history=None):
13
+ if history is None:
14
+ history = []
15
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
16
+ temperature=temperature):
17
+ updates = []
18
+ for query, response in history:
19
+ updates.append(gr.update(visible=True, value="用户:" + query))
20
+ updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
21
+ if len(updates) < MAX_BOXES:
22
+ updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
23
+ yield [history] + updates
24
+
25
+
26
+ with gr.Blocks() as demo:
27
+ state = gr.State([])
28
+ text_boxes = []
29
+ for i in range(MAX_BOXES):
30
+ if i % 2 == 0:
31
+ text_boxes.append(gr.Markdown(visible=False, label="提问:"))
32
+ else:
33
+ text_boxes.append(gr.Markdown(visible=False, label="回复:"))
34
+
35
+ with gr.Row():
36
+ with gr.Column(scale=4):
37
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
38
+ container=False)
39
+ with gr.Column(scale=1):
40
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
41
+ top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
42
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
43
+ button = gr.Button("Generate")
44
+ button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
45
+ demo.queue().launch(share=False, inbrowser=True)