KevinHuSh commited on
Commit
e32ef75
·
1 Parent(s): 34b2ab3

Test chat API and refine ppt chunker (#42)

Browse files
api/apps/conversation_app.py CHANGED
@@ -17,7 +17,7 @@ from flask import request
17
  from flask_login import login_required
18
  from api.db.services.dialog_service import DialogService, ConversationService
19
  from api.db import LLMType
20
- from api.db.services.llm_service import LLMService, TenantLLMService
21
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
22
  from api.utils import get_uuid
23
  from api.utils.api_utils import get_json_result
@@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs):
170
  if p["key"] not in kwargs:
171
  prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
172
 
173
- model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
174
- if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
175
-
176
  question = messages[-1]["content"]
177
- embd_mdl = TenantLLMService.model_instance(
178
- dialog.tenant_id, LLMType.EMBEDDING.value)
179
  kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
180
  dialog.vector_similarity_weight, top=1024, aggs=False)
181
  knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
@@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs):
189
  used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
190
  if "max_tokens" in gen_conf:
191
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
192
- mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
193
- answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
194
 
195
  answer = retrievaler.insert_citations(answer,
196
  [ck["content_ltks"] for ck in kbinfos["chunks"]],
 
17
  from flask_login import login_required
18
  from api.db.services.dialog_service import DialogService, ConversationService
19
  from api.db import LLMType
20
+ from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
21
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
22
  from api.utils import get_uuid
23
  from api.utils.api_utils import get_json_result
 
170
  if p["key"] not in kwargs:
171
  prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
172
 
 
 
 
173
  question = messages[-1]["content"]
174
+ embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
175
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
176
  kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
177
  dialog.vector_similarity_weight, top=1024, aggs=False)
178
  knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
 
186
  used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
187
  if "max_tokens" in gen_conf:
188
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
189
+ answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
 
190
 
191
  answer = retrievaler.insert_citations(answer,
192
  [ck["content_ltks"] for ck in kbinfos["chunks"]],
api/db/db_models.py CHANGED
@@ -524,6 +524,7 @@ class Dialog(DataBaseModel):
524
  similarity_threshold = FloatField(default=0.2)
525
  vector_similarity_weight = FloatField(default=0.3)
526
  top_n = IntegerField(default=6)
 
527
 
528
  kb_ids = JSONField(null=False, default=[])
529
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
 
524
  similarity_threshold = FloatField(default=0.2)
525
  vector_similarity_weight = FloatField(default=0.3)
526
  top_n = IntegerField(default=6)
527
+ do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
528
 
529
  kb_ids = JSONField(null=False, default=[])
530
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
api/db/services/llm_service.py CHANGED
@@ -14,12 +14,12 @@
14
  # limitations under the License.
15
  #
16
  from api.db.services.user_service import TenantService
17
- from rag.llm import EmbeddingModel, CvModel
 
18
  from api.db import LLMType
19
  from api.db.db_models import DB, UserTenant
20
  from api.db.db_models import LLMFactories, LLM, TenantLLM
21
  from api.db.services.common_service import CommonService
22
- from api.db import StatusEnum
23
 
24
 
25
  class LLMFactoriesService(CommonService):
@@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
37
  @DB.connection_context()
38
  def get_api_key(cls, tenant_id, model_name):
39
  objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
40
- if not objs: return
 
41
  return objs[0]
42
 
43
  @classmethod
44
  @DB.connection_context()
45
  def get_my_llms(cls, tenant_id):
46
- fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
 
 
 
 
 
47
  objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
48
  cls.model.tenant_id == tenant_id).dicts()
49
 
@@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
51
 
52
  @classmethod
53
  @DB.connection_context()
54
- def model_instance(cls, tenant_id, llm_type):
55
- e,tenant = TenantService.get_by_id(tenant_id)
56
- if not e: raise LookupError("Tenant not found")
 
57
 
58
- if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
59
- elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
60
- elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
61
- elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
62
- else: assert False, "LLM type error"
 
 
 
 
 
63
 
64
  model_config = cls.get_api_key(tenant_id, mdlnm)
65
- if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
 
