JiangYH commited on
Commit
0339b60
1 Parent(s): c6fc090

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
ChatWorld/ChatWorld.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jinja2 import Template
2
+ import torch
3
+
4
+ from .models import qwen_model
5
+
6
+ from .NaiveDB import NaiveDB
7
+ from .utils import *
8
+
9
+
10
+ class ChatWorld:
11
+ def __init__(self, pretrained_model_name_or_path="silk-road/Haruhi-Zero-14B-0_5", embedding_model_name_or_path="BAAI/bge-small-zh-v1.5") -> None:
12
+ self.embedding = initEmbedding(embedding_model_name_or_path)
13
+ self.tokenizer = initTokenizer(embedding_model_name_or_path)
14
+ self.story_vec: list[dict] = None
15
+ self.storys = None
16
+ self.model_role_name = None
17
+ self.model_role_nickname = None
18
+ self.model_name = pretrained_model_name_or_path
19
+
20
+ self.history = []
21
+
22
+ self.client = None
23
+ self.model = qwen_model(pretrained_model_name_or_path)
24
+ self.db = NaiveDB()
25
+ self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
26
+ '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
27
+ '下文给定了一些聊天记录,位于##分隔号中。\n'
28
+ '如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n'
29
+ '如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n'
30
+ '如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
31
+ '请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
32
+ '请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
33
+ ))
34
+
35
+ def getEmbeddingsFromStory(self, stories: list[str]):
36
+ if self.story_vec:
37
+ # 判断是否与当前的相同
38
+ if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
39
+ return [self.story_vec[i]["vec"] for i in range(len(stories))]
40
+
41
+ if self.embedding is None:
42
+ self.embedding = initEmbedding()
43
+
44
+ if self.tokenizer is None:
45
+ self.tokenizer = initTokenizer()
46
+
47
+ self.story_vec = []
48
+ for story in stories:
49
+ with torch.no_grad():
50
+ inputs = self.tokenizer(
51
+ story, return_tensors="pt", padding=True, truncation=True, max_length=512)
52
+ outputs = self.embedding(**inputs)[0][:, 0]
53
+ vec = torch.nn.functional.normalize(
54
+ outputs, p=2, dim=1).tolist()[0]
55
+
56
+ self.story_vec.append({"text": story, "vec": vec})
57
+
58
+ return [self.story_vec[i]["vec"] for i in range(len(stories))]
59
+
60
+ def initDB(self, storys: list[str]):
61
+ story_vecs = self.getEmbeddingsFromStory(storys)
62
+ self.db.build_db(storys, story_vecs)
63
+
64
+ def setRoleName(self, role_name, role_nick_name=None):
65
+ self.model_role_name = role_name
66
+ self.model_role_nickname = role_nick_name
67
+
68
+ def getSystemPrompt(self, role_name, role_nick_name):
69
+ assert self.model_role_name and self.model_role_nickname, "Please set model role name first"
70
+ return self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name)
71
+
72
+ def chat(self, user_role_name: str, text: str, user_role_nick_name: str = None, use_local_model=False):
73
+ message = [self.getSystemPrompt(
74
+ user_role_name, user_role_nick_name)] + self.history
75
+
76
+ if use_local_model:
77
+ response = self.model.get_response(message)
78
+ else:
79
+ response = self.client.chat(
80
+ user_role_name, text, user_role_nick_name)
81
+
82
+ self.history.append({"role": "user", "content": text})
83
+ self.history.append({"role": "model", "content": response})
84
+ return response
ChatWorld/NaiveDB.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from math import sqrt
3
+
4
+
5
+ class NaiveDB:
6
+ def __init__(self):
7
+ self.verbose = False
8
+ self.init_db()
9
+
10
+ def init_db(self):
11
+ if self.verbose:
12
+ print("call init_db")
13
+ self.stories = []
14
+ self.norms = []
15
+ self.vecs = []
16
+ self.flags = [] # 用于标记每个story是否可以被搜索
17
+ self.metas = [] # 用于存储每个story的meta信息
18
+ self.last_search_ids = [] # 用于存储上一次搜索的结果
19
+
20
+ def build_db(self, stories, vecs, flags=None, metas=None):
21
+ self.stories = stories
22
+ self.vecs = vecs
23
+ self.flags = flags if flags else [True for _ in self.stories]
24
+ self.metas = metas if metas else [{} for _ in self.stories]
25
+ self.recompute_norm()
26
+
27
+ def save(self, file_path):
28
+ print(
29
+ "warning! directly save folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead")
30
+
31
+ def load(self, file_path):
32
+ print(
33
+ "warning! directly load folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead")
34
+
35
+ def recompute_norm(self):
36
+ # 补全这部分代码,self.norms 分别存储每个vector的l2 norm
37
+ # 计算每个向量的L2范数
38
+ self.norms = [sqrt(sum([x ** 2 for x in vec])) for vec in self.vecs]
39
+
40
+ def get_stories_with_id(self, ids):
41
+ return [self.stories[i] for i in ids]
42
+
43
+ def clean_flag(self):
44
+ self.flags = [True for _ in self.stories]
45
+
46
+ def disable_story_with_ids(self, close_ids):
47
+ for id in close_ids:
48
+ self.flags[id] = False
49
+
50
+ def close_last_search(self):
51
+ for id in self.last_search_ids:
52
+ self.flags[id] = False
53
+
54
+ def search(self, query_vector, n_results):
55
+
56
+ if self.verbose:
57
+ print("call search")
58
+
59
+ if len(self.norms) != len(self.vecs):
60
+ self.recompute_norm()
61
+
62
+ # 计算查询向量的范数
63
+ query_norm = sqrt(sum([x ** 2 for x in query_vector]))
64
+
65
+ idxs = list(range(len(self.vecs)))
66
+
67
+ # 计算余弦相似度
68
+ similarities = []
69
+ for vec, norm, idx in zip(self.vecs, self.norms, idxs):
70
+ if len(self.flags) == len(self.vecs) and not self.flags[idx]:
71
+ continue
72
+
73
+ dot_product = sum(q * v for q, v in zip(query_vector, vec))
74
+ if query_norm < 1e-20:
75
+ similarities.append((random.random(), idx))
76
+ continue
77
+ cosine_similarity = dot_product / (query_norm * norm)
78
+ similarities.append((cosine_similarity, idx))
79
+
80
+ # 获取最相似的n_results个结果, 使用第0个字段进行排序
81
+ similarities.sort(key=lambda x: x[0], reverse=True)
82
+ self.last_search_ids = [x[1] for x in similarities[:n_results]]
83
+
84
+ top_indices = [x[1] for x in similarities[:n_results]]
85
+ return top_indices
ChatWorld/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ChatWorld import ChatWorld
ChatWorld/models.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+
4
+ class qwen_model:
5
+ def __init__(self, model_name):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
7
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval()
8
+
9
+ def get_response(self, message):
10
+ self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
11
+ return "test"
ChatWorld/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+
3
+
4
+ def initEmbedding(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
5
+ return AutoModel.from_pretrained(model_name, **model_wargs)
6
+
7
+
8
+ def initTokenizer(model_name="BAAI/bge-small-zh-v1.5", **model_wargs):
9
+ return AutoTokenizer.from_pretrained(model_name)
README.md CHANGED
@@ -1,13 +1,6 @@
1
  ---
2
  title: ChatWorld
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.18.0
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ChatWorld
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 3.50.2
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import gradio as gr
5
+
6
+ from ChatWorld import ChatWorld
7
+
8
+ logging.basicConfig(level=logging.INFO, filename="demo.log", filemode="w",
9
+ format="%(asctime)s - %(name)s - %(levelname)-9s - %(filename)-8s : %(lineno)s line - %(message)s",
10
+ datefmt="%Y-%m-%d %H:%M:%S")
11
+
12
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
13
+
14
+ chatWorld = ChatWorld()
15
+
16
+
17
+ def getContent(input_file):
18
+ # 读取文件内容
19
+ with open(input_file.name, 'r', encoding='utf-8') as f:
20
+ logging.info(f"read file {input_file.name}")
21
+ input_text = f.read()
22
+ logging.info(f"file content: {input_text}")
23
+
24
+ # 保存文件内容
25
+ input_text_list = input_text.split("\n")
26
+ chatWorld.initDB(input_text_list)
27
+ role_name_set = set()
28
+
29
+ # 读取角色名
30
+ for line in input_text_list:
31
+ role_name_set.add(line.split(":")[0])
32
+
33
+ role_name_list = [i for i in role_name_set if i != ""]
34
+ logging.info(f"role_name_list: {role_name_list}")
35
+
36
+ return gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]), gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1])
37
+
38
+
39
+ def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
40
+ chatWorld.setRoleName(model_role_name, model_role_nickname)
41
+ response = chatWorld.chat(
42
+ role_name, message, role_nickname, use_local_model=True)
43
+ return response
44
+
45
+
46
+ with gr.Blocks() as demo:
47
+
48
+ upload_c = gr.File(label="上传文档文件")
49
+
50
+ with gr.Row():
51
+ model_role_name = gr.Radio([], label="模型角色名")
52
+ model_role_nickname = gr.Textbox(label="模型角色昵称")
53
+
54
+ with gr.Row():
55
+ role_name = gr.Radio([], label="角色名")
56
+ role_nickname = gr.Textbox(label="角色昵称")
57
+
58
+ upload_c.upload(fn=getContent, inputs=upload_c,
59
+ outputs=[model_role_name, role_name])
60
+
61
+ chatBox = gr.ChatInterface(
62
+ submit_message, chatbot=gr.Chatbot(height=400, render=False), additional_inputs=[model_role_name, role_name, model_role_nickname, role_nickname])
63
+
64
+
65
+ demo.launch(share=True, debug=True, server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.19.2
2
+ Jinja2==3.1.3
3
+ torch==2.2.0
4
+ transformers==4.38.1
run_gradio.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0
2
+ export HF_HOME="/workspace/jyh/.cache/huggingface"
3
+
4
+ # Start the gradio server
5
+ /workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py