Spaces:
Runtime error
Runtime error
NewBreaker
commited on
Commit
•
47b54c6
1
Parent(s):
6647827
first
Browse files- 1.py +14 -0
- 1wandb.py +7 -0
- 2 +1 -0
- PROJECT.md +18 -0
- README.md +0 -12
- api.py +60 -0
- api_use.py +14 -0
- app.py +35 -0
- cli_demo.py +57 -0
- demo_app.py +18 -0
- demo_mult_chats.py +67 -0
- demo_single_chat.py +52 -0
- requirements.txt +8 -0
- utils.py +54 -0
- web_demo.py +104 -0
- web_demo2.py +69 -0
- web_demo_old.py +45 -0
1.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
|
3 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
4 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
5 |
+
kernel_file = ".\\models\\chatglm-6b-int4\\quantization_kernels.so"
|
6 |
+
|
7 |
+
model = model.quantize(bits=4,kernel_file=kernel_file)
|
8 |
+
model = model.eval()
|
9 |
+
|
10 |
+
|
11 |
+
response, history = model.chat(tokenizer, "你好", history=[])
|
12 |
+
print(response)
|
13 |
+
|
14 |
+
|
1wandb.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
|
3 |
+
wandb.login(key="b88cdc5f017d4e8c7b6a07aec184f577942139de")
|
4 |
+
wandb.init(project="chatglm")
|
5 |
+
print(1111)
|
6 |
+
# import wandb
|
7 |
+
#
|
2
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
你好
|
PROJECT.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 友情链接
|
2 |
+
|
3 |
+
以下是部分基于本仓库开发的开源项目:
|
4 |
+
* [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
|
5 |
+
* [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
|
6 |
+
* [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)
|
7 |
+
* [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于本地知识的 ChatGLM 应用,基于LangChain
|
8 |
+
* [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
|
9 |
+
* [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
|
10 |
+
* [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
|
11 |
+
* [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
|
12 |
+
* [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
|
13 |
+
* [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
|
14 |
+
* [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
|
15 |
+
|
16 |
+
以下是部分针对本项目的教程/文档:
|
17 |
+
* [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
|
18 |
+
* [ChatGLM-6B 的部署与微调教程 @ModelWhale平台](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
|
README.md
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Chatglm 6b Int4
|
3 |
-
emoji: 📉
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.27.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
import uvicorn, json, datetime
|
4 |
+
import torch
|
5 |
+
|
6 |
+
DEVICE = "cuda"
|
7 |
+
DEVICE_ID = "0"
|
8 |
+
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
9 |
+
|
10 |
+
|
11 |
+
def torch_gc():
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
with torch.cuda.device(CUDA_DEVICE):
|
14 |
+
torch.cuda.empty_cache()
|
15 |
+
torch.cuda.ipc_collect()
|
16 |
+
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
|
20 |
+
|
21 |
+
@app.post("/")
|
22 |
+
async def create_item(request: Request):
|
23 |
+
global model, tokenizer
|
24 |
+
json_post_raw = await request.json()
|
25 |
+
json_post = json.dumps(json_post_raw)
|
26 |
+
json_post_list = json.loads(json_post)
|
27 |
+
prompt = json_post_list.get('prompt')
|
28 |
+
history = json_post_list.get('history')
|
29 |
+
max_length = json_post_list.get('max_length')
|
30 |
+
top_p = json_post_list.get('top_p')
|
31 |
+
temperature = json_post_list.get('temperature')
|
32 |
+
response, history = model.chat(tokenizer,
|
33 |
+
prompt,
|
34 |
+
history=history,
|
35 |
+
max_length=max_length if max_length else 2048,
|
36 |
+
top_p=top_p if top_p else 0.7,
|
37 |
+
temperature=temperature if temperature else 0.95)
|
38 |
+
now = datetime.datetime.now()
|
39 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
40 |
+
answer = {
|
41 |
+
"response": response,
|
42 |
+
"history": history,
|
43 |
+
"status": 200,
|
44 |
+
"time": time
|
45 |
+
}
|
46 |
+
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
|
47 |
+
print(log)
|
48 |
+
torch_gc()
|
49 |
+
return answer
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
54 |
+
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="")
|
56 |
+
model = AutoModel.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
57 |
+
|
58 |
+
|
59 |
+
model.eval()
|
60 |
+
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
api_use.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
|
4 |
+
url = 'http://127.0.0.1:8000'
|
5 |
+
headers = {
|
6 |
+
'Content-Type': 'application/json'
|
7 |
+
}
|
8 |
+
data = {
|
9 |
+
'prompt': '你好',
|
10 |
+
'history': []
|
11 |
+
}
|
12 |
+
|
13 |
+
response = requests.post(url=url, headers=headers, data=json.dumps(data))
|
14 |
+
print(response.json())
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
5 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda().cpu()
|
6 |
+
|
7 |
+
|
8 |
+
# from transformers import AutoTokenizer, AutoModel
|
9 |
+
# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
10 |
+
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
11 |
+
# model = model.eval()
|
12 |
+
|
13 |
+
|
14 |
+
# kernel_file = "./models/chatglm-6b-int4/quantization_kernels.so"
|
15 |
+
# tokenizer = AutoTokenizer.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="")
|
16 |
+
# model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
17 |
+
# model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").half()
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# model = model.quantize(bits=model_args.quantization_bit, kernel_file=kernel_file)
|
22 |
+
|
23 |
+
model = model.eval()
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def chat(msg):
|
28 |
+
history = []
|
29 |
+
response, history = model.chat(tokenizer, msg, history=history)
|
30 |
+
print("response:", response)
|
31 |
+
return response
|
32 |
+
|
33 |
+
|
34 |
+
iface = gr.Interface(fn=chat, inputs="text", outputs="text")
|
35 |
+
iface.launch()
|
cli_demo.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
import signal
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
7 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
8 |
+
model = model.eval()
|
9 |
+
|
10 |
+
os_name = platform.system()
|
11 |
+
clear_command = 'cls' if os_name == 'Windows' else 'clear'
|
12 |
+
stop_stream = False
|
13 |
+
|
14 |
+
|
15 |
+
def build_prompt(history):
|
16 |
+
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
|
17 |
+
for query, response in history:
|
18 |
+
prompt += f"\n\n用户:{query}"
|
19 |
+
prompt += f"\n\nChatGLM-6B:{response}"
|
20 |
+
return prompt
|
21 |
+
|
22 |
+
|
23 |
+
def signal_handler(signal, frame):
|
24 |
+
global stop_stream
|
25 |
+
stop_stream = True
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
history = []
|
30 |
+
global stop_stream
|
31 |
+
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
32 |
+
while True:
|
33 |
+
query = input("\n用户:")
|
34 |
+
if query.strip() == "stop":
|
35 |
+
break
|
36 |
+
if query.strip() == "clear":
|
37 |
+
history = []
|
38 |
+
os.system(clear_command)
|
39 |
+
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
40 |
+
continue
|
41 |
+
count = 0
|
42 |
+
for response, history in model.stream_chat(tokenizer, query, history=history):
|
43 |
+
if stop_stream:
|
44 |
+
stop_stream = False
|
45 |
+
break
|
46 |
+
else:
|
47 |
+
count += 1
|
48 |
+
if count % 8 == 0:
|
49 |
+
os.system(clear_command)
|
50 |
+
print(build_prompt(history), flush=True)
|
51 |
+
signal.signal(signal.SIGINT, signal_handler)
|
52 |
+
os.system(clear_command)
|
53 |
+
print(build_prompt(history), flush=True)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
demo_app.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
5 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
6 |
+
model = model.eval()
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def chat(msg):
|
11 |
+
history = []
|
12 |
+
response, history = model.chat(tokenizer, msg, history=history)
|
13 |
+
print("response:", response)
|
14 |
+
return response
|
15 |
+
|
16 |
+
|
17 |
+
iface = gr.Interface(fn=chat, inputs="text", outputs="text")
|
18 |
+
iface.launch()
|
demo_mult_chats.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
3 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
4 |
+
model = model.eval()
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def parse_text(text):
|
10 |
+
lines = text.split("\n")
|
11 |
+
lines = [line for line in lines if line != ""]
|
12 |
+
count = 0
|
13 |
+
for i, line in enumerate(lines):
|
14 |
+
if "```" in line:
|
15 |
+
count += 1
|
16 |
+
items = line.split('`')
|
17 |
+
if count % 2 == 1:
|
18 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
19 |
+
else:
|
20 |
+
lines[i] = f'<br></code></pre>'
|
21 |
+
else:
|
22 |
+
if i > 0:
|
23 |
+
if count % 2 == 1:
|
24 |
+
line = line.replace("`", "\`")
|
25 |
+
line = line.replace("<", "<")
|
26 |
+
line = line.replace(">", ">")
|
27 |
+
line = line.replace(" ", " ")
|
28 |
+
line = line.replace("*", "*")
|
29 |
+
line = line.replace("_", "_")
|
30 |
+
line = line.replace("-", "-")
|
31 |
+
line = line.replace(".", ".")
|
32 |
+
line = line.replace("!", "!")
|
33 |
+
line = line.replace("(", "(")
|
34 |
+
line = line.replace(")", ")")
|
35 |
+
line = line.replace("$", "$")
|
36 |
+
lines[i] = "<br>"+line
|
37 |
+
text = "".join(lines)
|
38 |
+
return text
|
39 |
+
|
40 |
+
def predict(input, chatbot, max_length, top_p, temperature, history):
|
41 |
+
chatbot.append((parse_text(input), ""))
|
42 |
+
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
43 |
+
temperature=temperature):
|
44 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
45 |
+
|
46 |
+
yield chatbot, history
|
47 |
+
|
48 |
+
response_new = ''
|
49 |
+
history = []
|
50 |
+
|
51 |
+
for i in range(3000):
|
52 |
+
length_history = len(history)
|
53 |
+
if (length_history > 5): # 如果对话长度太长,就把之前的遗忘掉
|
54 |
+
del history[0]
|
55 |
+
del history[0]
|
56 |
+
print('\033[1;31m{}\033[0m'.format('\nYou:'),end='')
|
57 |
+
msg = input()
|
58 |
+
print('\033[1;34m{}\033[0m'.format('ChatGLM:'),end='')
|
59 |
+
|
60 |
+
for chatbot, history in predict(input=msg, chatbot=[], max_length=10000, top_p=0.5, temperature=0.5, history=history):
|
61 |
+
response_old = response_new
|
62 |
+
response_new = chatbot[0][1]
|
63 |
+
new_single = response_new.replace(response_old, '')
|
64 |
+
print(new_single,end='')
|
65 |
+
|
66 |
+
|
67 |
+
|
demo_single_chat.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
3 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
4 |
+
kernel_file =
|
5 |
+
model = model.quantize(bits=4, kernel_file=kernel)
|
6 |
+
model = model.eval()
|
7 |
+
|
8 |
+
def parse_text(text):
|
9 |
+
lines = text.split("\n")
|
10 |
+
lines = [line for line in lines if line != ""]
|
11 |
+
count = 0
|
12 |
+
for i, line in enumerate(lines):
|
13 |
+
if "```" in line:
|
14 |
+
count += 1
|
15 |
+
items = line.split('`')
|
16 |
+
if count % 2 == 1:
|
17 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
18 |
+
else:
|
19 |
+
lines[i] = f'<br></code></pre>'
|
20 |
+
else:
|
21 |
+
if i > 0:
|
22 |
+
if count % 2 == 1:
|
23 |
+
line = line.replace("`", "\`")
|
24 |
+
line = line.replace("<", "<")
|
25 |
+
line = line.replace(">", ">")
|
26 |
+
line = line.replace(" ", " ")
|
27 |
+
line = line.replace("*", "*")
|
28 |
+
line = line.replace("_", "_")
|
29 |
+
line = line.replace("-", "-")
|
30 |
+
line = line.replace(".", ".")
|
31 |
+
line = line.replace("!", "!")
|
32 |
+
line = line.replace("(", "(")
|
33 |
+
line = line.replace(")", ")")
|
34 |
+
line = line.replace("$", "$")
|
35 |
+
lines[i] = "<br>"+line
|
36 |
+
text = "".join(lines)
|
37 |
+
return text
|
38 |
+
|
39 |
+
def predict(input, chatbot, max_length, top_p, temperature, history):
|
40 |
+
chatbot.append((parse_text(input), ""))
|
41 |
+
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
42 |
+
temperature=temperature):
|
43 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
44 |
+
|
45 |
+
yield chatbot, history
|
46 |
+
response_new = ''
|
47 |
+
history = []
|
48 |
+
for chatbot, history in predict('请写一篇1000字的散文', chatbot=[], max_length=10000, top_p=0.5, temperature=0.5, history=history):
|
49 |
+
response_old = response_new
|
50 |
+
response_new = chatbot[0][1]
|
51 |
+
new_single = response_new.replace(response_old, '')
|
52 |
+
print(new_single,end='')
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
protobuf
|
2 |
+
transformers==4.27.1
|
3 |
+
cpm_kernels
|
4 |
+
torch>=1.10
|
5 |
+
gradio
|
6 |
+
mdtex2html
|
7 |
+
sentencepiece
|
8 |
+
accelerate
|
utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Tuple, Union, Optional
|
3 |
+
|
4 |
+
from torch.nn import Module
|
5 |
+
from transformers import AutoModel
|
6 |
+
|
7 |
+
|
8 |
+
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
9 |
+
# transformer.word_embeddings 占用1层
|
10 |
+
# transformer.final_layernorm 和 lm_head 占用1层
|
11 |
+
# transformer.layers 占用 28 层
|
12 |
+
# 总共30层分配到num_gpus张卡上
|
13 |
+
num_trans_layers = 28
|
14 |
+
per_gpu_layers = 30 / num_gpus
|
15 |
+
|
16 |
+
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
17 |
+
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
18 |
+
# linux下 model.device 会被设置成 lm_head.device
|
19 |
+
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
20 |
+
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
21 |
+
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
22 |
+
device_map = {'transformer.word_embeddings': 0,
|
23 |
+
'transformer.final_layernorm': 0, 'lm_head': 0}
|
24 |
+
|
25 |
+
used = 2
|
26 |
+
gpu_target = 0
|
27 |
+
for i in range(num_trans_layers):
|
28 |
+
if used >= per_gpu_layers:
|
29 |
+
gpu_target += 1
|
30 |
+
used = 0
|
31 |
+
assert gpu_target < num_gpus
|
32 |
+
device_map[f'transformer.layers.{i}'] = gpu_target
|
33 |
+
used += 1
|
34 |
+
|
35 |
+
return device_map
|
36 |
+
|
37 |
+
|
38 |
+
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
|
39 |
+
device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
|
40 |
+
if num_gpus < 2 and device_map is None:
|
41 |
+
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
42 |
+
else:
|
43 |
+
from accelerate import dispatch_model
|
44 |
+
|
45 |
+
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
|
46 |
+
|
47 |
+
if device_map is None:
|
48 |
+
device_map = auto_configure_device_map(num_gpus)
|
49 |
+
|
50 |
+
model = dispatch_model(model, device_map=device_map)
|
51 |
+
|
52 |
+
return model
|
53 |
+
|
54 |
+
|
web_demo.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
import gradio as gr
|
3 |
+
import mdtex2html
|
4 |
+
|
5 |
+
# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
6 |
+
# model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
|
8 |
+
model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
|
9 |
+
|
10 |
+
model = model.eval()
|
11 |
+
|
12 |
+
"""Override Chatbot.postprocess"""
|
13 |
+
|
14 |
+
|
15 |
+
def postprocess(self, y):
|
16 |
+
if y is None:
|
17 |
+
return []
|
18 |
+
for i, (message, response) in enumerate(y):
|
19 |
+
y[i] = (
|
20 |
+
None if message is None else mdtex2html.convert((message)),
|
21 |
+
None if response is None else mdtex2html.convert(response),
|
22 |
+
)
|
23 |
+
return y
|
24 |
+
|
25 |
+
|
26 |
+
gr.Chatbot.postprocess = postprocess
|
27 |
+
|
28 |
+
|
29 |
+
def parse_text(text):
|
30 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
31 |
+
lines = text.split("\n")
|
32 |
+
lines = [line for line in lines if line != ""]
|
33 |
+
count = 0
|
34 |
+
for i, line in enumerate(lines):
|
35 |
+
if "```" in line:
|
36 |
+
count += 1
|
37 |
+
items = line.split('`')
|
38 |
+
if count % 2 == 1:
|
39 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
40 |
+
else:
|
41 |
+
lines[i] = f'<br></code></pre>'
|
42 |
+
else:
|
43 |
+
if i > 0:
|
44 |
+
if count % 2 == 1:
|
45 |
+
line = line.replace("`", "\`")
|
46 |
+
line = line.replace("<", "<")
|
47 |
+
line = line.replace(">", ">")
|
48 |
+
line = line.replace(" ", " ")
|
49 |
+
line = line.replace("*", "*")
|
50 |
+
line = line.replace("_", "_")
|
51 |
+
line = line.replace("-", "-")
|
52 |
+
line = line.replace(".", ".")
|
53 |
+
line = line.replace("!", "!")
|
54 |
+
line = line.replace("(", "(")
|
55 |
+
line = line.replace(")", ")")
|
56 |
+
line = line.replace("$", "$")
|
57 |
+
lines[i] = "<br>"+line
|
58 |
+
text = "".join(lines)
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
def predict(input, chatbot, max_length, top_p, temperature, history):
|
63 |
+
chatbot.append((parse_text(input), ""))
|
64 |
+
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
65 |
+
temperature=temperature):
|
66 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
67 |
+
|
68 |
+
yield chatbot, history
|
69 |
+
|
70 |
+
|
71 |
+
def reset_user_input():
|
72 |
+
return gr.update(value='')
|
73 |
+
|
74 |
+
|
75 |
+
def reset_state():
|
76 |
+
return [], []
|
77 |
+
|
78 |
+
|
79 |
+
with gr.Blocks() as demo:
|
80 |
+
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
|
81 |
+
|
82 |
+
chatbot = gr.Chatbot()
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column(scale=4):
|
85 |
+
with gr.Column(scale=12):
|
86 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
87 |
+
container=False)
|
88 |
+
with gr.Column(min_width=32, scale=1):
|
89 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
90 |
+
with gr.Column(scale=1):
|
91 |
+
emptyBtn = gr.Button("Clear History")
|
92 |
+
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
93 |
+
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
94 |
+
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
95 |
+
|
96 |
+
history = gr.State([])
|
97 |
+
|
98 |
+
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
|
99 |
+
show_progress=True)
|
100 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
101 |
+
|
102 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
103 |
+
|
104 |
+
demo.queue().launch(share=False, inbrowser=True)
|
web_demo2.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_chat import message
|
4 |
+
|
5 |
+
|
6 |
+
st.set_page_config(
|
7 |
+
page_title="ChatGLM-6b 演示",
|
8 |
+
page_icon=":robot:"
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
@st.cache_resource
|
13 |
+
def get_model():
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
15 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
16 |
+
model = model.eval()
|
17 |
+
return tokenizer, model
|
18 |
+
|
19 |
+
|
20 |
+
MAX_TURNS = 20
|
21 |
+
MAX_BOXES = MAX_TURNS * 2
|
22 |
+
|
23 |
+
|
24 |
+
def predict(input, max_length, top_p, temperature, history=None):
|
25 |
+
tokenizer, model = get_model()
|
26 |
+
if history is None:
|
27 |
+
history = []
|
28 |
+
|
29 |
+
with container:
|
30 |
+
if len(history) > 0:
|
31 |
+
for i, (query, response) in enumerate(history):
|
32 |
+
message(query, avatar_style="big-smile", key=str(i) + "_user")
|
33 |
+
message(response, avatar_style="bottts", key=str(i))
|
34 |
+
|
35 |
+
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
|
36 |
+
st.write("AI正在回复:")
|
37 |
+
with st.empty():
|
38 |
+
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
39 |
+
temperature=temperature):
|
40 |
+
query, response = history[-1]
|
41 |
+
st.write(response)
|
42 |
+
|
43 |
+
return history
|
44 |
+
|
45 |
+
|
46 |
+
container = st.container()
|
47 |
+
|
48 |
+
# create a prompt text for the text generation
|
49 |
+
prompt_text = st.text_area(label="用户命令输入",
|
50 |
+
height = 100,
|
51 |
+
placeholder="请在这儿输入您的命令")
|
52 |
+
|
53 |
+
max_length = st.sidebar.slider(
|
54 |
+
'max_length', 0, 4096, 2048, step=1
|
55 |
+
)
|
56 |
+
top_p = st.sidebar.slider(
|
57 |
+
'top_p', 0.0, 1.0, 0.6, step=0.01
|
58 |
+
)
|
59 |
+
temperature = st.sidebar.slider(
|
60 |
+
'temperature', 0.0, 1.0, 0.95, step=0.01
|
61 |
+
)
|
62 |
+
|
63 |
+
if 'state' not in st.session_state:
|
64 |
+
st.session_state['state'] = []
|
65 |
+
|
66 |
+
if st.button("发送", key="predict"):
|
67 |
+
with st.spinner("AI正在思考,请稍等........"):
|
68 |
+
# text generation
|
69 |
+
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
|
web_demo_old.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
5 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
|
6 |
+
model = model.eval()
|
7 |
+
|
8 |
+
MAX_TURNS = 20
|
9 |
+
MAX_BOXES = MAX_TURNS * 2
|
10 |
+
|
11 |
+
|
12 |
+
def predict(input, max_length, top_p, temperature, history=None):
|
13 |
+
if history is None:
|
14 |
+
history = []
|
15 |
+
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
|
16 |
+
temperature=temperature):
|
17 |
+
updates = []
|
18 |
+
for query, response in history:
|
19 |
+
updates.append(gr.update(visible=True, value="用户:" + query))
|
20 |
+
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
21 |
+
if len(updates) < MAX_BOXES:
|
22 |
+
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
23 |
+
yield [history] + updates
|
24 |
+
|
25 |
+
|
26 |
+
with gr.Blocks() as demo:
|
27 |
+
state = gr.State([])
|
28 |
+
text_boxes = []
|
29 |
+
for i in range(MAX_BOXES):
|
30 |
+
if i % 2 == 0:
|
31 |
+
text_boxes.append(gr.Markdown(visible=False, label="提问:"))
|
32 |
+
else:
|
33 |
+
text_boxes.append(gr.Markdown(visible=False, label="回复:"))
|
34 |
+
|
35 |
+
with gr.Row():
|
36 |
+
with gr.Column(scale=4):
|
37 |
+
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", lines=11).style(
|
38 |
+
container=False)
|
39 |
+
with gr.Column(scale=1):
|
40 |
+
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
41 |
+
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
42 |
+
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
43 |
+
button = gr.Button("Generate")
|
44 |
+
button.click(predict, [txt, max_length, top_p, temperature, state], [state] + text_boxes)
|
45 |
+
demo.queue().launch(share=False, inbrowser=True)
|