66
  model_config = model_config.to_dict()
67
  if llm_type == LLMType.EMBEDDING.value:
68
- if model_config["llm_factory"] not in EmbeddingModel: return
69
- return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
 
 
70
 
71
  if llm_type == LLMType.IMAGE2TEXT.value:
72
- if model_config["llm_factory"] not in CvModel: return
73
- return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
  from api.db.services.user_service import TenantService
17
+ from api.settings import database_logger
18
+ from rag.llm import EmbeddingModel, CvModel, ChatModel
19
  from api.db import LLMType
20
  from api.db.db_models import DB, UserTenant
21
  from api.db.db_models import LLMFactories, LLM, TenantLLM
22
  from api.db.services.common_service import CommonService
 
23
 
24
 
25
  class LLMFactoriesService(CommonService):
 
37
  @DB.connection_context()
38
  def get_api_key(cls, tenant_id, model_name):
39
  objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
40
+ if not objs:
41
+ return
42
  return objs[0]
43
 
44
  @classmethod
45
  @DB.connection_context()
46
  def get_my_llms(cls, tenant_id):
47
+ fields = [
48
+ cls.model.llm_factory,
49
+ LLMFactories.logo,
50
+ LLMFactories.tags,
51
+ cls.model.model_type,
52
+ cls.model.llm_name]
53
  objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
54
  cls.model.tenant_id == tenant_id).dicts()
55
 
 
57
 
58
  @classmethod
59
  @DB.connection_context()
60
+ def model_instance(cls, tenant_id, llm_type, llm_name=None):
61
+ e, tenant = TenantService.get_by_id(tenant_id)
62
+ if not e:
63
+ raise LookupError("Tenant not found")
64
 
65
+ if llm_type == LLMType.EMBEDDING.value:
66
+ mdlnm = tenant.embd_id
67
+ elif llm_type == LLMType.SPEECH2TEXT.value:
68
+ mdlnm = tenant.asr_id
69
+ elif llm_type == LLMType.IMAGE2TEXT.value:
70
+ mdlnm = tenant.img2txt_id
71
+ elif llm_type == LLMType.CHAT.value:
72
+ mdlnm = tenant.llm_id if not llm_name else llm_name
73
+ else:
74
+ assert False, "LLM type error"
75
 
76
  model_config = cls.get_api_key(tenant_id, mdlnm)
77
+ if not model_config:
78
+ raise LookupError("Model({}) not found".format(mdlnm))
79
  model_config = model_config.to_dict()
80
  if llm_type == LLMType.EMBEDDING.value:
81
+ if model_config["llm_factory"] not in EmbeddingModel:
82
+ return
83
+ return EmbeddingModel[model_config["llm_factory"]](
84
+ model_config["api_key"], model_config["llm_name"])
85
 
86
  if llm_type == LLMType.IMAGE2TEXT.value:
87
+ if model_config["llm_factory"] not in CvModel:
88
+ return
89
+ return CvModel[model_config["llm_factory"]](
90
+ model_config["api_key"], model_config["llm_name"])
91
+
92
+ if llm_type == LLMType.CHAT.value:
93
+ if model_config["llm_factory"] not in ChatModel:
94
+ return
95
+ return ChatModel[model_config["llm_factory"]](
96
+ model_config["api_key"], model_config["llm_name"])
97
+
98
+ @classmethod
99
+ @DB.connection_context()
100
+ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
101
+ e, tenant = TenantService.get_by_id(tenant_id)
102
+ if not e:
103
+ raise LookupError("Tenant not found")
104
+
105
+ if llm_type == LLMType.EMBEDDING.value:
106
+ mdlnm = tenant.embd_id
107
+ elif llm_type == LLMType.SPEECH2TEXT.value:
108
+ mdlnm = tenant.asr_id
109
+ elif llm_type == LLMType.IMAGE2TEXT.value:
110
+ mdlnm = tenant.img2txt_id
111
+ elif llm_type == LLMType.CHAT.value:
112
+ mdlnm = tenant.llm_id if not llm_name else llm_name
113
+ else:
114
+ assert False, "LLM type error"
115
+
116
+ num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
117
+ .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
118
+ .execute()
119
+ return num
120
+
121
+
122
+ class LLMBundle(object):
123
+ def __init__(self, tenant_id, llm_type, llm_name=None):
124
+ self.tenant_id = tenant_id
125
+ self.llm_type = llm_type
126
+ self.llm_name = llm_name
127
+ self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
128
+ assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
129
+
130
+ def encode(self, texts: list, batch_size=32):
131
+ emd, used_tokens = self.mdl.encode(texts, batch_size)
132
+ if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
133
+ database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
134
+ return emd, used_tokens
135
+
136
+ def encode_queries(self, query: str):
137
+ emd, used_tokens = self.mdl.encode_queries(query)
138
+ if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
139
+ database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
140
+ return emd, used_tokens
141
+
142
+ def describe(self, image, max_tokens=300):
143
+ txt, used_tokens = self.mdl.describe(image, max_tokens)
144
+ if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
145
+ database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
146
+ return txt
147
+
148
+ def chat(self, system, history, gen_conf):
149
+ txt, used_tokens = self.mdl.chat(system, history, gen_conf)
150
+ if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
151
+ database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
152
+ return txt
api/utils/file_utils.py CHANGED
@@ -143,11 +143,11 @@ def filename_type(filename):
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
- if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
150
  return FileType.AURAL.value
