JiangYH commited on
Commit
6f179e7
1 Parent(s): 87818fb

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -158,3 +158,9 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  .idea/
 
 
 
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  .idea/
161
+ .vscode
162
+ data/
163
+ uploads/
164
+
165
+ clash*
166
+ *.yml
README.md CHANGED
@@ -4,3 +4,10 @@ app_file: app.py
4
  sdk: gradio
5
  sdk_version: 3.50.2
6
  ---
 
 
 
 
 
 
 
 
4
  sdk: gradio
5
  sdk_version: 3.50.2
6
  ---
7
+
8
+ python 3.9.18
9
+
10
+ # TODO
11
+ - 对话流传输
12
+ - 持久化
13
+ - 多轮对话 历史
app.py CHANGED
@@ -1,57 +1,87 @@
1
  import logging
2
- import os
3
 
4
  import gradio as gr
5
 
6
- from ChatWorld import ChatWorld
7
-
8
- logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
9
- format="%(asctime)s - %(name)s - %(levelname)-9s - %(filename)-8s : %(lineno)s line - %(message)s",
10
- datefmt="%Y-%m-%d %H:%M:%S")
11
 
12
  chatWorld = ChatWorld()
13
 
14
  role_name_list_global = None
 
15
 
16
 
17
  def getContent(input_file):
18
  # 读取文件内容
19
- with open(input_file.name, 'r', encoding='utf-8') as f:
20
  logging.info(f"read file {input_file.name}")
21
  input_text = f.read()
22
  logging.info(f"file content: {input_text}")
23
 
24
- # 保存文件内容
25
- input_text_list = input_text.split("\n")
26
- chatWorld.initDB(input_text_list)
27
- role_name_set = set()
28
-
29
- # 读取角色名
30
- for line in input_text_list:
31
- role_name_set.add(line.split(":")[0])
32
 
33
- role_name_list = [i for i in role_name_set if i != ""]
34
- logging.info(f"role_name_list: {role_name_list}")
35
 
36
  global role_name_list_global
37
  role_name_list_global = role_name_list
