Spaces:
Runtime error
Runtime error
yuxj
commited on
Commit
•
4f65819
1
Parent(s):
9fb44c2
add files
Browse files- app.py +101 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
from duckduckgo_search import ddg
|
6 |
+
import time
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
def best_device():
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
return 'cuda'
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
return 'mps'
|
14 |
+
return 'cpu'
|
15 |
+
|
16 |
+
device = best_device()
|
17 |
+
embeddings = HuggingFaceEmbeddings(model_name = 'GanymedeNil/text2vec-large-chinese', model_kwargs={'device': device})
|
18 |
+
local_db = FAISS.load_local('/kaggle/input/text2vec', embeddings)
|
19 |
+
|
20 |
+
model_name = 'THUDM/chatglm-6b-int4'
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True)
|
22 |
+
if device == 'cuda':
|
23 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().cuda().eval()
|
24 |
+
elif device == 'mps':
|
25 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().to("mps").eval()
|
26 |
+
else:
|
27 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code = True).float().eval()
|
28 |
+
|
29 |
+
def local_query(text, top_k = 3):
|
30 |
+
docs_and_scores = local_db.similarity_search_with_score(text)
|
31 |
+
docs_and_scores.sort(key = lambda x : x[1])
|
32 |
+
local_content = ''
|
33 |
+
count = 0
|
34 |
+
for doc in docs_and_scores:
|
35 |
+
if count < top_k:
|
36 |
+
local_content += doc[0].page_content.replace(' ', '') + '\n'
|
37 |
+
count += 1
|
38 |
+
return local_content
|
39 |
+
|
40 |
+
def web_search(text, limit = 3):
|
41 |
+
web_content = ''
|
42 |
+
try:
|
43 |
+
results = ddg(text)
|
44 |
+
if results:
|
45 |
+
count = 0
|
46 |
+
for result in results:
|
47 |
+
if count < limit:
|
48 |
+
web_content += result['body'] + "\n"
|
49 |
+
count += 1
|
50 |
+
except Exception as e:
|
51 |
+
print(f"网络检索异常:{text}")
|
52 |
+
return web_content
|
53 |
+
|
54 |
+
def ask_question(question, local_content = '', web_content = ''):
|
55 |
+
question = f'简洁和专业的来回答我的问题。\n如果你不知道答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n我的问题是:\n{question}'
|
56 |
+
if len(web_content) > 0:
|
57 |
+
if len(local_content) > 0:
|
58 |
+
question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n{local_content}\n我的问题是:\n{question}'
|
59 |
+
else:
|
60 |
+
question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n我的问题是:\n{question}'
|
61 |
+
elif len(local_content) > 0:
|
62 |
+
question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{local_content}\n我的问题是:\n{question}'
|
63 |
+
response, history = model.chat(tokenizer, question, history = [], max_length = 10000, temperature = 0.1)
|
64 |
+
return response
|
65 |
+
|
66 |
+
def on_click(question, kb_types):
|
67 |
+
if best_device() == 'cuda':
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
|
70 |
+
print("问题 [" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "]: \n", question + "\n\n")
|
71 |
+
local_content = ''
|
72 |
+
if '结合本地数据' in kb_types:
|
73 |
+
local_content = local_query(question, 2)
|
74 |
+
web_content = ''
|
75 |
+
if '结合网络检索' in kb_types:
|
76 |
+
web_content = web_search(question, 3)
|
77 |
+
if len(local_content) > 0:
|
78 |
+
if len(web_content) > 0:
|
79 |
+
print('结合本地数据和网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
|
80 |
+
else:
|
81 |
+
print('结合本地数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
|
82 |
+
elif len(web_content) > 0:
|
83 |
+
print('结合网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
|
84 |
+
else:
|
85 |
+
print('仅用模型数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
|
86 |
+
result = ask_question(question, local_content, web_content)
|
87 |
+
print(f'{result}\n\n----------------------------')
|
88 |
+
|
89 |
+
if best_device() == 'cuda':
|
90 |
+
torch.cuda.empty_cache()
|
91 |
+
return result
|
92 |
+
|
93 |
+
with gr.Blocks() as block:
|
94 |
+
gr.Markdown('<center><h1>LLM问答机器人测试</h1></center>')
|
95 |
+
cg_type = gr.CheckboxGroup(['结合本地数据', '结合网络检索'], label = '知识库类型(不勾选则仅用模型数据):')
|
96 |
+
tb_input = gr.Textbox(label = '输入问题(本地数据只有中国历史知识):')
|
97 |
+
btn = gr.Button("测试", variant = 'primary')
|
98 |
+
tb_output = gr.Textbox(label = 'AI回答:')
|
99 |
+
btn.click(fn = on_click, inputs = [tb_input, cg_type], outputs = tb_output)
|
100 |
+
block.queue(concurrency_count = 1)
|
101 |
+
block.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
sentencepiece
|
3 |
+
cpm_kernels
|
4 |
+
accelerate
|
5 |
+
langchain
|
6 |
+
unstructured
|
7 |
+
sentence_transformers
|
8 |
+
duckduckgo_search
|
9 |
+
gradio
|
10 |
+
faiss-cpu
|