151
 
152
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
153
- return FileType.VISUAL
 
143
  if re.match(r".*\.pdf$", filename):
144
  return FileType.PDF.value
145
 
146
+ if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
147
  return FileType.DOC.value
148
 
149
  if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
150
  return FileType.AURAL.value
151
 
152
  if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
153
+ return FileType.VISUAL
rag/llm/chat_model.py CHANGED
@@ -37,7 +37,7 @@ class GptTurbo(Base):
37
  model=self.model_name,
38
  messages=history,
39
  **gen_conf)
40
- return res.choices[0].message.content.strip()
41
 
42
 
43
  from dashscope import Generation
@@ -56,5 +56,5 @@ class QWenChat(Base):
56
  result_format='message'
57
  )
58
  if response.status_code == HTTPStatus.OK:
59
- return response.output.choices[0]['message']['content']
60
- return response.message
 
37
  model=self.model_name,
38
  messages=history,
39
  **gen_conf)
40
+ return res.choices[0].message.content.strip(), res.usage.completion_tokens
41
 
42
 
43
  from dashscope import Generation
 
56
  result_format='message'
57
  )
58
  if response.status_code == HTTPStatus.OK:
59
+ return response.output.choices[0]['message']['content'], response.usage.output_tokens
60
+ return response.message, 0
rag/llm/cv_model.py CHANGED
@@ -72,7 +72,7 @@ class GptV4(Base):
72
  messages=self.prompt(b64),
73
  max_tokens=max_tokens,
74
  )
75
- return res.choices[0].message.content.strip()
76
 
77
 
78
  class QWenCV(Base):
@@ -87,5 +87,5 @@ class QWenCV(Base):
87
  response = MultiModalConversation.call(model=self.model_name,
88
  messages=self.prompt(self.image2base64(image)))
89
  if response.status_code == HTTPStatus.OK:
90
- return response.output.choices[0]['message']['content']
91
- return response.message
 
72
  messages=self.prompt(b64),
73
  max_tokens=max_tokens,
74
  )
75
+ return res.choices[0].message.content.strip(), res.usage.total_tokens
76
 
77
 
78
  class QWenCV(Base):
 
87
  response = MultiModalConversation.call(model=self.model_name,
88
  messages=self.prompt(self.image2base64(image)))
89
  if response.status_code == HTTPStatus.OK:
90
+ return response.output.choices[0]['message']['content'], response.usage.output_tokens
91
+ return response.message, 0
rag/llm/embedding_model.py CHANGED
@@ -36,6 +36,9 @@ class Base(ABC):
36
  def encode(self, texts: list, batch_size=32):
37
  raise NotImplementedError("Please implement encode method!")
38
 
 
 
 
39
 
40
  class HuEmbedding(Base):
41
  def __init__(self, key="", model_name=""):
