thomas-yanxin commited on
Commit
5d583ec
1 Parent(s): d246a39

增加Jina Embedding infernece

Browse files
Files changed (2) hide show
  1. app.py +50 -43
  2. requirements.txt +2 -1
app.py CHANGED
@@ -8,6 +8,7 @@ from duckduckgo_search import ddg
8
  from duckduckgo_search.utils import SESSION
9
  from langchain.chains import RetrievalQA
10
  from langchain.document_loaders import UnstructuredFileLoader
 
11
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
12
  from langchain.prompts import PromptTemplate
13
  from langchain.prompts.prompt import PromptTemplate
@@ -16,16 +17,13 @@ from langchain.vectorstores import FAISS
16
  from chatllm import ChatLLM
17
  from chinese_text_splitter import ChineseTextSplitter
18
 
19
- # os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
20
-
21
-
22
-
23
  nltk.data.path.append('./nltk_data')
24
 
25
  embedding_model_dict = {
26
  "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
27
  "ernie-base": "nghuyong/ernie-3.0-base-zh",
28
- "text2vec-base": "GanymedeNil/text2vec-base-chinese"
 
29
  }
30
 
31
  llm_model_dict = {
@@ -35,22 +33,23 @@ llm_model_dict = {
35
  "Minimax": "Minimax"
36
  }
37
 
38
-
39
  DEVICE = "cuda" if torch.cuda.is_available(
40
  ) else "mps" if torch.backends.mps.is_available() else "cpu"
41
 
 
42
  def search_web(query):
43
 
44
- SESSION.proxies = {
45
- "http": f"socks5h://localhost:7890",
46
- "https": f"socks5h://localhost:7890"
47
- }
48
- results = ddg(query)
49
- web_content = ''
50
- if results:
51
- for result in results:
52
- web_content += result['body']
53
- return web_content
 
54
 
55
  def load_file(filepath):
56
  if filepath.lower().endswith(".pdf"):
@@ -64,12 +63,17 @@ def load_file(filepath):
64
  return docs
65
 
66
 
67
-
68
  def init_knowledge_vector_store(embedding_model, filepath):
69
- embeddings = HuggingFaceEmbeddings(
70
- model_name=embedding_model_dict[embedding_model], )
71
- embeddings.client = sentence_transformers.SentenceTransformer(
72
- embeddings.model_name, device=DEVICE)
 
 
 
 
 
 
73
 
74
  docs = load_file(filepath)
75
 
@@ -110,7 +114,8 @@ def get_knowledge_based_answer(query,
110
  if large_language_model == "Minimax":
111
  chatLLM.model = 'Minimax'
112
  else:
113
- chatLLM.load_model(model_name_or_path=llm_model_dict[large_language_model])
 
114
  chatLLM.temperature = temperature
115
  chatLLM.top_p = top_p
116
 
@@ -185,26 +190,28 @@ if __name__ == "__main__":
185
  label="large language model",
186
  value="ChatGLM-6B-int4")
187
 
188
- embedding_model = gr.Dropdown(list(embedding_model_dict.keys()),
189
- label="Embedding model",
190
- value="text2vec-base")
 
191
 
192
  file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式',
193
  file_types=['.txt', '.md', '.docx'])
194
-
195
- use_web = gr.Radio(["True", "False"], label="Web Search",
196
- value="False"
197
- )
198
  model_argument = gr.Accordion("模型参数配置")
199
 
200
  with model_argument:
201
 
202
- VECTOR_SEARCH_TOP_K = gr.Slider(1,
203
- 10,
204
- value=6,
205
- step=1,
206
- label="vector search top k",
207
- interactive=True)
 
208
 
209
  HISTORY_LEN = gr.Slider(0,
210
  3,
@@ -220,12 +227,11 @@ if __name__ == "__main__":
220
  label="temperature",
221
  interactive=True)
222
  top_p = gr.Slider(0,
223
- 1,
224
- value=0.9,
225
- step=0.1,
226
- label="top_p",
227
- interactive=True)
228
-
229
 
230
  with gr.Column(scale=4):
231
  chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
@@ -240,7 +246,8 @@ if __name__ == "__main__":
240
  inputs=[
241
  message, large_language_model,
242
  embedding_model, file, VECTOR_SEARCH_TOP_K,
243
- HISTORY_LEN, temperature, top_p, use_web,state
 
244
  ],
245
  outputs=[message, chatbot, state])
246
  clear_history.click(fn=clear_session,
@@ -253,7 +260,7 @@ if __name__ == "__main__":
253
  message, large_language_model,
254
  embedding_model, file,
255
  VECTOR_SEARCH_TOP_K, HISTORY_LEN,
256
- temperature, top_p, use_web,state
257
  ],
