KevinHuSh commited on
Commit
cfd888e
·
1 Parent(s): bbbfe3a

deal with stop reason being length problem (#109)

Browse files
api/apps/conversation_app.py CHANGED
@@ -176,7 +176,7 @@ def chat(dialog, messages, **kwargs):
176
  if not llm:
177
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
178
  llm = llm[0]
179
- question = messages[-1]["content"]
180
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
181
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
182
 
@@ -184,7 +184,7 @@ def chat(dialog, messages, **kwargs):
184
  ## try to use sql if field mapping is good to go
185
  if field_map:
186
  stat_logger.info("Use SQL to retrieval.")
187
- markdown_tbl, chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
188
  if markdown_tbl:
189
  return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
190
 
@@ -195,7 +195,9 @@ def chat(dialog, messages, **kwargs):
195
  if p["key"] not in kwargs:
196
  prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
197
 
198
- kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
 
 
199
  dialog.similarity_threshold,
200
  dialog.vector_similarity_weight, top=1024, aggs=False)
201
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
@@ -224,13 +226,14 @@ def chat(dialog, messages, **kwargs):
224
 
225
 
226
  def use_sql(question, field_map, tenant_id, chat_mdl):
227
- sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
228
  user_promt = """
229
  表名:{};
230
  数据库表字段说明如下:
231
  {}
232
 
233
- 问题:{}
 
234
  请写出SQL,且只要SQL,不要有其他说明及文字。
235
  """.format(
236
  index_name(tenant_id),
 
176
  if not llm:
177
  raise LookupError("LLM(%s) not found" % dialog.llm_id)
178
  llm = llm[0]
179
+ questions = [m["content"] for m in messages if m["role"] == "user"]
180
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
181
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
182
 
 
184
  ## try to use sql if field mapping is good to go
185
  if field_map:
186
  stat_logger.info("Use SQL to retrieval.")
187
+ markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
188
  if markdown_tbl:
189
  return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
190
 
 
195
  if p["key"] not in kwargs:
196
  prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
197
 
198
+ for _ in range(len(questions)//2):
199
+ questions.append(questions[-1])
200
+ kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
201
  dialog.similarity_threshold,
202
  dialog.vector_similarity_weight, top=1024, aggs=False)
203
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
 
226
 
227
 
228
  def use_sql(question, field_map, tenant_id, chat_mdl):
229
+ sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
230
  user_promt = """
231
  表名:{};
232
  数据库表字段说明如下:
233
  {}
234
 
235
+ 问题如下:
236
+ {}
237
  请写出SQL,且只要SQL,不要有其他说明及文字。
238
  """.format(
239
  index_name(tenant_id),
api/apps/user_app.py CHANGED
@@ -100,12 +100,14 @@ def github_callback():
100
  if len(users) > 1: raise Exception('Same E-mail exist!')
101
  user = users[0]
102
  login_user(user)
 
103
  except Exception as e:
104
  rollback_user_registration(user_id)
105
  stat_logger.exception(e)
106
  return redirect("/?error=%s"%str(e))
107
-
108
- return redirect("/?auth=%s"%user_id)
 
109
 
110
 
111
  def user_info_from_github(access_token):
 
100
  if len(users) > 1: raise Exception('Same E-mail exist!')
101
  user = users[0]
102
  login_user(user)
103
+ return redirect("/?auth=%s"%user.get_id())
104
  except Exception as e:
105
  rollback_user_registration(user_id)
106
  stat_logger.exception(e)
107
  return redirect("/?error=%s"%str(e))
108
+ user = users[0]
109
+ login_user(user)
110
+ return redirect("/?auth=%s" % user.get_id())
111
 
112
 
113
  def user_info_from_github(access_token):
deepdoc/vision/t_recognizer.py CHANGED
@@ -28,7 +28,7 @@ def main(args):
28
  images, outputs = init_in_out(args)
29
  if args.mode.lower() == "layout":
30
  labels = LayoutRecognizer.labels
31
- detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
32
  if args.mode.lower() == "tsr":
33
  labels = TableStructureRecognizer.labels
34
  detr = TableStructureRecognizer()
 
28
  images, outputs = init_in_out(args)
29
  if args.mode.lower() == "layout":
30
  labels = LayoutRecognizer.labels
31
+ detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
32
  if args.mode.lower() == "tsr":
33
  labels = TableStructureRecognizer.labels
34
  detr = TableStructureRecognizer()
rag/app/presentation.py CHANGED
@@ -73,12 +73,13 @@ class Pdf(PdfParser):
73
  return res
74
 
75
 
76
- def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
77
  """
78
  The supported file formats are pdf, pptx.
79
  Every page will be treated as a chunk. And the thumbnail of every page will be stored.
80
  PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
81
  """
 
82
  doc = {
83
  "docnm_kwd": filename,
84
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
@@ -98,8 +99,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
98
  for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
99
  d = copy.deepcopy(doc)
100
  d["image"] = img
101
- d["page_num_obj"] = [pn+1]
102
- tokenize(d, txt, pdf_parser.is_english)
 
 
103
  res.append(d)
104
  return res
105
 
 
73
  return res
74
 
75
 
76
+ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
77
  """
78
  The supported file formats are pdf, pptx.
79
  Every page will be treated as a chunk. And the thumbnail of every page will be stored.
80
  PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
81
  """
82
+ eng = lang.lower() == "english"
83
  doc = {
84
  "docnm_kwd": filename,
85
  "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
 
99
  for pn, (txt,img) in enumerate(pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)):
100
  d = copy.deepcopy(doc)
101
  d["image"] = img
102
+ d["page_num_int"] = [pn+1]
103
+ d["top_int"] = [0]
104
+ d["position_int"].append((pn + 1, 0, img.size[0], 0, img.size[1]))
105
+ tokenize(d, txt, eng)
106
  res.append(d)
107
  return res
108
 
rag/llm/chat_model.py CHANGED
@@ -14,9 +14,13 @@
14
  # limitations under the License.
15
  #
16
  from abc import ABC
 
 
17
  from openai import OpenAI
18
  import openai
19
 
 
 
20
 
21
  class Base(ABC):
22
  def __init__(self, key, model_name):
@@ -34,13 +38,17 @@ class GptTurbo(Base):
34
  def chat(self, system, history, gen_conf):
35
  if system: history.insert(0, {"role": "system", "content": system})
36
  try:
37
- res = self.client.chat.completions.create(
38
  model=self.model_name,
39
  messages=history,
40
  **gen_conf)
41
- return res.choices[0].message.content.strip(), res.usage.completion_tokens
 
 
 
 
42
  except openai.APIError as e:
43
- return "ERROR: "+str(e), 0
44
 
45
 
46
  from dashscope import Generation
@@ -59,9 +67,16 @@ class QWenChat(Base):
59
  result_format='message',
60
  **gen_conf
61
  )
 
 
62
  if response.status_code == HTTPStatus.OK:
63
- return response.output.choices[0]['message']['content'], response.usage.output_tokens
64
- return "ERROR: " + response.message, 0
 
 
 
 
 
65
 
66
 
67
  from zhipuai import ZhipuAI
@@ -73,11 +88,16 @@ class ZhipuChat(Base):
73
  def chat(self, system, history, gen_conf):
74
  from http import HTTPStatus
75
  if system: history.insert(0, {"role": "system", "content": system})
76
- response = self.client.chat.completions.create(
77
- self.model_name,
78
- messages=history,
79
- **gen_conf
80
- )
81
- if response.status_code == HTTPStatus.OK:
82
- return response.output.choices[0]['message']['content'], response.usage.completion_tokens
83
- return "ERROR: " + response.message, 0
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
  from abc import ABC
17
+ from copy import deepcopy
18
+
19
  from openai import OpenAI
20
  import openai
21
 
22
+ from rag.nlp import is_english
23
+
24
 
25
  class Base(ABC):
26
  def __init__(self, key, model_name):
 
38
  def chat(self, system, history, gen_conf):
39
  if system: history.insert(0, {"role": "system", "content": system})
40
  try:
41
+ response = self.client.chat.completions.create(
42
  model=self.model_name,
43
  messages=history,
44
  **gen_conf)
45
+ ans = response.output.choices[0]['message']['content'].strip()
46
+ if response.output.choices[0].get("finish_reason", "") == "length":
47
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
48
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
49
+ return ans, response.usage.completion_tokens
50
  except openai.APIError as e:
51
+ return "**ERROR**: "+str(e), 0
52
 
53
 
54
  from dashscope import Generation
 
67
  result_format='message',
68
  **gen_conf
69
  )
70
+ ans = ""
71
+ tk_count = 0
72
  if response.status_code == HTTPStatus.OK:
73
+ ans += response.output.choices[0]['message']['content']
74
+ tk_count += response.usage.output_tokens
75
+ if response.output.choices[0].get("finish_reason", "") == "length":
76
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
77
+ return ans, tk_count
78
+
79
+ return "**ERROR**: " + response.message, tk_count
80
 
81
 
82
  from zhipuai import ZhipuAI
 
88
  def chat(self, system, history, gen_conf):
89
  from http import HTTPStatus
90
  if system: history.insert(0, {"role": "system", "content": system})
91
+ try:
92
+ response = self.client.chat.completions.create(
93
+ self.model_name,
94
+ messages=history,
95
+ **gen_conf
96
+ )
97
+ ans = response.output.choices[0]['message']['content'].strip()
98
+ if response.output.choices[0].get("finish_reason", "") == "length":
99
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
100
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
101
+ return ans, response.usage.completion_tokens
102
+ except Exception as e:
103
+ return "**ERROR**: " + str(e), 0
rag/nlp/search.py CHANGED
@@ -224,12 +224,13 @@ class Dealer:
224
  chunks_tks,
225
  tkweight, vtweight)