@@ -68,15 +71,18 @@ class HuEmbedding(Base):
68
 
69
  class OpenAIEmbed(Base):
70
  def __init__(self, key, model_name="text-embedding-ada-002"):
71
- self.client = OpenAI(key)
72
  self.model_name = model_name
73
 
74
  def encode(self, texts: list, batch_size=32):
75
- token_count = 0
76
- for t in texts: token_count += num_tokens_from_string(t)
77
  res = self.client.embeddings.create(input=texts,
78
  model=self.model_name)
79
- return [d["embedding"] for d in res["data"]], token_count
 
 
 
 
 
80
 
81
 
82
  class QWenEmbed(Base):
@@ -84,16 +90,28 @@ class QWenEmbed(Base):
84
  dashscope.api_key = key
85
  self.model_name = model_name
86
 
87
- def encode(self, texts: list, batch_size=32, text_type="document"):
88
  import dashscope
89
  res = []
90
  token_count = 0
91
- for txt in texts:
 
92
  resp = dashscope.TextEmbedding.call(
93
  model=self.model_name,
94
- input=txt[:2048],
95
- text_type=text_type
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- res.append(resp["output"]["embeddings"][0]["embedding"])
98
- token_count += resp["usage"]["total_tokens"]
99
- return res, token_count
 
36
  def encode(self, texts: list, batch_size=32):
37
  raise NotImplementedError("Please implement encode method!")
38
 
39
+ def encode_queries(self, text: str):
40
+ raise NotImplementedError("Please implement encode method!")
41
+
42
 
43
  class HuEmbedding(Base):
44
  def __init__(self, key="", model_name=""):
 
71
 
72
  class OpenAIEmbed(Base):
73
  def __init__(self, key, model_name="text-embedding-ada-002"):
74
+ self.client = OpenAI(api_key=key)
75
  self.model_name = model_name
76
 
77
  def encode(self, texts: list, batch_size=32):
 
 
78
  res = self.client.embeddings.create(input=texts,
79
  model=self.model_name)
80
+ return np.array([d.embedding for d in res.data]), res.usage.total_tokens
81
+
82
+ def encode_queries(self, text):
83
+ res = self.client.embeddings.create(input=[text],
84
+ model=self.model_name)
85
+ return np.array(res.data[0].embedding), res.usage.total_tokens
86
 
87
 
88
  class QWenEmbed(Base):
 
90
  dashscope.api_key = key
91
  self.model_name = model_name
92
 
93
+ def encode(self, texts: list, batch_size=10):
94
  import dashscope
95
  res = []
96
  token_count = 0
97
+ texts = [txt[:2048] for txt in texts]
98
+ for i in range(0, len(texts), batch_size):
99
  resp = dashscope.TextEmbedding.call(
100
  model=self.model_name,
101
+ input=texts[i:i+batch_size],
102
+ text_type="document"
103
+ )
104
+ embds = [[]] * len(resp["output"]["embeddings"])
105
+ for e in resp["output"]["embeddings"]:
106
+ embds[e["text_index"]] = e["embedding"]
107
+ res.extend(embds)
108
+ token_count += resp["usage"]["input_tokens"]
109
+ return np.array(res), token_count
110
+
111
+ def encode_queries(self, text):
112
+ resp = dashscope.TextEmbedding.call(
113
+ model=self.model_name,
114
+ input=text[:2048],
115
+ text_type="query"
116
  )
117
+ return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
 
 
rag/nlp/huchunk.py CHANGED
@@ -11,6 +11,11 @@ from io import BytesIO
11
 
12
  class HuChunker:
13
 
 
 
 
 
 
14
  def __init__(self):
15
  self.MAX_LVL = 12
16
  self.proj_patt = [
@@ -228,11 +233,6 @@ class HuChunker:
228
 
229
  class PdfChunker(HuChunker):
230
 
231
- @dataclass
232
- class Fields:
233
- text_chunks: List = None
234
- table_chunks: List = None
235
-
236
  def __init__(self, pdf_parser):
237
  self.pdf = pdf_parser
238
  super().__init__()
@@ -293,11 +293,6 @@ class PdfChunker(HuChunker):
293
 
294
  class DocxChunker(HuChunker):
295
 
296
- @dataclass
297
- class Fields:
298
- text_chunks: List = None
299
- table_chunks: List = None
300
-
301
  def __init__(self, doc_parser):
302
  self.doc = doc_parser
303
  super().__init__()
@@ -344,11 +339,6 @@ class DocxChunker(HuChunker):
344
 
345
  class ExcelChunker(HuChunker):
346
 
347
- @dataclass
348
- class Fields:
349
- text_chunks: List = None
350
- table_chunks: List = None
351
-
352
  def __init__(self, excel_parser):
353
  self.excel = excel_parser
354
  super().__init__()
@@ -370,18 +360,51 @@ class PptChunker(HuChunker):
370
  def __init__(self):
371
  super().__init__()
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  def __call__(self, fnm):
374
  from pptx import Presentation
375
  ppt = Presentation(fnm) if isinstance(
376
  fnm, str) else Presentation(
377
  BytesIO(fnm))
378
- flds = self.Fields()
379
- flds.text_chunks = []
380
  for slide in ppt.slides:
 
381
  for shape in slide.shapes:
382
- if hasattr(shape, "text"):
383
- flds.text_chunks.append((shape.text, None))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  flds.table_chunks = []
 
385
  return flds
386
 
387
 
 
11
 
12
  class HuChunker:
13
 
14
+ @dataclass
15
+ class Fields:
16
+ text_chunks: List = None
17
+ table_chunks: List = None
18
+
19
  def __init__(self):
20
  self.MAX_LVL = 12
21
  self.proj_patt = [
 
233
 
234
  class PdfChunker(HuChunker):
235
 
 
 
 
 
 
236
  def __init__(self, pdf_parser):
237
  self.pdf = pdf_parser
238
  super().__init__()
 
293
 
294
  class DocxChunker(HuChunker):
295
 
 
 
 
 
 
296
  def __init__(self, doc_parser):
297
  self.doc = doc_parser
298
  super().__init__()
 
339
 
340
  class ExcelChunker(HuChunker):
341
 
 
 
 
 
 
342
  def __init__(self, excel_parser):
343
  self.excel = excel_parser
344
  super().__init__()
 
360
  def __init__(self):
361
  super().__init__()
362
 
363
+ def __extract(self, shape):
364
+ if shape.shape_type == 19:
365
+ tb = shape.table
366
+ rows = []
367
+ for i in range(1, len(tb.rows)):
368
+ rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
369
+ return "\n".join(rows)
370
+
371
+ if shape.has_text_frame:
372
+ return shape.text_frame.text
373
+
374
+ if shape.shape_type == 6:
375
+ texts = []
376
+ for p in shape.shapes:
377
+ t = self.__extract(p)
378
+ if t: texts.append(t)
379
+ return "\n".join(texts)
380
+
381
  def __call__(self, fnm):
382
  from pptx import Presentation
383
  ppt = Presentation(fnm) if isinstance(
384
  fnm, str) else Presentation(
385
  BytesIO(fnm))
386
+ txts = []
 
387
  for slide in ppt.slides:
388
+ texts = []
389
  for shape in slide.shapes:
390
+ txt = self.__extract(shape)
391
+ if txt: texts.append(txt)
392
+ txts.append("\n".join(texts))
393
+
394
+ import aspose.slides as slides
395
+ import aspose.pydrawing as drawing
396
+ imgs = []
397
+ with slides.Presentation(BytesIO(fnm)) as presentation:
398
+ for slide in presentation.slides:
399
+ buffered = BytesIO()
400
+ slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg)
401
+ imgs.append(buffered.getvalue())
402
+ assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
403
+
404
+ flds = self.Fields()
405
+ flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
406
  flds.table_chunks = []
407
+
408
  return flds
409
 
410
 
rag/nlp/search.py CHANGED
@@ -58,7 +58,8 @@ class Dealer:
58
  if req["available_int"] == 0:
59
  bqry.filter.append(Q("range", available_int={"lt": 1}))
60
  else:
61
- bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
 
62
  bqry.boost = 0.05
63
 
64
  s = Search()
@@ -87,9 +88,12 @@ class Dealer:
87
  q_vec = []
88
  if req.get("vector"):
89
  assert emb_mdl, "No embedding model selected"
90
- s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps)
 
 
91
  s["knn"]["filter"] = bqry.to_dict()
92
- if "highlight" in s: del s["highlight"]
 
93
  q_vec = s["knn"]["query_vector"]
94
  es_logger.info("【Q】: {}".format(json.dumps(s)))
95
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
@@ -175,7 +179,8 @@ class Dealer:
175
  def trans2floats(txt):
176
  return [float(t) for t in txt.split("\t")]
177
 
178
- def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7):
 