258
  outputs=[message, chatbot, state])
259
  gr.Markdown("""提醒:<br>
 
8
  from duckduckgo_search.utils import SESSION
9
  from langchain.chains import RetrievalQA
10
  from langchain.document_loaders import UnstructuredFileLoader
11
+ from langchain.embeddings import JinaEmbeddings
12
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
13
  from langchain.prompts import PromptTemplate
14
  from langchain.prompts.prompt import PromptTemplate
 
17
  from chatllm import ChatLLM
18
  from chinese_text_splitter import ChineseTextSplitter
19
 
 
 
 
 
20
  nltk.data.path.append('./nltk_data')
21
 
22
  embedding_model_dict = {
23
  "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
24
  "ernie-base": "nghuyong/ernie-3.0-base-zh",
25
+ "text2vec-base": "GanymedeNil/text2vec-base-chinese",
26
+ "ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k'
27
  }
28
 
29
  llm_model_dict = {
 
33
  "Minimax": "Minimax"
34
  }
35
 
 
36
  DEVICE = "cuda" if torch.cuda.is_available(
37
  ) else "mps" if torch.backends.mps.is_available() else "cpu"
38
 
39
+
40
  def search_web(query):
41
 
42
+ SESSION.proxies = {
43
+ "http": f"socks5h://localhost:7890",
44
+ "https": f"socks5h://localhost:7890"
45
+ }
46
+ results = ddg(query)
47
+ web_content = ''
48
+ if results:
49
+ for result in results:
50
+ web_content += result['body']
51
+ return web_content
52
+
53
 
54
  def load_file(filepath):
55
  if filepath.lower().endswith(".pdf"):
 
63
  return docs
64
 
65
 
 
66
  def init_knowledge_vector_store(embedding_model, filepath):
67
+ if embedding_model == "ViT-B-32":
68
+ jina_auth_token = os.getenv('jina_auth_token')
69
+ embeddings = JinaEmbeddings(
70
+ jina_auth_token=jina_auth_token,
71
+ model_name=embedding_model_dict[embedding_model])
72
+ else:
73
+ embeddings = HuggingFaceEmbeddings(
74
+ model_name=embedding_model_dict[embedding_model], )
75
+ embeddings.client = sentence_transformers.SentenceTransformer(
76
+ embeddings.model_name, device=DEVICE)
77
 
78
  docs = load_file(filepath)
79
 
 
114
  if large_language_model == "Minimax":
115
  chatLLM.model = 'Minimax'
116
  else:
117
+ chatLLM.load_model(
118
+ model_name_or_path=llm_model_dict[large_language_model])
119
  chatLLM.temperature = temperature
120
  chatLLM.top_p = top_p
121
 
 
190
  label="large language model",
191
  value="ChatGLM-6B-int4")
192
 
193
+ embedding_model = gr.Dropdown(list(
194
+ embedding_model_dict.keys()),
195
+ label="Embedding model",
196
+ value="text2vec-base")
197
 
198
  file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式',
199
  file_types=['.txt', '.md', '.docx'])
200
+
201
+ use_web = gr.Radio(["True", "False"],
202
+ label="Web Search",
203
+ value="False")
204
  model_argument = gr.Accordion("模型参数配置")
205
 
206
  with model_argument:
207
 
208
+ VECTOR_SEARCH_TOP_K = gr.Slider(
209
+ 1,
210
+ 10,
211
+ value=6,
212
+ step=1,
213
+ label="vector search top k",
214
+ interactive=True)
215
 
216
  HISTORY_LEN = gr.Slider(0,
217
  3,
 
227
  label="temperature",
228
  interactive=True)
229
  top_p = gr.Slider(0,
230
+ 1,
231
+ value=0.9,
232
+ step=0.1,
233
+ label="top_p",
234
+ interactive=True)
 
235
 
236
  with gr.Column(scale=4):
237
  chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
 
246
  inputs=[
247
  message, large_language_model,
248
  embedding_model, file, VECTOR_SEARCH_TOP_K,
249
+ HISTORY_LEN, temperature, top_p, use_web,
250
+ state
251
  ],
252
  outputs=[message, chatbot, state])
253
  clear_history.click(fn=clear_session,
 
260
  message, large_language_model,
261
  embedding_model, file,
262
  VECTOR_SEARCH_TOP_K, HISTORY_LEN,
263
+ temperature, top_p, use_web, state
264
  ],
265
  outputs=[message, chatbot, state])
266
  gr.Markdown("""提醒:<br>
requirements.txt CHANGED
@@ -15,4 +15,5 @@ gradio
15
  nltk
16
  torch
17
  torchvision
18
-
 
 
15
  nltk
16
  torch
17
  torchvision
18
+ protobuf==3.19
19
+ jina