KevinHuSh
commited on
Commit
·
22390c0
1
Parent(s):
d13f144
fix disable and enable llm setting in dialog (#616)
Browse files### What problem does this PR solve?
#614
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/dialog_app.py +1 -7
- api/utils/file_utils.py +1 -1
- rag/app/qa.py +19 -6
- rag/llm/chat_model.py +6 -6
- rag/llm/embedding_model.py +2 -2
api/apps/dialog_app.py
CHANGED
@@ -35,13 +35,7 @@ def set_dialog():
|
|
35 |
top_n = req.get("top_n", 6)
|
36 |
similarity_threshold = req.get("similarity_threshold", 0.1)
|
37 |
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
38 |
-
llm_setting = req.get("llm_setting", {
|
39 |
-
"temperature": 0.1,
|
40 |
-
"top_p": 0.3,
|
41 |
-
"frequency_penalty": 0.7,
|
42 |
-
"presence_penalty": 0.4,
|
43 |
-
"max_tokens": 215
|
44 |
-
})
|
45 |
default_prompt = {
|
46 |
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
47 |
以下是知识库:
|
|
|
35 |
top_n = req.get("top_n", 6)
|
36 |
similarity_threshold = req.get("similarity_threshold", 0.1)
|
37 |
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
38 |
+
llm_setting = req.get("llm_setting", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
default_prompt = {
|
40 |
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
|
41 |
以下是知识库:
|
api/utils/file_utils.py
CHANGED
@@ -67,7 +67,7 @@ def get_rag_python_directory(*args):
|
|
67 |
|
68 |
|
69 |
def get_home_cache_dir():
|
70 |
-
dir = os.path.join(os.path.expanduser('~'), ".
|
71 |
try:
|
72 |
os.mkdir(dir)
|
73 |
except OSError as error:
|
|
|
67 |
|
68 |
|
69 |
def get_home_cache_dir():
|
70 |
+
dir = os.path.join(os.path.expanduser('~'), ".ragflow")
|
71 |
try:
|
72 |
os.mkdir(dir)
|
73 |
except OSError as error:
|
rag/app/qa.py
CHANGED
@@ -116,18 +116,31 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
116 |
break
|
117 |
txt += l
|
118 |
lines = txt.split("\n")
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
fails = []
|
121 |
-
|
122 |
-
|
|
|
|
|
123 |
if len(arr) != 2:
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
127 |
if len(res) % 999 == 0:
|
128 |
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
129 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
130 |
|
|
|
|
|
131 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
132 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
133 |
|
|
|
116 |
break
|
117 |
txt += l
|
118 |
lines = txt.split("\n")
|
119 |
+
comma, tab = 0, 0
|
120 |
+
for l in lines:
|
121 |
+
if len(l.split(",")) == 2: comma += 1
|
122 |
+
if len(l.split("\t")) == 2: tab += 1
|
123 |
+
delimiter = "\t" if tab >= comma else ","
|
124 |
+
|
125 |
fails = []
|
126 |
+
question, answer = "", ""
|
127 |
+
i = 0
|
128 |
+
while i < len(lines):
|
129 |
+
arr = lines[i].split(delimiter)
|
130 |
if len(arr) != 2:
|
131 |
+
if question: answer += "\n" + lines[i]
|
132 |
+
else:
|
133 |
+
fails.append(str(i+1))
|
134 |
+
elif len(arr) == 2:
|
135 |
+
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
136 |
+
question, answer = arr
|
137 |
+
i += 1
|
138 |
if len(res) % 999 == 0:
|
139 |
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
140 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
141 |
|
142 |
+
if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
143 |
+
|
144 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
145 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
146 |
|
rag/llm/chat_model.py
CHANGED
@@ -141,12 +141,12 @@ class OllamaChat(Base):
|
|
141 |
if system:
|
142 |
history.insert(0, {"role": "system", "content": system})
|
143 |
try:
|
144 |
-
options = {
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
response = self.client.chat(
|
151 |
model=self.model_name,
|
152 |
messages=history,
|
|
|
141 |
if system:
|
142 |
history.insert(0, {"role": "system", "content": system})
|
143 |
try:
|
144 |
+
options = {}
|
145 |
+
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
146 |
+
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
147 |
+
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
148 |
+
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
149 |
+
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
150 |
response = self.client.chat(
|
151 |
model=self.model_name,
|
152 |
messages=history,
|
rag/llm/embedding_model.py
CHANGED
@@ -236,8 +236,8 @@ class YoudaoEmbed(Base):
|
|
236 |
try:
|
237 |
print("LOADING BCE...")
|
238 |
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
239 |
-
|
240 |
-
"
|
241 |
except Exception as e:
|
242 |
YoudaoEmbed._client = qanthing(
|
243 |
model_name_or_path=model_name.replace(
|
|
|
236 |
try:
|
237 |
print("LOADING BCE...")
|
238 |
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
239 |
+
get_home_cache_dir(),
|
240 |
+
"bce-embedding-base_v1"))
|
241 |
except Exception as e:
|
242 |
YoudaoEmbed._client = qanthing(
|
243 |
model_name_or_path=model_name.replace(
|