Upload folder using huggingface_hub
Browse files- .gitignore +6 -0
- README.md +7 -0
- app.py +100 -39
- main.py +95 -0
- requirements.txt +4 -0
- run_fastapi.sh +5 -0
- run_gradio.sh +3 -3
- src/ChatWorld.py +157 -0
- src/DataBase/BaseDB.py +61 -0
- src/DataBase/ChromaDB.py +49 -0
- src/DataBase/__init__.py +3 -0
- src/Models/__init__.py +3 -0
- src/Models/models.py +63 -0
- src/Response.py +12 -0
- src/__init__.py +3 -0
- src/logging.py +16 -0
- src/user.py +23 -0
- src/utils.py +27 -0
.gitignore
CHANGED
@@ -158,3 +158,9 @@ cython_debug/
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
.idea/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
.idea/
|
161 |
+
.vscode
|
162 |
+
data/
|
163 |
+
uploads/
|
164 |
+
|
165 |
+
clash*
|
166 |
+
*.yml
|
README.md
CHANGED
@@ -4,3 +4,10 @@ app_file: app.py
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.50.2
|
6 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.50.2
|
6 |
---
|
7 |
+
|
8 |
+
python 3.9.18
|
9 |
+
|
10 |
+
# TODO
|
11 |
+
- 对话流传输
|
12 |
+
- 持久化
|
13 |
+
- 多轮对话 历史
|
app.py
CHANGED
@@ -1,57 +1,87 @@
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
|
4 |
import gradio as gr
|
5 |
|
6 |
-
from
|
7 |
-
|
8 |
-
logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
|
9 |
-
format="%(asctime)s - %(name)s - %(levelname)-9s - %(filename)-8s : %(lineno)s line - %(message)s",
|
10 |
-
datefmt="%Y-%m-%d %H:%M:%S")
|
11 |
|
12 |
chatWorld = ChatWorld()
|
13 |
|
14 |
role_name_list_global = None
|
|
|
15 |
|
16 |
|
17 |
def getContent(input_file):
|
18 |
# 读取文件内容
|
19 |
-
with open(input_file.name,
|
20 |
logging.info(f"read file {input_file.name}")
|
21 |
input_text = f.read()
|
22 |
logging.info(f"file content: {input_text}")
|
23 |
|
24 |
-
|
25 |
-
input_text_list = input_text.split("\n")
|
26 |
-
chatWorld.initDB(input_text_list)
|
27 |
-
role_name_set = set()
|
28 |
-
|
29 |
-
# 读取角色名
|
30 |
-
for line in input_text_list:
|
31 |
-
role_name_set.add(line.split(":")[0])
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
global role_name_list_global
|
37 |
role_name_list_global = role_name_list
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
return response
|
48 |
|
49 |
|
50 |
-
def submit_message_api(
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return response
|
56 |
|
57 |
|
@@ -63,8 +93,13 @@ def get_role_list():
|
|
63 |
return []
|
64 |
|
65 |
|
66 |
-
|
|
|
|
|
|
|
67 |
|
|
|
|
|
68 |
upload_c = gr.File(label="上传文档文件")
|
69 |
|
70 |
with gr.Row():
|
@@ -75,15 +110,41 @@ with gr.Blocks() as demo:
|
|
75 |
role_name = gr.Radio(get_role_list(), label="角色名")
|
76 |
role_nickname = gr.Textbox(label="角色昵称")
|
77 |
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
with gr.Row():
|
82 |
chatBox_local = gr.ChatInterface(
|
83 |
-
submit_message,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
chatBox_api = gr.ChatInterface(
|
86 |
-
submit_message_api,
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
+
from src import ChatWorld
|
|
|
|
|
|
|
|
|
6 |
|
7 |
chatWorld = ChatWorld()
|
8 |
|
9 |
role_name_list_global = None
|
10 |
+
role_name_dict_global = None
|
11 |
|
12 |
|
13 |
def getContent(input_file):
|
14 |
# 读取文件内容
|
15 |
+
with open(input_file.name, "r", encoding="utf-8") as f:
|
16 |
logging.info(f"read file {input_file.name}")
|
17 |
input_text = f.read()
|
18 |
logging.info(f"file content: {input_text}")
|
19 |
|
20 |
+
chatWorld.setStory(stories=input_text, metas=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
# 保存文件内容
|
23 |
+
role_name_list, role_name_dict = chatWorld.getRoleNameFromFile(input_text)
|
24 |
|
25 |
global role_name_list_global
|
26 |
role_name_list_global = role_name_list
|
27 |
+
global role_name_dict_global
|
28 |
+
role_name_dict_global = role_name_dict
|
29 |
+
|
30 |
+
return (
|
31 |
+
gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]),
|
32 |
+
gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1]),
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def submit_message(
|
37 |
+
message,
|
38 |
+
history,
|
39 |
+
model_role_name,
|
40 |
+
role_name,
|
41 |
+
model_role_nickname,
|
42 |
+
role_nickname,
|
43 |
+
withCharacter,
|
44 |
+
):
|
45 |
+
if withCharacter:
|
46 |
+
response = chatWorld.chatWithCharacter(
|
47 |
+
text=message,
|
48 |
+
role_name=role_name,
|
49 |
+
role_nickname=role_nickname,
|
50 |
+
model_role_name=model_role_name,
|
51 |
+
model_role_nickname=model_role_nickname,
|
52 |
+
use_local_model=True,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
response = chatWorld.chatWithoutCharacter(
|
56 |
+
text=message,
|
57 |
+
use_local_model=True,
|
58 |
+
)
|
59 |
return response
|
60 |
|
61 |
|
62 |
+
def submit_message_api(
|
63 |
+
message,
|
64 |
+
history,
|
65 |
+
model_role_name,
|
66 |
+
role_name,
|
67 |
+
model_role_nickname,
|
68 |
+
role_nickname,
|
69 |
+
withCharacter,
|
70 |
+
):
|
71 |
+
if withCharacter:
|
72 |
+
response = chatWorld.chatWithCharacter(
|
73 |
+
text=message,
|
74 |
+
role_name=role_name,
|
75 |
+
role_nickname=role_nickname,
|
76 |
+
model_role_name=model_role_name,
|
77 |
+
model_role_nickname=model_role_nickname,
|
78 |
+
use_local_model=False,
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
response = chatWorld.chatWithoutCharacter(
|
82 |
+
text=message,
|
83 |
+
use_local_model=False,
|
84 |
+
)
|
85 |
return response
|
86 |
|
87 |
|
|
|
93 |
return []
|
94 |
|
95 |
|
96 |
+
def change_role_list(name):
|
97 |
+
global role_name_dict_global
|
98 |
+
|
99 |
+
return role_name_dict_global[name]
|
100 |
|
101 |
+
|
102 |
+
with gr.Blocks() as demo:
|
103 |
upload_c = gr.File(label="上传文档文件")
|
104 |
|
105 |
with gr.Row():
|
|
|
110 |
role_name = gr.Radio(get_role_list(), label="角色名")
|
111 |
role_nickname = gr.Textbox(label="角色昵称")
|
112 |
|
113 |
+
model_role_name.change(
|
114 |
+
fn=change_role_list, inputs=[model_role_name], outputs=[model_role_nickname]
|
115 |
+
)
|
116 |
+
role_name.change(fn=change_role_list, inputs=[role_name], outputs=[role_nickname])
|
117 |
+
|
118 |
+
upload_c.upload(
|
119 |
+
fn=getContent, inputs=upload_c, outputs=[model_role_name, role_name]
|
120 |
+
)
|
121 |
+
|
122 |
+
withCharacter = gr.Radio([True, False], value=True, label="是否进行角色扮演")
|
123 |
|
124 |
with gr.Row():
|
125 |
chatBox_local = gr.ChatInterface(
|
126 |
+
submit_message,
|
127 |
+
chatbot=gr.Chatbot(height=400, label="本地模型", render=False),
|
128 |
+
additional_inputs=[
|
129 |
+
model_role_name,
|
130 |
+
role_name,
|
131 |
+
model_role_nickname,
|
132 |
+
role_nickname,
|
133 |
+
withCharacter,
|
134 |
+
],
|
135 |
+
)
|
136 |
|
137 |
chatBox_api = gr.ChatInterface(
|
138 |
+
submit_message_api,
|
139 |
+
chatbot=gr.Chatbot(height=400, label="API模型", render=False),
|
140 |
+
additional_inputs=[
|
141 |
+
model_role_name,
|
142 |
+
role_name,
|
143 |
+
model_role_nickname,
|
144 |
+
role_nickname,
|
145 |
+
withCharacter,
|
146 |
+
],
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
demo.launch(share=True, server_name="0.0.0.0")
|
main.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from fastapi import Body, FastAPI, File, Form, HTTPException, Response, UploadFile
|
5 |
+
from fastapi.responses import JSONResponse
|
6 |
+
import uvicorn
|
7 |
+
|
8 |
+
from src import ChatWorld
|
9 |
+
from src.Response import ChatResponse, FileResponse
|
10 |
+
from src.logging import logging_info
|
11 |
+
from src.user import UUID, Role, User
|
12 |
+
from src.utils import convertToUTF8
|
13 |
+
|
14 |
+
app = FastAPI()
|
15 |
+
chatWorld = ChatWorld()
|
16 |
+
|
17 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
SAVE_DIR = "uploads"
|
19 |
+
|
20 |
+
user_info: dict[UUID, User] = dict()
|
21 |
+
|
22 |
+
|
23 |
+
@app.post("/uploadFile", response_model=FileResponse)
|
24 |
+
def upload_file(uuid: str = Form(), file: UploadFile = File(...)):
|
25 |
+
|
26 |
+
if not os.path.exists(os.path.join(BASE_DIR, SAVE_DIR)):
|
27 |
+
os.makedirs(os.path.join(BASE_DIR, SAVE_DIR))
|
28 |
+
|
29 |
+
file_name = f"{time.time_ns()}_{uuid}_{file.filename}"
|
30 |
+
file_path = os.path.join(BASE_DIR, SAVE_DIR, file_name)
|
31 |
+
|
32 |
+
file_content = file.file.read()
|
33 |
+
|
34 |
+
with open(file_path, "wb") as f:
|
35 |
+
f.write(file_content)
|
36 |
+
|
37 |
+
file_content_utf8 = convertToUTF8(file_content)
|
38 |
+
|
39 |
+
chatWorld.setStory(
|
40 |
+
stories=file_content_utf8,
|
41 |
+
metas={
|
42 |
+
"uuid": uuid,
|
43 |
+
},
|
44 |
+
)
|
45 |
+
|
46 |
+
user_info[uuid] = User()
|
47 |
+
print(user_info)
|
48 |
+
|
49 |
+
role_name_list,role_name_dict = chatWorld.getRoleNameFromFile(file_content_utf8)
|
50 |
+
return FileResponse(
|
51 |
+
filename=file_name,
|
52 |
+
role_name_list=role_name_list,
|
53 |
+
role_name_dict=role_name_dict,
|
54 |
+
md5=hashlib.md5(file_content).hexdigest(),
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
@app.post("/chatWithCharacter", response_model=ChatResponse)
|
59 |
+
def chatWithCharacter(
|
60 |
+
uuid: str = Body(...),
|
61 |
+
text: str = Body(...),
|
62 |
+
use_local_model: bool = Body(False),
|
63 |
+
top_k: int = Body(5),
|
64 |
+
role_info: Role = Body(...),
|
65 |
+
):
|
66 |
+
user = user_info.get(uuid)
|
67 |
+
|
68 |
+
if not user:
|
69 |
+
raise HTTPException(status_code=400, detail="User not found")
|
70 |
+
|
71 |
+
user_info[uuid] = user.update(role_info.model_dump())
|
72 |
+
logging_info(f"user_info: {user_info}")
|
73 |
+
|
74 |
+
response = chatWorld.chatWithCharacter(
|
75 |
+
text=text,
|
76 |
+
use_local_model=use_local_model,
|
77 |
+
top_k=top_k,
|
78 |
+
**role_info.model_dump(),
|
79 |
+
metas={"uuid": uuid},
|
80 |
+
)
|
81 |
+
|
82 |
+
return ChatResponse(response=response)
|
83 |
+
|
84 |
+
|
85 |
+
# @app.post("/chatWithoutCharacter")
|
86 |
+
# def chatWithoutCharacter(
|
87 |
+
# uuid: str = Body(...),
|
88 |
+
# text: str = Body(...),
|
89 |
+
# use_local_model: bool = Body(...),
|
90 |
+
# ):
|
91 |
+
# pass
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000)
|
requirements.txt
CHANGED
@@ -5,3 +5,7 @@ transformers==4.38.1
|
|
5 |
accelerate
|
6 |
zhipuai
|
7 |
sentencepiece
|
|
|
|
|
|
|
|
|
|
5 |
accelerate
|
6 |
zhipuai
|
7 |
sentencepiece
|
8 |
+
tiktoken
|
9 |
+
sentence-transformers
|
10 |
+
langchain
|
11 |
+
chromadb
|
run_fastapi.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export CUDA_VISIBLE_DEVICES=0
|
2 |
+
export HF_ENDPOINT="https://hf-mirror.com"
|
3 |
+
|
4 |
+
# Start the gradio server
|
5 |
+
python main.py
|
run_gradio.sh
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
export CUDA_VISIBLE_DEVICES=
|
2 |
-
export
|
3 |
|
4 |
# Start the gradio server
|
5 |
-
|
|
|
1 |
+
export CUDA_VISIBLE_DEVICES=3
|
2 |
+
export HF_ENDPOINT="https://hf-mirror.com"
|
3 |
|
4 |
# Start the gradio server
|
5 |
+
python app.py
|
src/ChatWorld.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from jinja2 import Template
|
3 |
+
from .DataBase import ChromaDB
|
4 |
+
|
5 |
+
from .Models import GLM, GLM_api
|
6 |
+
|
7 |
+
from .utils import *
|
8 |
+
|
9 |
+
|
10 |
+
class ChatWorld:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4",
|
14 |
+
embedding_model_name_or_path="BAAI/bge-small-zh-v1.5",
|
15 |
+
global_batch_size=16,
|
16 |
+
model_load=True,
|
17 |
+
) -> None:
|
18 |
+
self.model_name = pretrained_model_name_or_path
|
19 |
+
|
20 |
+
self.global_batch_size = global_batch_size
|
21 |
+
|
22 |
+
self.client = GLM_api()
|
23 |
+
|
24 |
+
if model_load:
|
25 |
+
self.model = GLM()
|
26 |
+
|
27 |
+
self.db = ChromaDB(embedding_model_name_or_path)
|
28 |
+
self.prompt = Template(
|
29 |
+
(
|
30 |
+
'Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
31 |
+
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
|
32 |
+
"下文给定了一些聊天记录,位于##分隔号中。\n"
|
33 |
+
"如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n"
|
34 |
+
"如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n"
|
35 |
+
"如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n"
|
36 |
+
"请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n"
|
37 |
+
"请你永远只以{{model_role_name}}身份,进行任何的回复。\n"
|
38 |
+
"{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}"
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def setStory(self, **stories_kargs):
|
43 |
+
self.db.deleteStoriesByMeta(metas=stories_kargs["metas"])
|
44 |
+
self.db.addStories(**stories_kargs)
|
45 |
+
|
46 |
+
def __getSystemPrompt(
|
47 |
+
self,
|
48 |
+
text: str,
|
49 |
+
top_k: int = 5,
|
50 |
+
metas=None,
|
51 |
+
**role_info,
|
52 |
+
):
|
53 |
+
|
54 |
+
rag = self.db.searchBySim(text, top_k, metas)
|
55 |
+
|
56 |
+
return {
|
57 |
+
"role": "system",
|
58 |
+
"content": self.prompt.render(
|
59 |
+
**role_info,
|
60 |
+
RAG=rag,
|
61 |
+
),
|
62 |
+
}
|
63 |
+
|
64 |
+
def chatWithCharacter(
|
65 |
+
self,
|
66 |
+
text: str,
|
67 |
+
system_prompt: dict[str, str] = None,
|
68 |
+
use_local_model: bool = False,
|
69 |
+
top_k: int = 5,
|
70 |
+
metas=None,
|
71 |
+
**role_info,
|
72 |
+
):
|
73 |
+
|
74 |
+
if not system_prompt:
|
75 |
+
system_prompt = self.__getSystemPrompt(
|
76 |
+
text=text, **role_info, top_k=top_k, metas=metas
|
77 |
+
)
|
78 |
+
|
79 |
+
user_role_name = role_info.get("role_name")
|
80 |
+
|
81 |
+
if not user_role_name:
|
82 |
+
raise ValueError("role_name is required")
|
83 |
+
|
84 |
+
message = [
|
85 |
+
system_prompt,
|
86 |
+
{"role": "user", "content": f"{user_role_name}:「{text}」"},
|
87 |
+
]
|
88 |
+
|
89 |
+
logging_info(f"message: {message}")
|
90 |
+
|
91 |
+
if use_local_model:
|
92 |
+
response = self.model.get_response(message)
|
93 |
+
else:
|
94 |
+
response = self.client.chat(message)
|
95 |
+
|
96 |
+
return response
|
97 |
+
|
98 |
+
def chatWithoutCharacter(
|
99 |
+
self,
|
100 |
+
text: str,
|
101 |
+
system_prompt: dict[str, str] = None,
|
102 |
+
use_local_model: bool = False,
|
103 |
+
):
|
104 |
+
|
105 |
+
logging_info(f"text: {text}")
|
106 |
+
|
107 |
+
message = [
|
108 |
+
{"role": "user", "content": f"{text}"},
|
109 |
+
]
|
110 |
+
|
111 |
+
if use_local_model:
|
112 |
+
response = self.model.get_response(text)
|
113 |
+
else:
|
114 |
+
|
115 |
+
response = self.client.chat(message)
|
116 |
+
|
117 |
+
return response
|
118 |
+
|
119 |
+
def getRoleNameFromFile(self, input_file: str):
|
120 |
+
# # 读取文件内容
|
121 |
+
# logging_info(f"file content: {input_file}")
|
122 |
+
|
123 |
+
# # 保存文件内容
|
124 |
+
# input_text_list = input_file.split("\n")
|
125 |
+
# role_name_set = set()
|
126 |
+
|
127 |
+
# # 读取角色名
|
128 |
+
# for line in input_text_list:
|
129 |
+
# role_name_set.add(line.split(":")[0])
|
130 |
+
|
131 |
+
# role_name_list = [i for i in role_name_set if i != ""]
|
132 |
+
# logging_info(f"role_name_list: {role_name_list}")
|
133 |
+
|
134 |
+
prompt = (
|
135 |
+
f"{input_file}\n"
|
136 |
+
+ '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```'
|
137 |
+
)
|
138 |
+
|
139 |
+
respense = self.chatWithoutCharacter(prompt, use_local_model=False)
|
140 |
+
|
141 |
+
json_start_index = respense.find("```json")
|
142 |
+
json_end_index = respense.find("```", json_start_index + 1)
|
143 |
+
|
144 |
+
json_str = respense[json_start_index + 7 : json_end_index]
|
145 |
+
|
146 |
+
print(json_str)
|
147 |
+
|
148 |
+
try:
|
149 |
+
json_str = json.loads(json_str)
|
150 |
+
role_name_list = [i["name"] for i in json_str]
|
151 |
+
role_name_dict = {i["name"]: i["nickname"] for i in json_str}
|
152 |
+
except Exception as e:
|
153 |
+
print(e)
|
154 |
+
role_name_list = []
|
155 |
+
role_name_dict = {}
|
156 |
+
|
157 |
+
return role_name_list, role_name_dict
|
src/DataBase/BaseDB.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from langchain.text_splitter import TokenTextSplitter
|
7 |
+
from langchain_core.documents import Document
|
8 |
+
|
9 |
+
|
10 |
+
class BaseDB(metaclass=ABCMeta):
|
11 |
+
def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.client = None
|
15 |
+
|
16 |
+
if persist_dir:
|
17 |
+
self.persist_dir = persist_dir
|
18 |
+
else:
|
19 |
+
self.persist_dir = "data"
|
20 |
+
|
21 |
+
if not embedding_name:
|
22 |
+
embedding_name = "BAAI/bge-small-zh-v1.5"
|
23 |
+
|
24 |
+
self.embedding = HuggingFaceEmbeddings(model_name=embedding_name)
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
|
26 |
+
|
27 |
+
self.init_db()
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def init_db(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def text_splitter(
|
34 |
+
self, text: Union[str, Document], chunk_size=300, chunk_overlap=10
|
35 |
+
):
|
36 |
+
if isinstance(text, Document):
|
37 |
+
return TokenTextSplitter.from_huggingface_tokenizer(
|
38 |
+
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
39 |
+
).split_documents(text)
|
40 |
+
elif isinstance(text, str):
|
41 |
+
return TokenTextSplitter.from_huggingface_tokenizer(
|
42 |
+
self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
43 |
+
).split_text(text)
|
44 |
+
else:
|
45 |
+
raise ValueError("text must be a str or Document")
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def addStories(self, stories, metas=None):
|
49 |
+
pass
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def deleteStoriesByMeta(self, metas):
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def searchBySim(self, query, n_results, metas, only_return_document=True):
|
57 |
+
pass
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def searchByMeta(self, metas=None):
|
61 |
+
pass
|
src/DataBase/ChromaDB.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from langchain_community.vectorstores.chroma import Chroma
|
3 |
+
|
4 |
+
from src.logging import logging_info
|
5 |
+
|
6 |
+
from .BaseDB import BaseDB
|
7 |
+
|
8 |
+
# TODO 数据库持久化 和 用户进入的加载。
|
9 |
+
|
10 |
+
|
11 |
+
class ChromaDB(BaseDB):
|
12 |
+
def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
|
13 |
+
super().__init__(embedding_name, persist_dir)
|
14 |
+
# logging_info(self.embedding)
|
15 |
+
|
16 |
+
def init_db(self):
|
17 |
+
self.client = Chroma(
|
18 |
+
persist_directory=self.persist_dir, embedding_function=self.embedding
|
19 |
+
)
|
20 |
+
|
21 |
+
def addStories(self, stories: str, metas: dict = None):
|
22 |
+
logging_info(self.text_splitter(stories)[-1])
|
23 |
+
|
24 |
+
split_stories = self.text_splitter(stories)
|
25 |
+
|
26 |
+
self.client.add_texts(
|
27 |
+
texts=split_stories, metadatas=[metas] * len(split_stories)
|
28 |
+
)
|
29 |
+
|
30 |
+
def searchBySim(
|
31 |
+
self, query, n_results=5, metas: dict = None, only_return_document=True
|
32 |
+
):
|
33 |
+
result = self.client.similarity_search(query, k=n_results, filter=metas)
|
34 |
+
|
35 |
+
# print(result)
|
36 |
+
|
37 |
+
if only_return_document:
|
38 |
+
return [i.page_content for i in result]
|
39 |
+
|
40 |
+
return result
|
41 |
+
|
42 |
+
def deleteStoriesByMeta(self, metas):
|
43 |
+
ids = self.searchByMeta(metas=metas)["ids"]
|
44 |
+
if ids:
|
45 |
+
self.client.delete(ids)
|
46 |
+
|
47 |
+
|
48 |
+
def searchByMeta(self, metas=None, include: list[str] = None) -> dict[str, any]:
|
49 |
+
return self.client.get(where=metas, include=include)
|
src/DataBase/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ChromaDB import ChromaDB
|
2 |
+
|
3 |
+
__all__ = ['ChromaDB']
|
src/Models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .models import GLM,GLM_api
|
2 |
+
|
3 |
+
__all__ = ["GLM", "GLM_api"]
|
src/Models/models.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from string import Template
|
3 |
+
from typing import Dict, List, Union
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from zhipuai import ZhipuAI
|
6 |
+
|
7 |
+
|
8 |
+
class GLM:
|
9 |
+
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
11 |
+
model_name, trust_remote_code=True
|
12 |
+
)
|
13 |
+
client = AutoModelForCausalLM.from_pretrained(
|
14 |
+
model_name, trust_remote_code=True, device_map="auto"
|
15 |
+
)
|
16 |
+
|
17 |
+
self.client = client.eval()
|
18 |
+
|
19 |
+
def message2query(self, messages) -> str:
|
20 |
+
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
21 |
+
# <|system|>
|
22 |
+
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
23 |
+
# <|user|>
|
24 |
+
# Hello
|
25 |
+
# <|assistant|>
|
26 |
+
# Hello, I'm ChatGLM3. What can I assist you today?
|
27 |
+
template = Template("<|$role|>\n$content\n")
|
28 |
+
|
29 |
+
return "".join([template.substitute(message) for message in messages])
|
30 |
+
|
31 |
+
def get_response(
|
32 |
+
self,
|
33 |
+
message: Union[str, list[dict[str, str]]],
|
34 |
+
history: List[Dict[str, str]] = None,
|
35 |
+
):
|
36 |
+
if isinstance(message, str):
|
37 |
+
response, history = self.client.chat(self.tokenizer, message)
|
38 |
+
elif isinstance(message, list):
|
39 |
+
response, history = self.client.chat(
|
40 |
+
self.tokenizer, message[-1]["content"],history=message[:-1]
|
41 |
+
)
|
42 |
+
# print(self.message2query(message))
|
43 |
+
print(response)
|
44 |
+
return response
|
45 |
+
|
46 |
+
|
47 |
+
class GLM_api:
|
48 |
+
def __init__(self, model_name="glm-4"):
|
49 |
+
API_KEY = os.environ.get("ZHIPU_API_KEY")
|
50 |
+
|
51 |
+
self.client = ZhipuAI(api_key=API_KEY)
|
52 |
+
self.model = model_name
|
53 |
+
|
54 |
+
def chat(self, message):
|
55 |
+
try:
|
56 |
+
response = self.client.chat.completions.create(
|
57 |
+
model=self.model, messages=message
|
58 |
+
)
|
59 |
+
except Exception as e:
|
60 |
+
print(e)
|
61 |
+
return "模型连接失败"
|
62 |
+
|
63 |
+
return response.choices[0].message.content
|
src/Response.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class ChatResponse(BaseModel):
|
5 |
+
response: str
|
6 |
+
|
7 |
+
|
8 |
+
class FileResponse(BaseModel):
|
9 |
+
filename: str
|
10 |
+
role_name_list: list[str] = []
|
11 |
+
role_name_dict: dict[str, str] = {}
|
12 |
+
md5: str = None
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ChatWorld import ChatWorld
|
2 |
+
|
3 |
+
__all__ = ['ChatWorld']
|
src/logging.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
logging.basicConfig(
|
5 |
+
level=logging.DEBUG,
|
6 |
+
filename="demo.log",
|
7 |
+
filemode="w",
|
8 |
+
format="%(asctime)s - %(name)s - %(levelname)-9s - %(filename)-8s : %(lineno)s line - %(message)s",
|
9 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
10 |
+
)
|
11 |
+
|
12 |
+
# Path: src/logging.py
|
13 |
+
|
14 |
+
|
15 |
+
def logging_info(text: str):
|
16 |
+
logging.info(text)
|
src/user.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
UUID = str
|
4 |
+
|
5 |
+
|
6 |
+
class Role(BaseModel):
|
7 |
+
role_name: str
|
8 |
+
role_nickname: str = None
|
9 |
+
model_role_name: str
|
10 |
+
model_role_nickname: str = None
|
11 |
+
|
12 |
+
|
13 |
+
class User:
|
14 |
+
history: list[str] = []
|
15 |
+
role_name: str = None
|
16 |
+
role_nickname: str = None
|
17 |
+
model_role_name: str = None
|
18 |
+
model_role_nickname: str = None
|
19 |
+
|
20 |
+
def update(self, new_properties: dict) -> "User":
|
21 |
+
for k, v in new_properties.items():
|
22 |
+
setattr(self, k, v)
|
23 |
+
return self
|
src/utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from jinja2 import Template
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
+
|
4 |
+
from .logging import logging_info
|
5 |
+
|
6 |
+
|
7 |
+
def initEmbedding(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
|
8 |
+
return AutoModel.from_pretrained(model_name, **model_wargs)
|
9 |
+
|
10 |
+
|
11 |
+
def initTokenizer(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
|
12 |
+
return AutoTokenizer.from_pretrained(model_name, **model_wargs)
|
13 |
+
|
14 |
+
|
15 |
+
def detectEncoding(b: bytes):
|
16 |
+
import chardet
|
17 |
+
|
18 |
+
logging_info(f"chardet.detect(b): {chardet.detect(b)}")
|
19 |
+
|
20 |
+
return chardet.detect(b)["encoding"]
|
21 |
+
|
22 |
+
|
23 |
+
def convertToUTF8(b: bytes):
|
24 |
+
if detectEncoding(b):
|
25 |
+
return b.decode(detectEncoding(b))
|
26 |
+
|
27 |
+
return b.decode("utf-8")
|