qingxu99 commited on
Commit
9d3b01a
1 Parent(s): 61ad51c

尝试加入jittor本地模型

Browse files
.gitignore CHANGED
@@ -146,3 +146,4 @@ debug*
146
  private*
147
  crazy_functions/test_project/pdf_and_word
148
  crazy_functions/test_samples
 
 
146
  private*
147
  crazy_functions/test_project/pdf_and_word
148
  crazy_functions/test_samples
149
+ request_llm/jittorllms
request_llm/bridge_jittorllms.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import time
4
+ import threading
5
+ import importlib
6
+ from toolbox import update_ui, get_conf
7
+ from multiprocessing import Process, Pipe
8
+
9
+ load_message = "jittorllms尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,jittorllms消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
10
+
11
+ #################################################################################
12
+ class GetGLMHandle(Process):
13
+ def __init__(self):
14
+ super().__init__(daemon=True)
15
+ self.parent, self.child = Pipe()
16
+ self.jittorllms_model = None
17
+ self.info = ""
18
+ self.success = True
19
+ self.check_dependency()
20
+ self.start()
21
+ self.threadLock = threading.Lock()
22
+
23
+ def check_dependency(self):
24
+ try:
25
+ import jittor
26
+ from .jittorllms.models import get_model
27
+ self.info = "依赖检测通过"
28
+ self.success = True
29
+ except:
30
+ self.info = r"缺少jittorllms的依赖,如果要使用jittorllms,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_jittorllms.txt`"+\
31
+ r"和`git clone https://gitlink.org.cn/jittor/JittorLLMs.git --depth 1 request_llm/jittorllms`两个指令来安装jittorllms的依赖(在项目根目录运行这两个指令)。"
32
+ self.success = False
33
+
34
+ def ready(self):
35
+ return self.jittorllms_model is not None
36
+
37
+ def run(self):
38
+ # 子进程执行
39
+ # 第一次运行,加载参数
40
+ def load_model():
41
+ import types
42
+ try:
43
+ if self.jittorllms_model is None:
44
+ device, = get_conf('LOCAL_MODEL_DEVICE')
45
+ from .jittorllms.models import get_model
46
+ # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
47
+ args_dict = {'model': 'chatglm', 'RUN_DEVICE':'cpu'}
48
+ self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
49
+ except:
50
+ self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
51
+ raise RuntimeError("不能正常加载jittorllms的参数!")
52
+
53
+ load_model()
54
+
55
+ # 进入任务等待状态
56
+ while True:
57
+ # 进入任务等待状态
58
+ kwargs = self.child.recv()
59
+ # 收到消息,开始请求
60
+ try:
61
+ for response, history in self.jittorllms_model.run_web_demo(kwargs['query'], kwargs['history']):
62
+ self.child.send(response)
63
+ except:
64
+ self.child.send('[Local Message] Call jittorllms fail.')
65
+ # 请求处理结束,开始下一个循环
66
+ self.child.send('[Finish]')
67
+
68
+ def stream_chat(self, **kwargs):
69
+ # 主进程执行
70
+ self.threadLock.acquire()
71
+ self.parent.send(kwargs)
72
+ while True:
73
+ res = self.parent.recv()
74
+ if res != '[Finish]':
75
+ yield res
76
+ else:
77
+ break
78
+ self.threadLock.release()
79
+
80
+ global glm_handle
81
+ glm_handle = None
82
+ #################################################################################
83
+ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
84
+ """
85
+ 多线程方法
86
+ 函数的说明请见 request_llm/bridge_all.py
87
+ """
88
+ global glm_handle
89
+ if glm_handle is None:
90
+ glm_handle = GetGLMHandle()
91
+ if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glm_handle.info
92
+ if not glm_handle.success:
93
+ error = glm_handle.info
94
+ glm_handle = None
95
+ raise RuntimeError(error)
96
+
97
+ # jittorllms 没有 sys_prompt 接口,因此把prompt加入 history
98
+ history_feedin = []
99
+ history_feedin.append(["What can I do?", sys_prompt])
100
+ for i in range(len(history)//2):
101
+ history_feedin.append([history[2*i], history[2*i+1]] )
102
+
103
+ watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
104
+ response = ""
105
+ for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
106
+ if len(observe_window) >= 1: observe_window[0] = response
107
+ if len(observe_window) >= 2:
108
+ if (time.time()-observe_window[1]) > watch_dog_patience:
109
+ raise RuntimeError("程序终止。")
110
+ return response
111
+
112
+
113
+
114
+ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
115
+ """
116
+ 单线程方法
117
+ 函数的说明请见 request_llm/bridge_all.py
118
+ """
119
+ chatbot.append((inputs, ""))
120
+
121
+ global glm_handle
122
+ if glm_handle is None:
123
+ glm_handle = GetGLMHandle()
124
+ chatbot[-1] = (inputs, load_message + "\n\n" + glm_handle.info)
125
+ yield from update_ui(chatbot=chatbot, history=[])
126
+ if not glm_handle.success:
127
+ glm_handle = None
128
+ return
129
+
130
+ if additional_fn is not None:
131
+ import core_functional
132
+ importlib.reload(core_functional) # 热更新prompt
133
+ core_functional = core_functional.get_core_functions()
134
+ if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
135
+ inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
136
+
137
+ # 处理历史信息
138
+ history_feedin = []
139
+ history_feedin.append(["What can I do?", system_prompt] )
140
+ for i in range(len(history)//2):
141
+ history_feedin.append([history[2*i], history[2*i+1]] )
142
+
143
+ # 开始接收jittorllms的回复
144
+ response = "[Local Message]: 等待jittorllms响应中 ..."
145
+ for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
146
+ chatbot[-1] = (inputs, response)
147
+ yield from update_ui(chatbot=chatbot, history=history)
148
+
149
+ # 总结输出
150
+ if response == "[Local Message]: 等待jittorllms响应中 ...":
151
+ response = "[Local Message]: jittorllms响应异常 ..."
152
+ history.extend([inputs, response])
153
+ yield from update_ui(chatbot=chatbot, history=history)
request_llm/requirements_jittorllms.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ jittor >= 1.3.7.9
2
+ jtorch >= 0.1.3
3
+ torch
4
+ torchvision
request_llm/test_llms.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 对各个llm模型进行单元测试
3
+ """
4
+ def validate_path():
5
+ import os, sys
6
+ dir_name = os.path.dirname(__file__)
7
+ root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
8
+ os.chdir(root_dir_assume)
9
+ sys.path.append(root_dir_assume)
10
+
11
+ validate_path() # validate path so you can run from base directory
12
+
13
+ from request_llm.bridge_jittorllms import predict_no_ui_long_connection
14
+
15
+ llm_kwargs = {
16
+ 'max_length': 512,
17
+ 'top_p': 1,
18
+ 'temperature': 1,
19
+ }
20
+
21
+ result = predict_no_ui_long_connection(inputs="你好",
22
+ llm_kwargs=llm_kwargs,
23
+ history=[],
24
+ sys_prompt="")
25
+
26
+ print('result')