38
-
39
- return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
40
-
41
-
42
- def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
43
- print(f"history: {history}")
44
- chatWorld.setRoleName(model_role_name, model_role_nickname)
45
- response = chatWorld.chat(message,
46
- role_name, role_nickname, use_local_model=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return response
48
 
49
 
50
- def submit_message_api(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
51
- print(f"history: {history}")
52
- chatWorld.setRoleName(model_role_name, model_role_nickname)
53
- response = chatWorld.chat(message,
54
- role_name, role_nickname, use_local_model=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return response
56
 
57
 
@@ -63,8 +93,13 @@ def get_role_list():
63
  return []
64
 
65
 
66
- with gr.Blocks() as demo:
 
 
 
67
 
 
 
68
  upload_c = gr.File(label="上传文档文件")
69
 
70
  with gr.Row():
@@ -75,15 +110,41 @@ with gr.Blocks() as demo:
75
  role_name = gr.Radio(get_role_list(), label="角色名")
76
  role_nickname = gr.Textbox(label="角色昵称")
77
 
78
- upload_c.upload(fn=getContent, inputs=upload_c,
79
- outputs=[model_role_name, role_name])
 
 
 
 
 
 
 
 
80
 
81
  with gr.Row():
82
  chatBox_local = gr.ChatInterface(
83
- submit_message, chatbot=gr.Chatbot(height=400, label="本地模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
 
 
 
 
 
 
 
 
 
84
 
85
  chatBox_api = gr.ChatInterface(
86
- submit_message_api, chatbot=gr.Chatbot(height=400, label="API模型", render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
87
-
88
-
89
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
1
  import logging
 
2
 
3
  import gradio as gr
4
 
5
+ from src import ChatWorld
 
 
 
 
6
 
7
  chatWorld = ChatWorld()
8
 
9
  role_name_list_global = None
10
+ role_name_dict_global = None
11
 
12
 
13
  def getContent(input_file):
14
  # 读取文件内容
15
+ with open(input_file.name, "r", encoding="utf-8") as f:
16
  logging.info(f"read file {input_file.name}")
17
  input_text = f.read()
18
  logging.info(f"file content: {input_text}")
19
 
20
+ chatWorld.setStory(stories=input_text, metas=None)
 
 
 
 
 
 
 
21
 
22
+ # 保存文件内容
23
+ role_name_list, role_name_dict = chatWorld.getRoleNameFromFile(input_text)
24
 
25
  global role_name_list_global
26
  role_name_list_global = role_name_list
27
+ global role_name_dict_global
28
+ role_name_dict_global = role_name_dict
29
+
30
+ return (
31
+ gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]),
32
+ gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1]),
33
+ )
34
+
35
+
36
+ def submit_message(
37
+ message,
38
+ history,
39
+ model_role_name,
40
+ role_name,
41
+ model_role_nickname,
42
+ role_nickname,
43
+ withCharacter,
44
+ ):
45
+ if withCharacter:
46
+ response = chatWorld.chatWithCharacter(
47
+ text=message,
48
+ role_name=role_name,
49
+ role_nickname=role_nickname,
50
+ model_role_name=model_role_name,
51
+ model_role_nickname=model_role_nickname,
52
+ use_local_model=True,
53
+ )
54
+ else:
55
+ response = chatWorld.chatWithoutCharacter(
56
+ text=message,
57
+ use_local_model=True,
58
+ )
59
  return response
60
 
61
 
62
+ def submit_message_api(
63
+ message,
64
+ history,
65
+ model_role_name,
66
+ role_name,
67
+ model_role_nickname,
68
+ role_nickname,
69
+ withCharacter,
70
+ ):
71
+ if withCharacter:
72
+ response = chatWorld.chatWithCharacter(
73
+ text=message,
74
+ role_name=role_name,
75
+ role_nickname=role_nickname,
76
+ model_role_name=model_role_name,
77
+ model_role_nickname=model_role_nickname,
78
+ use_local_model=False,
79
+ )
80
+ else:
81
+ response = chatWorld.chatWithoutCharacter(
82
+ text=message,
83
+ use_local_model=False,
84
+ )
85
  return response
86
 
87
 
 
93
  return []
94
 
95
 
96
+ def change_role_list(name):
97
+ global role_name_dict_global
98
+
99
+ return role_name_dict_global[name]
100
 
101
+
102
+ with gr.Blocks() as demo:
103
  upload_c = gr.File(label="上传文档文件")
104
 
105
  with gr.Row():
 
110
  role_name = gr.Radio(get_role_list(), label="角色名")
111
  role_nickname = gr.Textbox(label="角色昵称")
112
 
113
+ model_role_name.change(
114
+ fn=change_role_list, inputs=[model_role_name], outputs=[model_role_nickname]
115
+ )
116
+ role_name.change(fn=change_role_list, inputs=[role_name], outputs=[role_nickname])
117
+
118
+ upload_c.upload(
119
+ fn=getContent, inputs=upload_c, outputs=[model_role_name, role_name]
120
+ )
121
+
122
+ withCharacter = gr.Radio([True, False], value=True, label="是否进行角色扮演")
123
 
124
  with gr.Row():
125
  chatBox_local = gr.ChatInterface(
126
+ submit_message,
127
+ chatbot=gr.Chatbot(height=400, label="本地模型", render=False),
128
+ additional_inputs=[
129
+ model_role_name,
130
+ role_name,
131
+ model_role_nickname,
132
+ role_nickname,
133
+ withCharacter,
134
+ ],
135
+ )
136
 
137
  chatBox_api = gr.ChatInterface(
138
+ submit_message_api,
139
+ chatbot=gr.Chatbot(height=400, label="API模型", render=False),
140
+ additional_inputs=[
141
+ model_role_name,
142
+ role_name,
143
+ model_role_nickname,
144
+ role_nickname,
145
+ withCharacter,
146
+ ],
147
+ )
148
+
149
+
150
+ demo.launch(share=True, server_name="0.0.0.0")
main.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import time
4
+ from fastapi import Body, FastAPI, File, Form, HTTPException, Response, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ import uvicorn
7
+
8
+ from src import ChatWorld
9
+ from src.Response import ChatResponse, FileResponse
10
+ from src.logging import logging_info
11
+ from src.user import UUID, Role, User
12
+ from src.utils import convertToUTF8
13
+
14
+ app = FastAPI()
15
+ chatWorld = ChatWorld()
16
+
17
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ SAVE_DIR = "uploads"
19
+
20
+ user_info: dict[UUID, User] = dict()
21
+
22
+
23
+ @app.post("/uploadFile", response_model=FileResponse)
24
+ def upload_file(uuid: str = Form(), file: UploadFile = File(...)):
25
+
26
+ if not os.path.exists(os.path.join(BASE_DIR, SAVE_DIR)):
27
+ os.makedirs(os.path.join(BASE_DIR, SAVE_DIR))
28
+
29
+ file_name = f"{time.time_ns()}_{uuid}_{file.filename}"
30
+ file_path = os.path.join(BASE_DIR, SAVE_DIR, file_name)
31
+
32
+ file_content = file.file.read()
33
+
34
+ with open(file_path, "wb") as f:
35
+ f.write(file_content)
36
+
37
+ file_content_utf8 = convertToUTF8(file_content)
38
+
39
+ chatWorld.setStory(
40
+ stories=file_content_utf8,
41
+ metas={
42
+ "uuid": uuid,
43
+ },
44
+ )
45
+
46
+ user_info[uuid] = User()
47
+ print(user_info)
48
+
49
+ role_name_list,role_name_dict = chatWorld.getRoleNameFromFile(file_content_utf8)
50
+ return FileResponse(
51
+ filename=file_name,
52
+ role_name_list=role_name_list,
53
+ role_name_dict=role_name_dict,
54
+ md5=hashlib.md5(file_content).hexdigest(),
55
+ )
56
+
57
+
58
+ @app.post("/chatWithCharacter", response_model=ChatResponse)
59
+ def chatWithCharacter(
60
+ uuid: str = Body(...),
61
+ text: str = Body(...),
62
+ use_local_model: bool = Body(False),
63
+ top_k: int = Body(5),
64
+ role_info: Role = Body(...),
65
+ ):
66
+ user = user_info.get(uuid)
67
+
68
+ if not user:
69
+ raise HTTPException(status_code=400, detail="User not found")
70
+
71
+ user_info[uuid] = user.update(role_info.model_dump())
72
+ logging_info(f"user_info: {user_info}")
73
+
74
+ response = chatWorld.chatWithCharacter(
75
+ text=text,
76
+ use_local_model=use_local_model,
77
+ top_k=top_k,
78
+ **role_info.model_dump(),
79
+ metas={"uuid": uuid},
80
+ )
81
+
82
+ return ChatResponse(response=response)
83
+
84
+
85
+ # @app.post("/chatWithoutCharacter")
86
+ # def chatWithoutCharacter(
87
+ # uuid: str = Body(...),
88
+ # text: str = Body(...),
89
+ # use_local_model: bool = Body(...),
90
+ # ):
91
+ # pass
92
+
93
+
94
+ if __name__ == "__main__":
95
+ uvicorn.run("main:app", host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -5,3 +5,7 @@ transformers==4.38.1
5
  accelerate
6
  zhipuai
7
  sentencepiece
 
 
 
 
 
5
  accelerate
6
  zhipuai
7
  sentencepiece
8
+ tiktoken
9
+ sentence-transformers
10
+ langchain
11
+ chromadb
run_fastapi.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0
2
+ export HF_ENDPOINT="https://hf-mirror.com"
3
+
4
+ # Start the gradio server
5
+ python main.py
run_gradio.sh CHANGED
@@ -1,5 +1,5 @@
1
- export CUDA_VISIBLE_DEVICES=0
2
- export HF_HOME="/workspace/jyh/.cache/huggingface"
3
 
4
  # Start the gradio server
5
- /workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
 
1
+ export CUDA_VISIBLE_DEVICES=3
2
+ export HF_ENDPOINT="https://hf-mirror.com"
3
 
4
  # Start the gradio server
5
+ python app.py
src/ChatWorld.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from jinja2 import Template
3
+ from .DataBase import ChromaDB
4
+
5
+ from .Models import GLM, GLM_api
6
+
7
+ from .utils import *
8
+
9
+
10
+ class ChatWorld:
11
+ def __init__(
12
+ self,
13
+ pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4",
14
+ embedding_model_name_or_path="BAAI/bge-small-zh-v1.5",
15
+ global_batch_size=16,
16
+ model_load=True,
17
+ ) -> None:
18
+ self.model_name = pretrained_model_name_or_path
19
+
20
+ self.global_batch_size = global_batch_size
21
+
22
+ self.client = GLM_api()
23
+
24
+ if model_load:
25
+ self.model = GLM()
26
+
27
+ self.db = ChromaDB(embedding_model_name_or_path)
28
+ self.prompt = Template(
29
+ (
30
+ 'Please be aware that your codename in this conversation is "{{model_role_name}}"'
31
+ '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
32
+ "下文给定了一些聊天记录,位于##分隔号中。\n"
33
+ "如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n"
34
+ "如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n"
35
+ "如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n"
36
+ "请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n"
37
+ "请你永远只以{{model_role_name}}身份,进行任何的回复。\n"
38
+ "{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}"
39
+ )
40
+ )
41
+
42
+ def setStory(self, **stories_kargs):
43
+ self.db.deleteStoriesByMeta(metas=stories_kargs["metas"])
44
+ self.db.addStories(**stories_kargs)
45
+
46
+ def __getSystemPrompt(
47
+ self,
48
+ text: str,
49
+ top_k: int = 5,
50
+ metas=None,
51
+ **role_info,
52
+ ):
53
+
54
+ rag = self.db.searchBySim(text, top_k, metas)
55
+
56
+ return {
57
+ "role": "system",
58
+ "content": self.prompt.render(
59
+ **role_info,
60
+ RAG=rag,
61
+ ),
62
+ }
63
+
64
+ def chatWithCharacter(
65
+ self,
66
+ text: str,
67
+ system_prompt: dict[str, str] = None,
68
+ use_local_model: bool = False,
69
+ top_k: int = 5,
70
+ metas=None,
71
+ **role_info,
72
+ ):
73
+
74
+ if not system_prompt:
75
+ system_prompt = self.__getSystemPrompt(
76
+ text=text, **role_info, top_k=top_k, metas=metas
77
+ )
78
+
79
+ user_role_name = role_info.get("role_name")
80
+
81
+ if not user_role_name:
82
+ raise ValueError("role_name is required")
83
+
84
+ message = [
85
+ system_prompt,
86
+ {"role": "user", "content": f"{user_role_name}:「{text}」"},
87
+ ]
88
+
89
+ logging_info(f"message: {message}")
90
+
91
+ if use_local_model:
92
+ response = self.model.get_response(message)
93
+ else:
94
+ response = self.client.chat(message)
95
+
96
+ return response
97
+
98
+ def chatWithoutCharacter(
99
+ self,
100
+ text: str,
101
+ system_prompt: dict[str, str] = None,
102
+ use_local_model: bool = False,
103
+ ):
104
+
105
+ logging_info(f"text: {text}")
106
+
107
+ message = [
108
+ {"role": "user", "content": f"{text}"},
109
+ ]
110
+
111
+ if use_local_model:
112
+ response = self.model.get_response(text)
113
+ else:
114
+
115
+ response = self.client.chat(message)
116
+
117
+ return response
118
+
119
+ def getRoleNameFromFile(self, input_file: str):
120
+ # # 读取文件内容
121
+ # logging_info(f"file content: {input_file}")
122
+
123
+ # # 保存文件内容
124
+ # input_text_list = input_file.split("\n")
125
+ # role_name_set = set()
126
+
127
+ # # 读取角色名
128
+ # for line in input_text_list:
129
+ # role_name_set.add(line.split(":")[0])
130
+
131
+ # role_name_list = [i for i in role_name_set if i != ""]
132
+ # logging_info(f"role_name_list: {role_name_list}")
133
+
134
+ prompt = (
135
+ f"{input_file}\n"
136
+ + '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```'
137
+ )
138
+
139
+ respense = self.chatWithoutCharacter(prompt, use_local_model=False)
140
+
141
+ json_start_index = respense.find("```json")
142
+ json_end_index = respense.find("```", json_start_index + 1)
143
+
144
+ json_str = respense[json_start_index + 7 : json_end_index]
145
+
146
+ print(json_str)
147
+
148
+ try:
149
+ json_str = json.loads(json_str)
150
+ role_name_list = [i["name"] for i in json_str]
151
+ role_name_dict = {i["name"]: i["nickname"] for i in json_str}
152
+ except Exception as e:
153
+ print(e)
154
+ role_name_list = []
155
+ role_name_dict = {}
156
+
157
+ return role_name_list, role_name_dict
src/DataBase/BaseDB.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ from typing import Union
3
+
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from transformers import AutoTokenizer
6
+ from langchain.text_splitter import TokenTextSplitter
7
+ from langchain_core.documents import Document
8
+
9
+
10
+ class BaseDB(metaclass=ABCMeta):
11
+ def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
12
+ super().__init__()
13
+
14
+ self.client = None
15
+
16
+ if persist_dir:
17
+ self.persist_dir = persist_dir
18
+ else:
19
+ self.persist_dir = "data"
20
+
21
+ if not embedding_name:
22
+ embedding_name = "BAAI/bge-small-zh-v1.5"
23
+
24
+ self.embedding = HuggingFaceEmbeddings(model_name=embedding_name)
25
+ self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
26
+
27
+ self.init_db()
28
+
29
+ @abstractmethod
30
+ def init_db(self):
31
+ pass
32
+
33
+ def text_splitter(
34
+ self, text: Union[str, Document], chunk_size=300, chunk_overlap=10
35
+ ):
36
+ if isinstance(text, Document):
37
+ return TokenTextSplitter.from_huggingface_tokenizer(
38
+ self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
39
+ ).split_documents(text)
40
+ elif isinstance(text, str):
41
+ return TokenTextSplitter.from_huggingface_tokenizer(
42
+ self.tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
43
+ ).split_text(text)
44
+ else:
45
+ raise ValueError("text must be a str or Document")
46
+
47
+ @abstractmethod
48
+ def addStories(self, stories, metas=None):
49
+ pass
50
+
51
+ @abstractmethod
52
+ def deleteStoriesByMeta(self, metas):
53
+ pass
54
+
55
+ @abstractmethod
56
+ def searchBySim(self, query, n_results, metas, only_return_document=True):
57
+ pass
58
+
59
+ @abstractmethod
60
+ def searchByMeta(self, metas=None):
61
+ pass
src/DataBase/ChromaDB.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from langchain_community.vectorstores.chroma import Chroma
3
+
4
+ from src.logging import logging_info
5
+
6
+ from .BaseDB import BaseDB
7
+
8
+ # TODO 数据库持久化 和 用户进入的加载。
9
+
10
+
11
+ class ChromaDB(BaseDB):
12
+ def __init__(self, embedding_name: str = None, persist_dir=None) -> None:
13
+ super().__init__(embedding_name, persist_dir)
14
+ # logging_info(self.embedding)
15
+
16
+ def init_db(self):
17
+ self.client = Chroma(
18
+ persist_directory=self.persist_dir, embedding_function=self.embedding
19
+ )
20
+
21
+ def addStories(self, stories: str, metas: dict = None):
22
+ logging_info(self.text_splitter(stories)[-1])
23
+
24
+ split_stories = self.text_splitter(stories)
25
+
26
+ self.client.add_texts(
27
+ texts=split_stories, metadatas=[metas] * len(split_stories)
28
+ )
29
+
30
+ def searchBySim(
31
+ self, query, n_results=5, metas: dict = None, only_return_document=True
32
+ ):
33
+ result = self.client.similarity_search(query, k=n_results, filter=metas)
34
+
35
+ # print(result)
36
+
37
+ if only_return_document:
38
+ return [i.page_content for i in result]
39
+
40
+ return result
41
+
42
+ def deleteStoriesByMeta(self, metas):
43
+ ids = self.searchByMeta(metas=metas)["ids"]
44
+ if ids:
45
+ self.client.delete(ids)
46
+
47
+
48
+ def searchByMeta(self, metas=None, include: list[str] = None) -> dict[str, any]:
49
+ return self.client.get(where=metas, include=include)
src/DataBase/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ChromaDB import ChromaDB
2
+
3
+ __all__ = ['ChromaDB']
src/Models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import GLM,GLM_api
2
+
3
+ __all__ = ["GLM", "GLM_api"]
src/Models/models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from string import Template
3
+ from typing import Dict, List, Union
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from zhipuai import ZhipuAI
6
+
7
+
8
+ class GLM:
9
+ def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(
11
+ model_name, trust_remote_code=True
12
+ )
13
+ client = AutoModelForCausalLM.from_pretrained(
14
+ model_name, trust_remote_code=True, device_map="auto"
15
+ )
16
+
17
+ self.client = client.eval()
18
+
19
+ def message2query(self, messages) -> str:
20
+ # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
21
+ # <|system|>
22
+ # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
23
+ # <|user|>
24
+ # Hello
25
+ # <|assistant|>
26
+ # Hello, I'm ChatGLM3. What can I assist you today?
27
+ template = Template("<|$role|>\n$content\n")
28
+
29
+ return "".join([template.substitute(message) for message in messages])
30
+
31
+ def get_response(
32
+ self,
33
+ message: Union[str, list[dict[str, str]]],
34
+ history: List[Dict[str, str]] = None,
35
+ ):
36
+ if isinstance(message, str):
37
+ response, history = self.client.chat(self.tokenizer, message)
38
+ elif isinstance(message, list):
39
+ response, history = self.client.chat(
40
+ self.tokenizer, message[-1]["content"],history=message[:-1]
41
+ )
42
+ # print(self.message2query(message))
43
+ print(response)
44
+ return response
45
+
46
+
47
+ class GLM_api:
48
+ def __init__(self, model_name="glm-4"):
49
+ API_KEY = os.environ.get("ZHIPU_API_KEY")
50
+
51
+ self.client = ZhipuAI(api_key=API_KEY)
52
+ self.model = model_name
53
+
54
+ def chat(self, message):
55
+ try:
56
+ response = self.client.chat.completions.create(
57
+ model=self.model, messages=message
58
+ )
59
+ except Exception as e:
60
+ print(e)
61
+ return "模型连接失败"
62
+
63
+ return response.choices[0].message.content
src/Response.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ChatResponse(BaseModel):
5
+ response: str
6
+
7
+
8
+ class FileResponse(BaseModel):
9
+ filename: str
10
+ role_name_list: list[str] = []
11
+ role_name_dict: dict[str, str] = {}
12
+ md5: str = None
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ChatWorld import ChatWorld
2
+
3
+ __all__ = ['ChatWorld']
src/logging.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ logging.basicConfig(
5
+ level=logging.DEBUG,
6
+ filename="demo.log",
7
+ filemode="w",
8
+ format="%(asctime)s - %(name)s - %(levelname)-9s - %(filename)-8s : %(lineno)s line - %(message)s",
9
+ datefmt="%Y-%m-%d %H:%M:%S",
10
+ )
11
+
12
+ # Path: src/logging.py
13
+
14
+
15
+ def logging_info(text: str):
16
+ logging.info(text)
src/user.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ UUID = str
4
+
5
+
6
+ class Role(BaseModel):
7
+ role_name: str
8
+ role_nickname: str = None
9
+ model_role_name: str
10
+ model_role_nickname: str = None
11
+
12
+
13
+ class User:
14
+ history: list[str] = []
15
+ role_name: str = None
16
+ role_nickname: str = None
17
+ model_role_name: str = None
18
+ model_role_nickname: str = None
19
+
20
+ def update(self, new_properties: dict) -> "User":
21
+ for k, v in new_properties.items():
22
+ setattr(self, k, v)
23
+ return self
src/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jinja2 import Template
2
+ from transformers import AutoModel, AutoTokenizer
3
+
4
+ from .logging import logging_info
5
+
6
+
7
+ def initEmbedding(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
8
+ return AutoModel.from_pretrained(model_name, **model_wargs)
9
+
10
+
11
+ def initTokenizer(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
12
+ return AutoTokenizer.from_pretrained(model_name, **model_wargs)
13
+
14
+
15
+ def detectEncoding(b: bytes):
16
+ import chardet
17
+
18
+ logging_info(f"chardet.detect(b): {chardet.detect(b)}")
19
+
20
+ return chardet.detect(b)["encoding"]
21
+
22
+
23
+ def convertToUTF8(b: bytes):
24
+ if detectEncoding(b):
25
+ return b.decode(detectEncoding(b))
26
+
27
+ return b.decode("utf-8")