179
  pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
180
  for i in range(1, len(pieces)):
181
  if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@@ -184,47 +189,57 @@ class Dealer:
184
  idx = []
185
  pieces_ = []
186
  for i, t in enumerate(pieces):
187
- if len(t) < 5: continue
 
188
  idx.append(i)
189
  pieces_.append(t)
190
  es_logger.info("{} => {}".format(answer, pieces_))
191
- if not pieces_: return answer
 
192
 
193
- ans_v, c = embd_mdl.encode(pieces_)
194
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
195
  len(ans_v[0]), len(chunk_v[0]))
196
 
197
  chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
198
  cites = {}
199
- for i,a in enumerate(pieces_):
200
  sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
201
  chunk_v,
202
- huqie.qie(pieces_[i]).split(" "),
 
203
  chunks_tks,
204
  tkweight, vtweight)
205
  mx = np.max(sim) * 0.99
206
- if mx < 0.55: continue
207
- cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
 
 
208
 
209
  res = ""
210
- for i,p in enumerate(pieces):
211
  res += p
212
- if i not in idx:continue
213
- if i not in cites:continue
214
- res += "##%s$$"%"$".join(cites[i])
 
 
215
 
216
  return res
217
 
218
- def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
 
219
  ins_embd = [
220
  Dealer.trans2floats(
221
- sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
222
  if not ins_embd:
223
  return [], [], []
224
- ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
 
225
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
226
  ins_embd,
227
- huqie.qie(query).split(" "),
 
228
  ins_tw, tkweight, vtweight)
229
  return sim, tksim, vtsim
230
 
@@ -237,7 +252,8 @@ class Dealer:
237
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
238
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
239
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
240
- if not question: return ranks
 
241
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
242
  "question": question, "vector": True,
243
  "similarity": similarity_threshold}
 
