JiangYH commited on
Commit
4ab98db
1 Parent(s): 6f179e7

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +6 -2
  2. run_gradio.sh +1 -1
  3. src/ChatWorld.py +0 -2
  4. src/DataBase/BaseDB.py +7 -1
app.py CHANGED
@@ -9,6 +9,10 @@ chatWorld = ChatWorld()
9
  role_name_list_global = None
10
  role_name_dict_global = None
11
 
 
 
 
 
12
 
13
  def getContent(input_file):
14
  # 读取文件内容
@@ -28,8 +32,8 @@ def getContent(input_file):
28
  role_name_dict_global = role_name_dict
29
 
30
  return (
31
- gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]),
32
- gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1]),
33
  )
34
 
35
 
 
9
  role_name_list_global = None
10
  role_name_dict_global = None
11
 
12
+ Meta = {
13
+ "uuid":"111"
14
+ }
15
+
16
 
17
  def getContent(input_file):
18
  # 读取文件内容
 
32
  role_name_dict_global = role_name_dict
33
 
34
  return (
35
+ gr.Radio(choices=role_name_list, interactive=True),
36
+ gr.Radio(choices=role_name_list, interactive=True),
37
  )
38
 
39
 
run_gradio.sh CHANGED
@@ -1,4 +1,4 @@
1
- export CUDA_VISIBLE_DEVICES=3
2
  export HF_ENDPOINT="https://hf-mirror.com"
3
 
4
  # Start the gradio server
 
1
+ export CUDA_VISIBLE_DEVICES=1
2
  export HF_ENDPOINT="https://hf-mirror.com"
3
 
4
  # Start the gradio server
src/ChatWorld.py CHANGED
@@ -17,8 +17,6 @@ class ChatWorld:
17
  ) -> None:
18
  self.model_name = pretrained_model_name_or_path
19
 
20
- self.global_batch_size = global_batch_size
21
-
22
  self.client = GLM_api()
23
 
24
  if model_load:
 
17
  ) -> None:
18
  self.model_name = pretrained_model_name_or_path
19
 
 
 
20
  self.client = GLM_api()
21
 
22
  if model_load:
src/DataBase/BaseDB.py CHANGED
@@ -5,6 +5,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from transformers import AutoTokenizer
6
  from langchain.text_splitter import TokenTextSplitter
7
  from langchain_core.documents import Document
 
8
 
9
 
10
  class BaseDB(metaclass=ABCMeta):
@@ -21,7 +22,12 @@ class BaseDB(metaclass=ABCMeta):
21
  if not embedding_name:
22
  embedding_name = "BAAI/bge-small-zh-v1.5"
23
 
24
- self.embedding = HuggingFaceEmbeddings(model_name=embedding_name)
 
 
 
 
 
25
  self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
26
 
27
  self.init_db()
 
5
  from transformers import AutoTokenizer
6
  from langchain.text_splitter import TokenTextSplitter
7
  from langchain_core.documents import Document
8
+ from torch.cuda import is_available
9
 
10
 
11
  class BaseDB(metaclass=ABCMeta):
 
22
  if not embedding_name:
23
  embedding_name = "BAAI/bge-small-zh-v1.5"
24
 
25
+ if is_available():
26
+ model_kwargs = {"device": "cuda"}
27
+ else:
28
+ model_kwargs = {"device": "cpu"}
29
+
30
+ self.embedding = HuggingFaceEmbeddings(model_name=embedding_name,model_kwargs=model_kwargs)
31
  self.tokenizer = AutoTokenizer.from_pretrained(embedding_name)
32
 
33
  self.init_db()