Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio_client import Client
|
2 |
+
client = Client("https://svjack-entity-property-extractor-zh.hf.space")
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
|
9 |
+
from langchain.vectorstores import FAISS
|
10 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
11 |
+
from langchain import chains
|
12 |
+
from rapidfuzz import fuzz
|
13 |
+
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
|
16 |
+
if not os.path.exists("genshin_book_chunks_with_qa_sp"):
|
17 |
+
path = snapshot_download(
|
18 |
+
repo_id="svjack/genshin_book_chunks_with_qa_sp",
|
19 |
+
repo_type="dataset",
|
20 |
+
local_dir="genshin_book_chunks_with_qa_sp",
|
21 |
+
local_dir_use_symlinks = False
|
22 |
+
)
|
23 |
+
|
24 |
+
if not os.path.exists("bge_small_book_chunks_prebuld"):
|
25 |
+
path = snapshot_download(
|
26 |
+
repo_id="svjack/bge_small_book_chunks_prebuld",
|
27 |
+
repo_type="dataset",
|
28 |
+
local_dir="bge_small_book_chunks_prebuld",
|
29 |
+
local_dir_use_symlinks = False
|
30 |
+
)
|
31 |
+
|
32 |
+
if not os.path.exists("mistral-7b"):
|
33 |
+
path = snapshot_download(
|
34 |
+
repo_id="svjack/mistral-7b",
|
35 |
+
repo_type="model",
|
36 |
+
local_dir="mistral-7b",
|
37 |
+
local_dir_use_symlinks = False
|
38 |
+
)
|
39 |
+
|
40 |
+
'''
|
41 |
+
query = "警察是如何破获邪恶计划的?" ## 警 执律 盗
|
42 |
+
k = 10
|
43 |
+
uniform_recall_docs_to_pairwise_cos(
|
44 |
+
query,
|
45 |
+
docsearch_bge_loaded.similarity_search_with_score(query, k = k, ),
|
46 |
+
bge_book_embeddings
|
47 |
+
)
|
48 |
+
'''
|
49 |
+
def uniform_recall_docs_to_pairwise_cos(query ,doc_list, embeddings):
|
50 |
+
assert type(doc_list) == type([])
|
51 |
+
from langchain.evaluation import load_evaluator
|
52 |
+
from langchain.evaluation import EmbeddingDistance
|
53 |
+
hf_evaluator = load_evaluator("pairwise_embedding_distance", embeddings=embeddings,
|
54 |
+
distance_metric = EmbeddingDistance.COSINE)
|
55 |
+
return sorted(pd.Series(doc_list).map(lambda x: x[0].page_content).map(lambda x:
|
56 |
+
(x ,hf_evaluator.evaluate_string_pairs(prediction=query, prediction_b=x)["score"])
|
57 |
+
).values.tolist(), key = lambda t2: t2[1])
|
58 |
+
|
59 |
+
'''
|
60 |
+
sort_by_kw("深渊使徒", book_df)["content_chunks_formatted"].head(5).values.tolist() ### 深渊
|
61 |
+
'''
|
62 |
+
def sort_by_kw(kw, book_df):
|
63 |
+
req = book_df.copy()
|
64 |
+
req["sim_score"] = req.apply(
|
65 |
+
lambda x:
|
66 |
+
max(map(lambda y: fuzz.ratio(y, kw) ,eval(x["person"]) + eval(x["locate"]) + eval(x["locate"]))) if \
|
67 |
+
eval(x["person"]) + eval(x["locate"]) + eval(x["locate"]) else 0
|
68 |
+
, axis = 1
|
69 |
+
)
|
70 |
+
req = req.sort_values(by = "sim_score", ascending = False)
|
71 |
+
return req
|
72 |
+
|
73 |
+
def recall_chuncks(query, docsearch, embedding, book_df,
|
74 |
+
sparse_threshold = 30,
|
75 |
+
dense_top_k = 10,
|
76 |
+
rerank_by = "emb",
|
77 |
+
):
|
78 |
+
sparse_output = sort_by_kw(query, book_df)[["content_chunks_formatted", "sim_score"]]
|
79 |
+
sparse_output_list = sparse_output[
|
80 |
+
sparse_output["sim_score"] >= sparse_threshold
|
81 |
+
]["content_chunks_formatted"].values.tolist()
|
82 |
+
dense_output = uniform_recall_docs_to_pairwise_cos(
|
83 |
+
query,
|
84 |
+
docsearch.similarity_search_with_score(query, k = dense_top_k,),
|
85 |
+
embedding
|
86 |
+
)
|
87 |
+
for chunck, score in dense_output:
|
88 |
+
if chunck not in sparse_output_list:
|
89 |
+
sparse_output_list.append(chunck)
|
90 |
+
if rerank_by == "emb":
|
91 |
+
from langchain.evaluation import load_evaluator
|
92 |
+
from langchain.evaluation import EmbeddingDistance
|
93 |
+
hf_evaluator = load_evaluator("pairwise_embedding_distance", embeddings=embedding,
|
94 |
+
distance_metric = EmbeddingDistance.COSINE)
|
95 |
+
return pd.Series(sorted(pd.Series(sparse_output_list).map(lambda x:
|
96 |
+
(x ,hf_evaluator.evaluate_string_pairs(prediction=query, prediction_b=x)["score"])
|
97 |
+
).values.tolist(), key = lambda t2: t2[1])).map(lambda x: x[0]).values.tolist()
|
98 |
+
else:
|
99 |
+
sparse_output_list = sorted(sparse_output_list, key = lambda x: fuzz.ratio(x, query), reverse = True)
|
100 |
+
return sparse_output_list
|
101 |
+
|
102 |
+
def reduce_list_by_order(text_list, as_text = False):
|
103 |
+
if not text_list:
|
104 |
+
return
|
105 |
+
df = pd.DataFrame(
|
106 |
+
list(map(lambda x: (x.split("\n")[0], x.split("\n")[1], "\n".join(x.split("\n")[2:])), text_list))
|
107 |
+
).groupby([0, 1])[2].apply(list).map(lambda x: sorted(x, key = len, reverse=True)).map(
|
108 |
+
"\n\n".join
|
109 |
+
).reset_index()
|
110 |
+
d = dict(df.apply(lambda x: ((x.iloc[0], x.iloc[1]), x.iloc[2]), axis = 1).values.tolist())
|
111 |
+
#return df
|
112 |
+
order_list = []
|
113 |
+
for x in text_list:
|
114 |
+
a, b = x.split("\n")[0], x.split("\n")[1]
|
115 |
+
if not order_list:
|
116 |
+
order_list = [[a, [b]]]
|
117 |
+
elif a in list(map(lambda t2: t2[0], order_list)):
|
118 |
+
order_list[list(map(lambda t2: t2[0], order_list)).index(a)][1].append(b)
|
119 |
+
elif a not in list(map(lambda t2: t2[0], order_list)):
|
120 |
+
order_list.append([a, [b]])
|
121 |
+
df = pd.DataFrame(pd.DataFrame(order_list).explode(1).dropna().apply(
|
122 |
+
lambda x: (x.iloc[0], x.iloc[1], d[(x.iloc[0], x.iloc[1])]), axis = 1
|
123 |
+
).values.tolist()).drop_duplicates()
|
124 |
+
if as_text:
|
125 |
+
return "\n\n".join(
|
126 |
+
df.apply(lambda x: "{}\n{}\n{}".format(x.iloc[0], x.iloc[1], x.iloc[2]), axis = 1).values.tolist()
|
127 |
+
)
|
128 |
+
return df
|
129 |
+
|
130 |
+
def build_gpt_prompt(query, docsearch, embedding, book_df, max_context_length = 4090):
|
131 |
+
l = recall_chuncks(query, docsearch, embedding, book_df)
|
132 |
+
context = reduce_list_by_order(l, as_text = True)
|
133 |
+
context_l = []
|
134 |
+
for ele in context.split("\n"):
|
135 |
+
if sum(map(len, context_l)) >= max_context_length:
|
136 |
+
break
|
137 |
+
context_l.append(ele)
|
138 |
+
context = "\n".join(context_l).strip()
|
139 |
+
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答案。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
|
140 |
+
|
141 |
+
{context}
|
142 |
+
|
143 |
+
问题: {question}
|
144 |
+
有用的回答:"""
|
145 |
+
return template.format(
|
146 |
+
**{
|
147 |
+
"context": context,
|
148 |
+
"question": query
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
def collect_prompt_to_hist_list(prompt, add_assistant = False):
|
153 |
+
l = pd.Series(prompt.split("\n\n")).map(lambda x: x.strip()).values.tolist()
|
154 |
+
ll = []
|
155 |
+
for ele in l:
|
156 |
+
if not ll:
|
157 |
+
ll.append(ele)
|
158 |
+
else:
|
159 |
+
if ele.startswith("文章标题:") or ele.startswith("问题:"):
|
160 |
+
ll.append(ele)
|
161 |
+
else:
|
162 |
+
ll[-1] += ("\n\n" + ele)
|
163 |
+
if add_assistant:
|
164 |
+
ll_ = []
|
165 |
+
for i in range(len(ll)):
|
166 |
+
if i == 0:
|
167 |
+
ll_.append((ll[i], "好的。"))
|
168 |
+
elif i < len(ll) - 1:
|
169 |
+
ll_.append((ll[i], "我读懂了。"))
|
170 |
+
else:
|
171 |
+
ll_.append((ll[i], ""))
|
172 |
+
return ll_
|
173 |
+
else:
|
174 |
+
return ll
|
175 |
+
|
176 |
+
def row_to_content_ask(r):
|
177 |
+
question = r["question"]
|
178 |
+
content_list = r["content_list"]
|
179 |
+
assert type(content_list) == type([])
|
180 |
+
content_prompt_list = pd.Series(content_list).map(
|
181 |
+
lambda x: '''
|
182 |
+
{}\n从上面的相关的叙述中抽取包含"{}"中词汇的相关语段。
|
183 |
+
'''.format(x, question).strip()
|
184 |
+
).values.tolist()
|
185 |
+
return content_prompt_list
|
186 |
+
|
187 |
+
def entity_extractor_by_llm(query, llm, show_process = False, max_length = 512):
|
188 |
+
import re
|
189 |
+
hist = [['请从下面的句子中提取实体和属性。不需要进行进一步解释。', '好的。'],
|
190 |
+
['宁波在哪个省份?', '实体:宁波 属性:省份'],
|
191 |
+
['中国的货币是什么?', '实体:中国 属性:货币'],
|
192 |
+
['百慕大三角在什么地方?', '实体:百慕大三角 属性:地方'],
|
193 |
+
['谁是最可爱的人?', "实体:人 属性:可爱"],
|
194 |
+
['黄河的拐点在哪里?', "实体:黄河 属性:拐点"],
|
195 |
+
#["玉米的引进时间是什么时候?", ""]
|
196 |
+
]
|
197 |
+
|
198 |
+
re_hist = pd.DataFrame(
|
199 |
+
pd.Series(hist).map(
|
200 |
+
lambda t2: (
|
201 |
+
{
|
202 |
+
"role": "user",
|
203 |
+
"content": t2[0]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"role": "assistant",
|
207 |
+
"content": t2[1]
|
208 |
+
},
|
209 |
+
)
|
210 |
+
).values.tolist()).values.reshape([-1]).tolist()
|
211 |
+
|
212 |
+
out = llm.create_chat_completion(
|
213 |
+
messages = re_hist + [
|
214 |
+
{
|
215 |
+
"role": "user",
|
216 |
+
#"content": prompt + "如果没有提到相关内容,请回答不知道。使用中文进行回答,不要包含任何英文。"
|
217 |
+
"content": query
|
218 |
+
}
|
219 |
+
],
|
220 |
+
stream=True
|
221 |
+
)
|
222 |
+
out_text = ""
|
223 |
+
for chunk in out:
|
224 |
+
delta = chunk['choices'][0]['delta']
|
225 |
+
if "content" in delta:
|
226 |
+
out_text += delta['content']
|
227 |
+
from IPython.display import clear_output
|
228 |
+
clear_output(wait=True)
|
229 |
+
if show_process:
|
230 |
+
print(out_text)
|
231 |
+
if len(out_text) >= max_length:
|
232 |
+
break
|
233 |
+
e_list = re.findall(r"实体(.*?)属性", out_text.replace("\n", " "))
|
234 |
+
if e_list:
|
235 |
+
return re.findall(u"[\u4e00-\u9fa5]+" ,e_list[0])
|
236 |
+
return None
|
237 |
+
|
238 |
+
def unzip_string(x, size = 2):
|
239 |
+
if len(x) <= size:
|
240 |
+
return [x]
|
241 |
+
req = []
|
242 |
+
for i in range(len(x) - size + 1):
|
243 |
+
req.append(x[i: i + size])
|
244 |
+
return req
|
245 |
+
|
246 |
+
def entity_extractor_by_adapter(x):
|
247 |
+
import json
|
248 |
+
result = client.predict(
|
249 |
+
x, # str in 'question' Textbox component
|
250 |
+
api_name="/predict"
|
251 |
+
)
|
252 |
+
with open(result, "r") as f:
|
253 |
+
req = json.load(f)
|
254 |
+
req_list = req.get("E-TAG", [])
|
255 |
+
req_ = []
|
256 |
+
for ele in req_list:
|
257 |
+
for x in unzip_string(ele, 2):
|
258 |
+
if x not in req_:
|
259 |
+
req_.append(x)
|
260 |
+
return req_
|
261 |
+
|
262 |
+
def query_content_ask_func(question, content_list,
|
263 |
+
llm, setfit_model, show_process = False, max_length = 1024):
|
264 |
+
ask_list = row_to_content_ask(
|
265 |
+
{
|
266 |
+
"question": question,
|
267 |
+
"content_list": content_list
|
268 |
+
}
|
269 |
+
)
|
270 |
+
#return ask_list
|
271 |
+
req = []
|
272 |
+
for prompt in ask_list:
|
273 |
+
out = llm.create_chat_completion(
|
274 |
+
messages = [
|
275 |
+
{
|
276 |
+
"role": "user",
|
277 |
+
"content": prompt + "如果没有提到相关内容,请回答不知道。使用中文进行回答,不要包含任何英文。"
|
278 |
+
}
|
279 |
+
],
|
280 |
+
stream=True
|
281 |
+
)
|
282 |
+
out_text = ""
|
283 |
+
for chunk in out:
|
284 |
+
delta = chunk['choices'][0]['delta']
|
285 |
+
if "content" in delta:
|
286 |
+
out_text += delta['content']
|
287 |
+
from IPython.display import clear_output
|
288 |
+
clear_output(wait=True)
|
289 |
+
if show_process:
|
290 |
+
print(out_text)
|
291 |
+
if len(out_text) >= max_length:
|
292 |
+
break
|
293 |
+
req.append(out_text)
|
294 |
+
d = {
|
295 |
+
"question": question,
|
296 |
+
"content_list": content_list
|
297 |
+
}
|
298 |
+
assert len(req) == len(ask_list)
|
299 |
+
d["question_content_relate_list"] = req
|
300 |
+
d["relate_prob_list"] = setfit_model.predict_proba(
|
301 |
+
req
|
302 |
+
).numpy()[:, 1].tolist()
|
303 |
+
return d
|
304 |
+
|
305 |
+
def build_relate_ask_list(query, docsearch_bge_loaded, bge_book_embeddings, book_df,
|
306 |
+
llm, setfit_model, as_content_score_df = True,
|
307 |
+
show_process = False, add_relate_entities = False,
|
308 |
+
max_length = 1024):
|
309 |
+
prompt = build_gpt_prompt(query, docsearch_bge_loaded, bge_book_embeddings, book_df)
|
310 |
+
prompt_list = collect_prompt_to_hist_list(prompt)
|
311 |
+
question = prompt_list[-1].split("\n")[0]
|
312 |
+
content_list = prompt_list[1:-1]
|
313 |
+
|
314 |
+
d = query_content_ask_func(question, content_list,
|
315 |
+
llm, setfit_model, show_process = show_process)
|
316 |
+
|
317 |
+
#entity_list = entity_extractor_by_llm(query, llm, show_process = show_process, max_length = max_length)
|
318 |
+
entity_list = entity_extractor_by_adapter(query)
|
319 |
+
if type(entity_list) != type([]):
|
320 |
+
entity_list = []
|
321 |
+
|
322 |
+
d["in_content_entity_list"] = list(map(lambda x:
|
323 |
+
list(filter(lambda e: e in x, entity_list))
|
324 |
+
, d["content_list"]))
|
325 |
+
|
326 |
+
if add_relate_entities:
|
327 |
+
relate_content_entity_list = [[]] * len(content_list)
|
328 |
+
|
329 |
+
for entity in entity_list:
|
330 |
+
entity_content_score_d = query_content_ask_func(entity, d["content_list"],
|
331 |
+
llm, setfit_model, show_process = show_process)
|
332 |
+
lookup_df = pd.DataFrame(
|
333 |
+
list(zip(*[entity_content_score_d["content_list"],
|
334 |
+
entity_content_score_d["relate_prob_list"]]))
|
335 |
+
)
|
336 |
+
for ii, (i, r) in enumerate(lookup_df.iterrows()):
|
337 |
+
if r.iloc[1] >= 0.5 and entity not in relate_content_entity_list[ii]:
|
338 |
+
#relate_content_entity_list[ii].append(entity)
|
339 |
+
relate_content_entity_list[ii] = relate_content_entity_list[ii] + [entity]
|
340 |
+
|
341 |
+
d["relate_content_entity_list"] = relate_content_entity_list
|
342 |
+
|
343 |
+
if as_content_score_df:
|
344 |
+
if add_relate_entities:
|
345 |
+
df = pd.concat(
|
346 |
+
[
|
347 |
+
pd.Series(d["content_list"]).map(lambda x: x.strip()),
|
348 |
+
pd.Series(d["in_content_entity_list"]),
|
349 |
+
pd.Series(d["relate_content_entity_list"]),
|
350 |
+
pd.Series(d["question_content_relate_list"]).map(lambda x: x.strip()),
|
351 |
+
pd.Series(d["relate_prob_list"])
|
352 |
+
], axis = 1
|
353 |
+
)
|
354 |
+
df.columns = ["content", "entities", "relate_entities", "relate_eval_str", "score"]
|
355 |
+
else:
|
356 |
+
df = pd.concat(
|
357 |
+
[
|
358 |
+
pd.Series(d["content_list"]).map(lambda x: x.strip()),
|
359 |
+
pd.Series(d["in_content_entity_list"]),
|
360 |
+
#pd.Series(d["relate_content_entity_list"]),
|
361 |
+
pd.Series(d["question_content_relate_list"]).map(lambda x: x.strip()),
|
362 |
+
pd.Series(d["relate_prob_list"])
|
363 |
+
], axis = 1
|
364 |
+
)
|
365 |
+
df.columns = ["content", "entities", "relate_eval_str", "score"]
|
366 |
+
req = []
|
367 |
+
entities_num_list = df["entities"].map(len).drop_duplicates().dropna().sort_values(ascending = False).\
|
368 |
+
values.tolist()
|
369 |
+
for e_num in entities_num_list:
|
370 |
+
req.append(
|
371 |
+
df[
|
372 |
+
df["entities"].map(lambda x: len(x) == e_num)
|
373 |
+
].sort_values(by = "score", ascending = False)
|
374 |
+
)
|
375 |
+
return pd.concat(req, axis = 0)
|
376 |
+
#df = df.sort_values(by = "score", ascending = False)
|
377 |
+
#return df
|
378 |
+
return d
|
379 |
+
|
380 |
+
def mistral_predict(prompt, llm, show_process = True, max_length = 512):
|
381 |
+
out = llm.create_chat_completion(
|
382 |
+
messages = [] + [
|
383 |
+
{
|
384 |
+
"role": "user",
|
385 |
+
#"content": prompt + "如果没有提到相关内容,请回答不知道。使用中文进行回答,不要包含任何英文。"
|
386 |
+
"content": prompt
|
387 |
+
}
|
388 |
+
],
|
389 |
+
stream=True
|
390 |
+
)
|
391 |
+
from IPython.display import clear_output
|
392 |
+
out_text = ""
|
393 |
+
for chunk in out:
|
394 |
+
delta = chunk['choices'][0]['delta']
|
395 |
+
if "content" in delta:
|
396 |
+
out_text += delta['content']
|
397 |
+
if show_process:
|
398 |
+
print(out_text)
|
399 |
+
if len(out_text) >= max_length:
|
400 |
+
break
|
401 |
+
clear_output(wait=True)
|
402 |
+
clear_output(wait=True)
|
403 |
+
return out_text
|
404 |
+
|
405 |
+
def run_all(query, docsearch_bge_loaded, bge_book_embeddings, book_df,
|
406 |
+
llm, setfit_model, only_return_prompt = False):
|
407 |
+
df = build_relate_ask_list(query, docsearch_bge_loaded, bge_book_embeddings, book_df,
|
408 |
+
llm, setfit_model, show_process=False)
|
409 |
+
info_list = df[
|
410 |
+
df.apply(
|
411 |
+
lambda x: x["score"] >= 0.5 and bool(x["entities"]), axis = 1
|
412 |
+
)
|
413 |
+
].values.tolist()
|
414 |
+
if not info_list:
|
415 |
+
return df, info_list, "没有相关内容,谢谢你的提问。"
|
416 |
+
prompt = '''
|
417 |
+
问题: {}
|
418 |
+
根据下面的内容回答上面的问题,如果无法根据内容确定答案,请回答不知道。
|
419 |
+
{}
|
420 |
+
'''.format(query, "\n\n".join(pd.Series(info_list).map(lambda x: x[0]).values.tolist()))
|
421 |
+
if only_return_prompt:
|
422 |
+
return df, info_list, prompt
|
423 |
+
out = mistral_predict(prompt + "\n使用中文进行回答,不要包含任何英文。", llm)
|
424 |
+
return df, info_list, out
|
425 |
+
|
426 |
+
#book_df = pd.read_csv("genshin_book_chunks_with_qa_sp.csv")
|
427 |
+
book_df = pd.read_csv("genshin_book_chunks_with_qa_sp/genshin_book_chunks_with_qa_sp.csv")
|
428 |
+
book_df["content_chunks"].dropna().drop_duplicates().shape
|
429 |
+
|
430 |
+
book_df["content_chunks_formatted"] = book_df.apply(
|
431 |
+
lambda x: "文章标题:{}\n子标题:{}\n内容:{}".format(x["title"], x["sub_title"], x["content_chunks"]),
|
432 |
+
axis = 1
|
433 |
+
)
|
434 |
+
|
435 |
+
texts = book_df["content_chunks_formatted"].dropna().drop_duplicates().values.tolist()
|
436 |
+
|
437 |
+
#embedding_path = "bge-small-book-qa/"
|
438 |
+
embedding_path = "svjack/bge-small-book-qa"
|
439 |
+
bge_book_embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
|
440 |
+
docsearch_bge_loaded = FAISS.load_local("bge_small_book_chunks_prebuld/", bge_book_embeddings)
|
441 |
+
|
442 |
+
from llama_cpp import Llama
|
443 |
+
#true_path = "mistral-7b-instruct-v0.2.Q4_0.gguf"
|
444 |
+
true_path = "mistral-7b/mistral-7b-instruct-v0.2.Q4_0.gguf"
|
445 |
+
|
446 |
+
#### 16g +
|
447 |
+
# Set gpu_layers to the number of layers to offload to GPU. Set to 0 if no GPU acceleration is available on your system.
|
448 |
+
llm = Llama(
|
449 |
+
model_path=true_path, # Download the model file first
|
450 |
+
n_ctx=8000, # The max sequence length to use - note that longer sequence lengths require much more resources
|
451 |
+
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
|
452 |
+
n_gpu_layers=-1, # The number of layers to offload to GPU, if you have GPU acceleration available
|
453 |
+
chat_format="llama-2"
|
454 |
+
)
|
455 |
+
|
456 |
+
from setfit import SetFitModel
|
457 |
+
#setfit_model = SetFitModel.from_pretrained("setfit_info_cls")
|
458 |
+
setfit_model = SetFitModel.from_pretrained("svjack/setfit_info_cls")
|
459 |
+
|
460 |
+
import gradio as gr
|
461 |
+
|
462 |
+
with gr.Blocks() as demo:
|
463 |
+
title = gr.HTML(
|
464 |
+
"""<h1 align="center"> <font size="+3"> Genshin Impact Book QA Mistral-7B Demo ⚡ </font> </h1>""",
|
465 |
+
elem_id="title",
|
466 |
+
)
|
467 |
+
|
468 |
+
with gr.Column():
|
469 |
+
with gr.Row():
|
470 |
+
query = gr.Text(label = "输入问题:", lines = 1, interactive = True, scale = 5.0)
|
471 |
+
run_button = gr.Button("得到答案")
|
472 |
+
output = gr.Text(label = "回答:", lines = 5, interactive = True)
|
473 |
+
recall_items = gr.JSON(label = "召回相关内容", interactive = False)
|
474 |
+
|
475 |
+
with gr.Row():
|
476 |
+
gr.Examples(
|
477 |
+
[
|
478 |
+
'丘丘人有哪些生活习惯?',
|
479 |
+
#'岩王帝君和归终是什么关系?',
|
480 |
+
'盐之魔神的下场是什么样的?',
|
481 |
+
#'归终是谁?',
|
482 |
+
'岩王帝君是一个什么样的人?',
|
483 |
+
'铳枪手的故事内容是什么样的?',
|
484 |
+
#'白夜国的子民遭遇了什么?',
|
485 |
+
'大蛇居住在哪里?',
|
486 |
+
'珊瑚宫有哪些传说?',
|
487 |
+
'灵光颂的内容是什么样的?',
|
488 |
+
#'连心珠讲了一件什么事情?',
|
489 |
+
#'梓心是谁?',
|
490 |
+
'枫丹有哪些故事?',
|
491 |
+
'璃月有哪些故事?',
|
492 |
+
'轻策庄有哪些故事?',
|
493 |
+
'瑶光滩有哪些故事?',
|
494 |
+
'稻妻有哪些故事?',
|
495 |
+
'海祇岛有哪些故事?',
|
496 |
+
#'须弥有哪些故事?',
|
497 |
+
'蒙德有哪些故事?',
|
498 |
+
'璃月有哪些奇珍异宝?',
|
499 |
+
'狸猫和天狗是什么关系?',
|
500 |
+
'岩王帝君和归终是什么关系?',
|
501 |
+
],
|
502 |
+
inputs = query,
|
503 |
+
label = "被书目内容包含的问题"
|
504 |
+
)
|
505 |
+
with gr.Row():
|
506 |
+
gr.Examples(
|
507 |
+
[
|
508 |
+
'爱丽丝女士是可莉的妈妈吗?',
|
509 |
+
'摘星崖是什么样的?',
|
510 |
+
'丘丘人使用的是什么文字?',
|
511 |
+
'深渊使徒哪里来的?',
|
512 |
+
'发条机关可以用来做什么��',
|
513 |
+
'那先朱那做了什么?',
|
514 |
+
],
|
515 |
+
inputs = query,
|
516 |
+
label = "没有被书目明确提到的问题"
|
517 |
+
)
|
518 |
+
|
519 |
+
run_button.click(lambda x:
|
520 |
+
run_all(x, docsearch_bge_loaded, bge_book_embeddings, book_df, llm,
|
521 |
+
setfit_model = setfit_model)[1:],
|
522 |
+
query, [recall_items, output]
|
523 |
+
)
|
524 |
+
|
525 |
+
demo.queue(max_size=4, concurrency_count=1).launch(debug=True, show_api=False, share = True)
|