58
  if req["available_int"] == 0:
59
  bqry.filter.append(Q("range", available_int={"lt": 1}))
60
  else:
61
+ bqry.filter.append(
62
+ Q("bool", must_not=Q("range", available_int={"lt": 1})))
63
  bqry.boost = 0.05
64
 
65
  s = Search()
 
88
  q_vec = []
89
  if req.get("vector"):
90
  assert emb_mdl, "No embedding model selected"
91
+ s["knn"] = self._vector(
92
+ qst, emb_mdl, req.get(
93
+ "similarity", 0.4), ps)
94
  s["knn"]["filter"] = bqry.to_dict()
95
+ if "highlight" in s:
96
+ del s["highlight"]
97
  q_vec = s["knn"]["query_vector"]
98
  es_logger.info("【Q】: {}".format(json.dumps(s)))
99
  res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
 
179
  def trans2floats(txt):
180
  return [float(t) for t in txt.split("\t")]
181
 
182
+ def insert_citations(self, answer, chunks, chunk_v,
183
+ embd_mdl, tkweight=0.3, vtweight=0.7):
184
  pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
185
  for i in range(1, len(pieces)):
186
  if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
 
189
  idx = []
190
  pieces_ = []
191
  for i, t in enumerate(pieces):
192
+ if len(t) < 5:
193
+ continue
194
  idx.append(i)
195
  pieces_.append(t)
196
  es_logger.info("{} => {}".format(answer, pieces_))
197
+ if not pieces_:
198
+ return answer
199
 
200
+ ans_v, _ = embd_mdl.encode(pieces_)
201
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
202
  len(ans_v[0]), len(chunk_v[0]))
203
 
204
  chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
205
  cites = {}