226
  mx = np.max(sim) * 0.99
227
- if mx < 0.35:
228
  continue
229
  cites[idx[i]] = list(
230
  set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
231
 
232
  res = ""
 
233
  for i, p in enumerate(pieces):
234
  res += p
235
  if i not in idx:
@@ -237,7 +238,10 @@ class Dealer:
237
  if i not in cites:
238
  continue
239
  for c in cites[i]: assert int(c) < len(chunk_v)
240
- for c in cites[i]: res += f" ##{c}$$"
 
 
 
241
 
242
  return res
243
 
@@ -318,7 +322,7 @@ class Dealer:
318
  if dnm not in ranks["doc_aggs"]:
319
  ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
320
  ranks["doc_aggs"][dnm]["count"] += 1
321
- ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
322
 
323
  return ranks
324
 
 
224
  chunks_tks,
225
  tkweight, vtweight)
226
  mx = np.max(sim) * 0.99
227
+ if mx < 0.66:
228
  continue
229
  cites[idx[i]] = list(
230
  set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
231
 
232
  res = ""
233
+ seted = set([])
234
  for i, p in enumerate(pieces):
235
  res += p
236
  if i not in idx:
 
238
  if i not in cites:
239
  continue
240
  for c in cites[i]: assert int(c) < len(chunk_v)
241
+ for c in cites[i]:
242
+ if c in seted:continue
243
+ res += f" ##{c}$$"
244
+ seted.add(c)
245
 
246
  return res
247
 
 
322
  if dnm not in ranks["doc_aggs"]:
323
  ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
324
  ranks["doc_aggs"][dnm]["count"] += 1
325
+ ranks["doc_aggs"] = []#[{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
326
 
327
  return ranks
328