206
+ for i, a in enumerate(pieces_):
207
  sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
208
  chunk_v,
209
+ huqie.qie(
210
+ pieces_[i]).split(" "),
211
  chunks_tks,
212
  tkweight, vtweight)
213
  mx = np.max(sim) * 0.99
214
+ if mx < 0.55:
215
+ continue
216
+ cites[idx[i]] = list(
217
+ set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
218
 
219
  res = ""
220
+ for i, p in enumerate(pieces):
221
  res += p
222
+ if i not in idx:
223
+ continue
224
+ if i not in cites:
225
+ continue
226
+ res += "##%s$$" % "$".join(cites[i])
227
 
228
  return res
229
 
230
+ def rerank(self, sres, query, tkweight=0.3,
231
+ vtweight=0.7, cfield="content_ltks"):
232
  ins_embd = [
233
  Dealer.trans2floats(
234
+ sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
235
  if not ins_embd:
236
  return [], [], []
237
+ ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
238
+ for i in sres.ids]
239
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
240
  ins_embd,
241
+ huqie.qie(
242
+ query).split(" "),
243
  ins_tw, tkweight, vtweight)
244
  return sim, tksim, vtsim
245
 
 
252
  def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
253
  vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
254
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
255
+ if not question:
256
+ return ranks
257
  req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
258
  "question": question, "vector": True,
259
  "similarity": similarity_threshold}
rag/svr/parse_user_docs.py CHANGED
@@ -49,7 +49,7 @@ from rag.nlp.huchunk import (
49
  )
50
  from api.db import LLMType
51
  from api.db.services.document_service import DocumentService
52
- from api.db.services.llm_service import TenantLLMService
53
  from api.settings import database_logger
54
  from api.utils import get_format_time
55
  from api.utils.file_utils import get_project_base_directory
@@ -62,7 +62,7 @@ EXC = ExcelChunker(ExcelParser())
62
  PPT = PptChunker()
63
 
64
 
65
- def chuck_doc(name, binary, cvmdl=None):
66
  suff = os.path.split(name)[-1].lower().split(".")[-1]
67
  if suff.find("pdf") >= 0:
68
  return PDF(binary)
@@ -127,7 +127,7 @@ def build(row, cvmdl):
127
  100., "Finished preparing! Start to slice file!", True)
128
  try:
129
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
130
- obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
131
  except Exception as e:
132
  if re.search("(No such file|not found)", str(e)):
133
  set_progress(
@@ -236,12 +236,14 @@ def main(comm, mod):
236
 
237
  tmf = open(tm_fnm, "a+")
238
  for _, r in rows.iterrows():
239
- embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING)
240
- if not embd_mdl:
241
- set_progress(r["id"], -1, "Can't find embedding model!")
242
- cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
 
 
243
  continue
244
- cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
245
  st_tm = timer()
246
  cks = build(r, cv_mdl)
247
  if not cks:
 
49
  )
50
  from api.db import LLMType
51
  from api.db.services.document_service import DocumentService
52
+ from api.db.services.llm_service import TenantLLMService, LLMBundle
53
  from api.settings import database_logger
54
  from api.utils import get_format_time
55
  from api.utils.file_utils import get_project_base_directory
 
62
  PPT = PptChunker()
63
 
64
 
65
+ def chuck_doc(name, binary, tenant_id, cvmdl=None):
66
  suff = os.path.split(name)[-1].lower().split(".")[-1]
67
  if suff.find("pdf") >= 0:
68
  return PDF(binary)
 
127
  100., "Finished preparing! Start to slice file!", True)
128
  try:
129
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
130
+ obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl)
131
  except Exception as e:
132
  if re.search("(No such file|not found)", str(e)):
133
  set_progress(
 
236
 
237
  tmf = open(tm_fnm, "a+")
238
  for _, r in rows.iterrows():
239
+ try:
240
+ embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
241
+ cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
242
+ #TODO: sequence2text model
243
+ except Exception as e:
244
+ set_progress(r["id"], -1, str(e))
245
  continue
246
+
247
  st_tm = timer()
248
  cks = build(r, cv_mdl)
249
  if not cks: