KevinHuSh
commited on
Commit
·
d0db329
1
Parent(s):
cdd9565
add llm API (#19)
Browse files* add llm API
* refine llm API
- python/conf/mapping.json +0 -1
- python/conf/sys.cnf +6 -7
- python/llm/__init__.py +21 -2
- python/llm/chat_model.py +12 -10
- python/llm/cv_model.py +66 -0
- python/llm/embedding_model.py +32 -3
- python/nlp/huchunk.py +7 -3
- python/nlp/search.py +100 -71
- python/parser/excel_parser.py +4 -2
- python/parser/pdf_parser.py +18 -17
- python/svr/dialog_svr.py +3 -2
- python/svr/parse_user_docs.py +15 -4
- python/util/__init__.py +9 -4
- python/util/config.py +16 -10
- python/util/db_conn.py +17 -10
- python/util/es_conn.py +6 -5
- python/util/minio_conn.py +17 -19
python/conf/mapping.json
CHANGED
@@ -121,7 +121,6 @@
|
|
121 |
"match": "*_vec",
|
122 |
"mapping": {
|
123 |
"type": "dense_vector",
|
124 |
-
"dims": 1024,
|
125 |
"index": true,
|
126 |
"similarity": "cosine"
|
127 |
}
|
|
|
121 |
"match": "*_vec",
|
122 |
"mapping": {
|
123 |
"type": "dense_vector",
|
|
|
124 |
"index": true,
|
125 |
"similarity": "cosine"
|
126 |
}
|
python/conf/sys.cnf
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
[infiniflow]
|
2 |
es=http://es01:9200
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
minio_host=minio:9000
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
1 |
[infiniflow]
|
2 |
es=http://es01:9200
|
3 |
+
postgres_user=root
|
4 |
+
postgres_password=infiniflow_docgpt
|
5 |
+
postgres_host=postgres
|
6 |
+
postgres_port=5432
|
7 |
minio_host=minio:9000
|
8 |
+
minio_user=infiniflow
|
9 |
+
minio_password=infiniflow_docgpt
|
|
python/llm/__init__.py
CHANGED
@@ -1,2 +1,21 @@
|
|
1 |
-
|
2 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .embedding_model import *
|
3 |
+
from .chat_model import *
|
4 |
+
from .cv_model import *
|
5 |
+
|
6 |
+
EmbeddingModel = None
|
7 |
+
ChatModel = None
|
8 |
+
CvModel = None
|
9 |
+
|
10 |
+
|
11 |
+
if os.environ.get("OPENAI_API_KEY"):
|
12 |
+
EmbeddingModel = GptEmbed()
|
13 |
+
ChatModel = GptTurbo()
|
14 |
+
CvModel = GptV4()
|
15 |
+
|
16 |
+
elif os.environ.get("DASHSCOPE_API_KEY"):
|
17 |
+
EmbeddingModel = QWenEmbd()
|
18 |
+
ChatModel = QWenChat()
|
19 |
+
CvModel = QWenCV()
|
20 |
+
else:
|
21 |
+
EmbeddingModel = HuEmbedding()
|
python/llm/chat_model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from abc import ABC
|
2 |
-
import
|
3 |
import os
|
4 |
|
|
|
5 |
class Base(ABC):
|
6 |
def chat(self, system, history, gen_conf):
|
7 |
raise NotImplementedError("Please implement encode method!")
|
@@ -9,26 +10,27 @@ class Base(ABC):
|
|
9 |
|
10 |
class GptTurbo(Base):
|
11 |
def __init__(self):
|
12 |
-
|
13 |
|
14 |
def chat(self, system, history, gen_conf):
|
15 |
history.insert(0, {"role": "system", "content": system})
|
16 |
-
res =
|
17 |
-
|
18 |
-
|
|
|
19 |
return res.choices[0].message.content.strip()
|
20 |
|
21 |
|
22 |
-
class
|
23 |
def chat(self, system, history, gen_conf):
|
24 |
from http import HTTPStatus
|
25 |
from dashscope import Generation
|
26 |
-
from dashscope.api_entities.dashscope_response import Role
|
27 |
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
|
|
28 |
response = Generation.call(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
)
|
33 |
if response.status_code == HTTPStatus.OK:
|
34 |
return response.output.choices[0]['message']['content']
|
|
|
1 |
from abc import ABC
|
2 |
+
from openai import OpenAI
|
3 |
import os
|
4 |
|
5 |
+
|
6 |
class Base(ABC):
|
7 |
def chat(self, system, history, gen_conf):
|
8 |
raise NotImplementedError("Please implement encode method!")
|
|
|
10 |
|
11 |
class GptTurbo(Base):
|
12 |
def __init__(self):
|
13 |
+
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
14 |
|
15 |
def chat(self, system, history, gen_conf):
|
16 |
history.insert(0, {"role": "system", "content": system})
|
17 |
+
res = self.client.chat.completions.create(
|
18 |
+
model="gpt-3.5-turbo",
|
19 |
+
messages=history,
|
20 |
+
**gen_conf)
|
21 |
return res.choices[0].message.content.strip()
|
22 |
|
23 |
|
24 |
+
class QWenChat(Base):
|
25 |
def chat(self, system, history, gen_conf):
|
26 |
from http import HTTPStatus
|
27 |
from dashscope import Generation
|
|
|
28 |
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
29 |
+
history.insert(0, {"role": "system", "content": system})
|
30 |
response = Generation.call(
|
31 |
+
Generation.Models.qwen_turbo,
|
32 |
+
messages=history,
|
33 |
+
result_format='message'
|
34 |
)
|
35 |
if response.status_code == HTTPStatus.OK:
|
36 |
return response.output.choices[0]['message']['content']
|
python/llm/cv_model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from openai import OpenAI
|
3 |
+
import os
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
|
8 |
+
class Base(ABC):
|
9 |
+
def describe(self, image, max_tokens=300):
|
10 |
+
raise NotImplementedError("Please implement encode method!")
|
11 |
+
|
12 |
+
def image2base64(self, image):
|
13 |
+
if isinstance(image, BytesIO):
|
14 |
+
return base64.b64encode(image.getvalue()).decode("utf-8")
|
15 |
+
buffered = BytesIO()
|
16 |
+
try:
|
17 |
+
image.save(buffered, format="JPEG")
|
18 |
+
except Exception as e:
|
19 |
+
image.save(buffered, format="PNG")
|
20 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
21 |
+
|
22 |
+
def prompt(self, b64):
|
23 |
+
return [
|
24 |
+
{
|
25 |
+
"role": "user",
|
26 |
+
"content": [
|
27 |
+
{
|
28 |
+
"type": "text",
|
29 |
+
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"type": "image_url",
|
33 |
+
"image_url": {
|
34 |
+
"url": f"data:image/jpeg;base64,{b64}"
|
35 |
+
},
|
36 |
+
},
|
37 |
+
],
|
38 |
+
}
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
class GptV4(Base):
|
43 |
+
def __init__(self):
|
44 |
+
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
45 |
+
|
46 |
+
def describe(self, image, max_tokens=300):
|
47 |
+
b64 = self.image2base64(image)
|
48 |
+
|
49 |
+
res = self.client.chat.completions.create(
|
50 |
+
model="gpt-4-vision-preview",
|
51 |
+
messages=self.prompt(b64),
|
52 |
+
max_tokens=max_tokens,
|
53 |
+
)
|
54 |
+
return res.choices[0].message.content.strip()
|
55 |
+
|
56 |
+
|
57 |
+
class QWenCV(Base):
|
58 |
+
def describe(self, image, max_tokens=300):
|
59 |
+
from http import HTTPStatus
|
60 |
+
from dashscope import MultiModalConversation
|
61 |
+
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
62 |
+
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
63 |
+
messages=self.prompt(self.image2base64(image)))
|
64 |
+
if response.status_code == HTTPStatus.OK:
|
65 |
+
return response.output.choices[0]['message']['content']
|
66 |
+
return response.message
|
python/llm/embedding_model.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
from abc import ABC
|
|
|
2 |
from FlagEmbedding import FlagModel
|
3 |
import torch
|
|
|
4 |
import numpy as np
|
5 |
|
|
|
6 |
class Base(ABC):
|
7 |
def encode(self, texts: list, batch_size=32):
|
8 |
raise NotImplementedError("Please implement encode method!")
|
@@ -22,11 +25,37 @@ class HuEmbedding(Base):
|
|
22 |
|
23 |
"""
|
24 |
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
25 |
-
|
26 |
-
|
27 |
|
28 |
def encode(self, texts: list, batch_size=32):
|
29 |
res = []
|
30 |
for i in range(0, len(texts), batch_size):
|
31 |
-
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
|
32 |
return np.array(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from abc import ABC
|
2 |
+
from openai import OpenAI
|
3 |
from FlagEmbedding import FlagModel
|
4 |
import torch
|
5 |
+
import os
|
6 |
import numpy as np
|
7 |
|
8 |
+
|
9 |
class Base(ABC):
|
10 |
def encode(self, texts: list, batch_size=32):
|
11 |
raise NotImplementedError("Please implement encode method!")
|
|
|
25 |
|
26 |
"""
|
27 |
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
28 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
29 |
+
use_fp16=torch.cuda.is_available())
|
30 |
|
31 |
def encode(self, texts: list, batch_size=32):
|
32 |
res = []
|
33 |
for i in range(0, len(texts), batch_size):
|
34 |
+
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
35 |
return np.array(res)
|
36 |
+
|
37 |
+
|
38 |
+
class GptEmbed(Base):
|
39 |
+
def __init__(self):
|
40 |
+
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
|
41 |
+
|
42 |
+
def encode(self, texts: list, batch_size=32):
|
43 |
+
res = self.client.embeddings.create(input=texts,
|
44 |
+
model="text-embedding-ada-002")
|
45 |
+
return [d["embedding"] for d in res["data"]]
|
46 |
+
|
47 |
+
|
48 |
+
class QWenEmbd(Base):
|
49 |
+
def encode(self, texts: list, batch_size=32, text_type="document"):
|
50 |
+
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
51 |
+
import dashscope
|
52 |
+
from http import HTTPStatus
|
53 |
+
res = []
|
54 |
+
for txt in texts:
|
55 |
+
resp = dashscope.TextEmbedding.call(
|
56 |
+
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
57 |
+
input=txt[:2048],
|
58 |
+
text_type=text_type
|
59 |
+
)
|
60 |
+
res.append(resp["output"]["embeddings"][0]["embedding"])
|
61 |
+
return res
|
python/nlp/huchunk.py
CHANGED
@@ -372,7 +372,9 @@ class PptChunker(HuChunker):
|
|
372 |
|
373 |
def __call__(self, fnm):
|
374 |
from pptx import Presentation
|
375 |
-
ppt = Presentation(fnm) if isinstance(
|
|
|
|
|
376 |
flds = self.Fields()
|
377 |
flds.text_chunks = []
|
378 |
for slide in ppt.slides:
|
@@ -398,7 +400,8 @@ class TextChunker(HuChunker):
|
|
398 |
mime = magic.Magic(mime=True)
|
399 |
if isinstance(file_path, str):
|
400 |
file_type = mime.from_file(file_path)
|
401 |
-
else:
|
|
|
402 |
if 'text' in file_type:
|
403 |
return False
|
404 |
else:
|
@@ -406,7 +409,8 @@ class TextChunker(HuChunker):
|
|
406 |
|
407 |
def __call__(self, fnm):
|
408 |
flds = self.Fields()
|
409 |
-
if self.is_binary_file(fnm):
|
|
|
410 |
with open(fnm, "r") as f:
|
411 |
txt = f.read()
|
412 |
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
|
|
372 |
|
373 |
def __call__(self, fnm):
|
374 |
from pptx import Presentation
|
375 |
+
ppt = Presentation(fnm) if isinstance(
|
376 |
+
fnm, str) else Presentation(
|
377 |
+
BytesIO(fnm))
|
378 |
flds = self.Fields()
|
379 |
flds.text_chunks = []
|
380 |
for slide in ppt.slides:
|
|
|
400 |
mime = magic.Magic(mime=True)
|
401 |
if isinstance(file_path, str):
|
402 |
file_type = mime.from_file(file_path)
|
403 |
+
else:
|
404 |
+
file_type = mime.from_buffer(file_path)
|
405 |
if 'text' in file_type:
|
406 |
return False
|
407 |
else:
|
|
|
409 |
|
410 |
def __call__(self, fnm):
|
411 |
flds = self.Fields()
|
412 |
+
if self.is_binary_file(fnm):
|
413 |
+
return flds
|
414 |
with open(fnm, "r") as f:
|
415 |
txt = f.read()
|
416 |
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
python/nlp/search.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import re
|
2 |
-
from elasticsearch_dsl import Q,Search,A
|
3 |
-
from typing import List, Optional, Tuple,Dict, Union
|
4 |
from dataclasses import dataclass
|
5 |
from util import setup_logging, rmSpace
|
6 |
from nlp import huqie, query
|
@@ -9,18 +9,24 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|
9 |
import numpy as np
|
10 |
from copy import deepcopy
|
11 |
|
12 |
-
|
|
|
|
|
13 |
|
14 |
class Dealer:
|
15 |
def __init__(self, es, emb_mdl):
|
16 |
self.qryr = query.EsQueryer(es)
|
17 |
-
self.qryr.flds = [
|
|
|
|
|
|
|
|
|
18 |
self.es = es
|
19 |
self.emb_mdl = emb_mdl
|
20 |
|
21 |
@dataclass
|
22 |
class SearchResult:
|
23 |
-
total:int
|
24 |
ids: List[str]
|
25 |
query_vector: List[float] = None
|
26 |
field: Optional[Dict] = None
|
@@ -42,71 +48,78 @@ class Dealer:
|
|
42 |
keywords = []
|
43 |
qst = req.get("question", "")
|
44 |
|
45 |
-
bqry,keywords = self.qryr.question(qst)
|
46 |
-
if req.get("kb_ids"):
|
|
|
47 |
bqry.filter.append(Q("exists", field="q_tks"))
|
48 |
bqry.boost = 0.05
|
49 |
print(bqry)
|
50 |
|
51 |
s = Search()
|
52 |
-
pg = int(req.get("page", 1))-1
|
53 |
ps = int(req.get("size", 1000))
|
54 |
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
55 |
"image_id", "doc_id", "q_vec"])
|
56 |
|
57 |
-
s = s.query(bqry)[pg*ps:(pg+1)*ps]
|
58 |
s = s.highlight("content_ltks")
|
59 |
s = s.highlight("title_ltks")
|
60 |
-
if not qst:
|
|
|
|
|
61 |
|
62 |
s = s.highlight_options(
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
s = s.to_dict()
|
70 |
q_vec = []
|
71 |
-
if req.get("vector"):
|
72 |
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
73 |
s["knn"]["filter"] = bqry.to_dict()
|
74 |
del s["highlight"]
|
75 |
q_vec = s["knn"]["query_vector"]
|
76 |
-
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
77 |
print("TOTAL: ", self.es.getTotal(res))
|
78 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
79 |
-
bqry,_ = self.qryr.question(qst, min_match="10%")
|
80 |
-
if req.get("kb_ids"):
|
|
|
81 |
s["query"] = bqry.to_dict()
|
82 |
s["knn"]["filter"] = bqry.to_dict()
|
83 |
s["knn"]["similarity"] = 0.7
|
84 |
-
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
85 |
|
86 |
kwds = set([])
|
87 |
for k in keywords:
|
88 |
kwds.add(k)
|
89 |
for kk in huqie.qieqie(k).split(" "):
|
90 |
-
if len(kk) < 2:
|
91 |
-
|
|
|
|
|
92 |
kwds.add(kk)
|
93 |
|
94 |
aggs = self.getAggregation(res, "docnm_kwd")
|
95 |
|
96 |
return self.SearchResult(
|
97 |
-
total
|
98 |
-
ids
|
99 |
-
query_vector
|
100 |
-
aggregation
|
101 |
-
highlight
|
102 |
-
field
|
103 |
-
|
104 |
-
keywords
|
105 |
)
|
106 |
|
107 |
def getAggregation(self, res, g):
|
108 |
-
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:
|
109 |
-
|
|
|
110 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
111 |
|
112 |
def getHighlight(self, res):
|
@@ -114,8 +127,11 @@ class Dealer:
|
|
114 |
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
115 |
r = []
|
116 |
for t in line.split(" "):
|
117 |
-
if not t:
|
118 |
-
|
|
|
|
|
|
|
119 |
r.append(t)
|
120 |
r = "".join(r)
|
121 |
return r
|
@@ -123,66 +139,76 @@ class Dealer:
|
|
123 |
ans = {}
|
124 |
for d in res["hits"]["hits"]:
|
125 |
hlts = d.get("highlight")
|
126 |
-
if not hlts:
|
|
|
127 |
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
128 |
return ans
|
129 |
|
130 |
def getFields(self, sres, flds):
|
131 |
res = {}
|
132 |
-
if not flds:
|
133 |
-
|
134 |
-
|
135 |
-
for n
|
136 |
-
|
|
|
137 |
m[n] = "\t".join([str(vv) for vv in v])
|
138 |
continue
|
139 |
-
if
|
|
|
140 |
m[n] = rmSpace(m[n])
|
141 |
|
142 |
-
if m:
|
|
|
143 |
return res
|
144 |
|
145 |
-
|
146 |
@staticmethod
|
147 |
def trans2floats(txt):
|
148 |
return [float(t) for t in txt.split("\t")]
|
149 |
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
155 |
s = 0
|
156 |
e = 0
|
157 |
res = ""
|
|
|
158 |
def citeit():
|
159 |
nonlocal s, e, ans, res
|
160 |
-
if not ins_embd:
|
|
|
161 |
embd = self.emb_mdl.encode(ans[s: e])
|
162 |
-
sim = self.qryr.hybrid_similarity(embd,
|
163 |
-
ins_embd,
|
164 |
huqie.qie(ans[s:e]).split(" "),
|
165 |
ins_tw)
|
166 |
print(ans[s: e], sim)
|
167 |
-
mx = np.max(sim)*0.99
|
168 |
-
if mx < 0.55:
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
|
172 |
return cita
|
173 |
|
174 |
punct = set(";。?!!")
|
175 |
-
if not self.qryr.isChinese(ans):
|
176 |
punct.add("?")
|
177 |
punct.add(".")
|
178 |
while e < len(ans):
|
179 |
if e - s < 12 or ans[e] not in punct:
|
180 |
e += 1
|
181 |
continue
|
182 |
-
if ans[e] == "." and e+
|
|
|
183 |
e += 1
|
184 |
continue
|
185 |
-
if ans[e] == "." and e-2>=0 and
|
186 |
e += 1
|
187 |
continue
|
188 |
res += ans[s: e]
|
@@ -191,33 +217,36 @@ class Dealer:
|
|
191 |
e += 1
|
192 |
s = e
|
193 |
|
194 |
-
if s< len(ans):
|
195 |
res += ans[s:]
|
196 |
citeit()
|
197 |
|
198 |
return res
|
199 |
|
200 |
-
|
201 |
-
|
202 |
-
ins_embd = [
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
208 |
huqie.qie(query).split(" "),
|
209 |
ins_tw, tkweight, vtweight)
|
210 |
return sim
|
211 |
|
212 |
|
213 |
-
|
214 |
-
if __name__ == "__main__":
|
215 |
from util import es_conn
|
216 |
SE = Dealer(es_conn.HuEs("infiniflow"))
|
217 |
qs = [
|
218 |
"胡凯",
|
219 |
""
|
220 |
]
|
221 |
-
for q in qs:
|
222 |
print(">>>>>>>>>>>>>>>>>>>>", q)
|
223 |
-
print(SE.search(
|
|
|
|
1 |
import re
|
2 |
+
from elasticsearch_dsl import Q, Search, A
|
3 |
+
from typing import List, Optional, Tuple, Dict, Union
|
4 |
from dataclasses import dataclass
|
5 |
from util import setup_logging, rmSpace
|
6 |
from nlp import huqie, query
|
|
|
9 |
import numpy as np
|
10 |
from copy import deepcopy
|
11 |
|
12 |
+
|
13 |
+
def index_name(uid): return f"docgpt_{uid}"
|
14 |
+
|
15 |
|
16 |
class Dealer:
|
17 |
def __init__(self, es, emb_mdl):
|
18 |
self.qryr = query.EsQueryer(es)
|
19 |
+
self.qryr.flds = [
|
20 |
+
"title_tks^10",
|
21 |
+
"title_sm_tks^5",
|
22 |
+
"content_ltks^2",
|
23 |
+
"content_sm_ltks"]
|
24 |
self.es = es
|
25 |
self.emb_mdl = emb_mdl
|
26 |
|
27 |
@dataclass
|
28 |
class SearchResult:
|
29 |
+
total: int
|
30 |
ids: List[str]
|
31 |
query_vector: List[float] = None
|
32 |
field: Optional[Dict] = None
|
|
|
48 |
keywords = []
|
49 |
qst = req.get("question", "")
|
50 |
|
51 |
+
bqry, keywords = self.qryr.question(qst)
|
52 |
+
if req.get("kb_ids"):
|
53 |
+
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
54 |
bqry.filter.append(Q("exists", field="q_tks"))
|
55 |
bqry.boost = 0.05
|
56 |
print(bqry)
|
57 |
|
58 |
s = Search()
|
59 |
+
pg = int(req.get("page", 1)) - 1
|
60 |
ps = int(req.get("size", 1000))
|
61 |
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
62 |
"image_id", "doc_id", "q_vec"])
|
63 |
|
64 |
+
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
65 |
s = s.highlight("content_ltks")
|
66 |
s = s.highlight("title_ltks")
|
67 |
+
if not qst:
|
68 |
+
s = s.sort(
|
69 |
+
{"create_time": {"order": "desc", "unmapped_type": "date"}})
|
70 |
|
71 |
s = s.highlight_options(
|
72 |
+
fragment_size=120,
|
73 |
+
number_of_fragments=5,
|
74 |
+
boundary_scanner_locale="zh-CN",
|
75 |
+
boundary_scanner="SENTENCE",
|
76 |
+
boundary_chars=",./;:\\!(),。?:!……()——、"
|
77 |
+
)
|
78 |
s = s.to_dict()
|
79 |
q_vec = []
|
80 |
+
if req.get("vector"):
|
81 |
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
82 |
s["knn"]["filter"] = bqry.to_dict()
|
83 |
del s["highlight"]
|
84 |
q_vec = s["knn"]["query_vector"]
|
85 |
+
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
86 |
print("TOTAL: ", self.es.getTotal(res))
|
87 |
if self.es.getTotal(res) == 0 and "knn" in s:
|
88 |
+
bqry, _ = self.qryr.question(qst, min_match="10%")
|
89 |
+
if req.get("kb_ids"):
|
90 |
+
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
91 |
s["query"] = bqry.to_dict()
|
92 |
s["knn"]["filter"] = bqry.to_dict()
|
93 |
s["knn"]["similarity"] = 0.7
|
94 |
+
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
95 |
|
96 |
kwds = set([])
|
97 |
for k in keywords:
|
98 |
kwds.add(k)
|
99 |
for kk in huqie.qieqie(k).split(" "):
|
100 |
+
if len(kk) < 2:
|
101 |
+
continue
|
102 |
+
if kk in kwds:
|
103 |
+
continue
|
104 |
kwds.add(kk)
|
105 |
|
106 |
aggs = self.getAggregation(res, "docnm_kwd")
|
107 |
|
108 |
return self.SearchResult(
|
109 |
+
total=self.es.getTotal(res),
|
110 |
+
ids=self.es.getDocIds(res),
|
111 |
+
query_vector=q_vec,
|
112 |
+
aggregation=aggs,
|
113 |
+
highlight=self.getHighlight(res),
|
114 |
+
field=self.getFields(res, ["docnm_kwd", "content_ltks",
|
115 |
+
"kb_id", "image_id", "doc_id", "q_vec"]),
|
116 |
+
keywords=list(kwds)
|
117 |
)
|
118 |
|
119 |
def getAggregation(self, res, g):
|
120 |
+
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
121 |
+
return
|
122 |
+
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
123 |
return [(b["key"], b["doc_count"]) for b in bkts]
|
124 |
|
125 |
def getHighlight(self, res):
|
|
|
127 |
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
128 |
r = []
|
129 |
for t in line.split(" "):
|
130 |
+
if not t:
|
131 |
+
continue
|
132 |
+
if len(r) > 0 and len(
|
133 |
+
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
134 |
+
r.append(" ")
|
135 |
r.append(t)
|
136 |
r = "".join(r)
|
137 |
return r
|
|
|
139 |
ans = {}
|
140 |
for d in res["hits"]["hits"]:
|
141 |
hlts = d.get("highlight")
|
142 |
+
if not hlts:
|
143 |
+
continue
|
144 |
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
145 |
return ans
|
146 |
|
147 |
def getFields(self, sres, flds):
|
148 |
res = {}
|
149 |
+
if not flds:
|
150 |
+
return {}
|
151 |
+
for d in self.es.getSource(sres):
|
152 |
+
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
153 |
+
for n, v in m.items():
|
154 |
+
if isinstance(v, type([])):
|
155 |
m[n] = "\t".join([str(vv) for vv in v])
|
156 |
continue
|
157 |
+
if not isinstance(v, type("")):
|
158 |
+
m[n] = str(m[n])
|
159 |
m[n] = rmSpace(m[n])
|
160 |
|
161 |
+
if m:
|
162 |
+
res[d["id"]] = m
|
163 |
return res
|
164 |
|
|
|
165 |
@staticmethod
|
166 |
def trans2floats(txt):
|
167 |
return [float(t) for t in txt.split("\t")]
|
168 |
|
169 |
+
def insert_citations(self, ans, top_idx, sres,
|
170 |
+
vfield="q_vec", cfield="content_ltks"):
|
171 |
|
172 |
+
ins_embd = [Dealer.trans2floats(
|
173 |
+
sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
174 |
+
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
|
|
175 |
s = 0
|
176 |
e = 0
|
177 |
res = ""
|
178 |
+
|
179 |
def citeit():
|
180 |
nonlocal s, e, ans, res
|
181 |
+
if not ins_embd:
|
182 |
+
return
|
183 |
embd = self.emb_mdl.encode(ans[s: e])
|
184 |
+
sim = self.qryr.hybrid_similarity(embd,
|
185 |
+
ins_embd,
|
186 |
huqie.qie(ans[s:e]).split(" "),
|
187 |
ins_tw)
|
188 |
print(ans[s: e], sim)
|
189 |
+
mx = np.max(sim) * 0.99
|
190 |
+
if mx < 0.55:
|
191 |
+
return
|
192 |
+
cita = list(set([top_idx[i]
|
193 |
+
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
|
194 |
+
for i in cita:
|
195 |
+
res += f"@?{i}?@"
|
196 |
|
197 |
return cita
|
198 |
|
199 |
punct = set(";。?!!")
|
200 |
+
if not self.qryr.isChinese(ans):
|
201 |
punct.add("?")
|
202 |
punct.add(".")
|
203 |
while e < len(ans):
|
204 |
if e - s < 12 or ans[e] not in punct:
|
205 |
e += 1
|
206 |
continue
|
207 |
+
if ans[e] == "." and e + \
|
208 |
+
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
|
209 |
e += 1
|
210 |
continue
|
211 |
+
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
|
212 |
e += 1
|
213 |
continue
|
214 |
res += ans[s: e]
|
|
|
217 |
e += 1
|
218 |
s = e
|
219 |
|
220 |
+
if s < len(ans):
|
221 |
res += ans[s:]
|
222 |
citeit()
|
223 |
|
224 |
return res
|
225 |
|
226 |
+
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
|
227 |
+
vfield="q_vec", cfield="content_ltks"):
|
228 |
+
ins_embd = [
|
229 |
+
Dealer.trans2floats(
|
230 |
+
sres.field[i]["q_vec"]) for i in sres.ids]
|
231 |
+
if not ins_embd:
|
232 |
+
return []
|
233 |
+
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
|
234 |
+
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
235 |
+
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
236 |
+
ins_embd,
|
237 |
huqie.qie(query).split(" "),
|
238 |
ins_tw, tkweight, vtweight)
|
239 |
return sim
|
240 |
|
241 |
|
242 |
+
if __name__ == "__main__":
|
|
|
243 |
from util import es_conn
|
244 |
SE = Dealer(es_conn.HuEs("infiniflow"))
|
245 |
qs = [
|
246 |
"胡凯",
|
247 |
""
|
248 |
]
|
249 |
+
for q in qs:
|
250 |
print(">>>>>>>>>>>>>>>>>>>>", q)
|
251 |
+
print(SE.search(
|
252 |
+
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
python/parser/excel_parser.py
CHANGED
@@ -5,8 +5,10 @@ from io import BytesIO
|
|
5 |
|
6 |
class HuExcelParser:
|
7 |
def __call__(self, fnm):
|
8 |
-
if isinstance(fnm, str):
|
9 |
-
|
|
|
|
|
10 |
res = []
|
11 |
for sheetname in wb.sheetnames:
|
12 |
ws = wb[sheetname]
|
|
|
5 |
|
6 |
class HuExcelParser:
|
7 |
def __call__(self, fnm):
|
8 |
+
if isinstance(fnm, str):
|
9 |
+
wb = load_workbook(fnm)
|
10 |
+
else:
|
11 |
+
wb = load_workbook(BytesIO(fnm))
|
12 |
res = []
|
13 |
for sheetname in wb.sheetnames:
|
14 |
ws = wb[sheetname]
|
python/parser/pdf_parser.py
CHANGED
@@ -53,7 +53,7 @@ class HuParser:
|
|
53 |
def _y_dis(
|
54 |
self, a, b):
|
55 |
return (
|
56 |
-
|
57 |
|
58 |
def _match_proj(self, b):
|
59 |
proj_patt = [
|
@@ -76,9 +76,9 @@ class HuParser:
|
|
76 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
77 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
78 |
tks_all = up["text"][-LEN:].strip() \
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
tks_all = huqie.qie(tks_all).split(" ")
|
83 |
fea = [
|
84 |
up.get("R", -1) == down.get("R", -1),
|
@@ -100,7 +100,7 @@ class HuParser:
|
|
100 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
101 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
102 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
103 |
-
|
104 |
self._match_proj(down),
|
105 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
106 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
@@ -217,7 +217,7 @@ class HuParser:
|
|
217 |
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
218 |
tp, btm, x0, x1, b)
|
219 |
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
220 |
-
|
221 |
if ov > 0 and ratio:
|
222 |
ov /= (x1 - x0) * (btm - tp)
|
223 |
return ov
|
@@ -382,7 +382,7 @@ class HuParser:
|
|
382 |
continue
|
383 |
for tb in tbls: # for table
|
384 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
385 |
-
|
386 |
left *= ZM
|
387 |
top *= ZM
|
388 |
right *= ZM
|
@@ -899,7 +899,7 @@ class HuParser:
|
|
899 |
lst_r = rows[-1]
|
900 |
if lst_r[-1].get("R", "") != b.get("R", "") \
|
901 |
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
902 |
-
|
903 |
btm = b["bottom"]
|
904 |
b["rn"] += 1
|
905 |
rows.append([b])
|
@@ -949,9 +949,9 @@ class HuParser:
|
|
949 |
j += 1
|
950 |
continue
|
951 |
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
952 |
-
|
953 |
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
954 |
-
|
955 |
if f and ff:
|
956 |
j += 1
|
957 |
continue
|
@@ -1012,9 +1012,9 @@ class HuParser:
|
|
1012 |
i += 1
|
1013 |
continue
|
1014 |
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
1015 |
-
|
1016 |
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
1017 |
-
|
1018 |
if f and ff:
|
1019 |
i += 1
|
1020 |
continue
|
@@ -1169,8 +1169,8 @@ class HuParser:
|
|
1169 |
else "") + headers[j - 1][k]
|
1170 |
else:
|
1171 |
headers[j][k] = headers[j - 1][k] \
|
1172 |
-
|
1173 |
-
|
1174 |
|
1175 |
logging.debug(
|
1176 |
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
@@ -1247,7 +1247,7 @@ class HuParser:
|
|
1247 |
i += 1
|
1248 |
continue
|
1249 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
1250 |
-
|
1251 |
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
|
1252 |
"figure caption", "reference"]:
|
1253 |
nomerge_lout_no.append(lst_lout_no)
|
@@ -1526,7 +1526,8 @@ class HuParser:
|
|
1526 |
return "\n\n".join(res)
|
1527 |
|
1528 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1529 |
-
self.pdf = pdfplumber.open(fnm) if isinstance(
|
|
|
1530 |
self.lefted_chars = []
|
1531 |
self.mean_height = []
|
1532 |
self.mean_width = []
|
@@ -1601,7 +1602,7 @@ class HuParser:
|
|
1601 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
1602 |
right *
|
1603 |
ZM, min(
|
1604 |
-
|
1605 |
))
|
1606 |
)
|
1607 |
bottom -= self.page_images[pns[0]].size[1]
|
|
|
53 |
def _y_dis(
|
54 |
self, a, b):
|
55 |
return (
|
56 |
+
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
57 |
|
58 |
def _match_proj(self, b):
|
59 |
proj_patt = [
|
|
|
76 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
77 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
78 |
tks_all = up["text"][-LEN:].strip() \
|
79 |
+
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
80 |
+
up["text"][-1] + down["text"][0]) else "") \
|
81 |
+
+ down["text"][:LEN].strip()
|
82 |
tks_all = huqie.qie(tks_all).split(" ")
|
83 |
fea = [
|
84 |
up.get("R", -1) == down.get("R", -1),
|
|
|
100 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
101 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
102 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
103 |
+
and re.search(r"[\))]", down["text"]) else False,
|
104 |
self._match_proj(down),
|
105 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
106 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
|
|
217 |
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
218 |
tp, btm, x0, x1, b)
|
219 |
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
220 |
+
x0 != 0 and btm - tp != 0 else 0
|
221 |
if ov > 0 and ratio:
|
222 |
ov /= (x1 - x0) * (btm - tp)
|
223 |
return ov
|
|
|
382 |
continue
|
383 |
for tb in tbls: # for table
|
384 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
385 |
+
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
386 |
left *= ZM
|
387 |
top *= ZM
|
388 |
right *= ZM
|
|
|
899 |
lst_r = rows[-1]
|
900 |
if lst_r[-1].get("R", "") != b.get("R", "") \
|
901 |
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
902 |
+
): # new row
|
903 |
btm = b["bottom"]
|
904 |
b["rn"] += 1
|
905 |
rows.append([b])
|
|
|
949 |
j += 1
|
950 |
continue
|
951 |
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
952 |
+
[j - 1][0].get("text")) or j == 0
|
953 |
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
954 |
+
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
955 |
if f and ff:
|
956 |
j += 1
|
957 |
continue
|
|
|
1012 |
i += 1
|
1013 |
continue
|
1014 |
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
1015 |
+
[jj][0].get("text")) or i == 0
|
1016 |
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
1017 |
+
[jj][0].get("text")) or i + 1 >= len(tbl)
|
1018 |
if f and ff:
|
1019 |
i += 1
|
1020 |
continue
|
|
|
1169 |
else "") + headers[j - 1][k]
|
1170 |
else:
|
1171 |
headers[j][k] = headers[j - 1][k] \
|
1172 |
+
+ ("的" if headers[j - 1][k] else "") \
|
1173 |
+
+ headers[j][k]
|
1174 |
|
1175 |
logging.debug(
|
1176 |
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
|
|
1247 |
i += 1
|
1248 |
continue
|
1249 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
1250 |
+
"-" + str(self.boxes[i]["layoutno"])
|
1251 |
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
|
1252 |
"figure caption", "reference"]:
|
1253 |
nomerge_lout_no.append(lst_lout_no)
|
|
|
1526 |
return "\n\n".join(res)
|
1527 |
|
1528 |
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
1529 |
+
self.pdf = pdfplumber.open(fnm) if isinstance(
|
1530 |
+
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
1531 |
self.lefted_chars = []
|
1532 |
self.mean_height = []
|
1533 |
self.mean_width = []
|
|
|
1602 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
1603 |
right *
|
1604 |
ZM, min(
|
1605 |
+
bottom, self.page_images[pns[0]].size[1])
|
1606 |
))
|
1607 |
)
|
1608 |
bottom -= self.page_images[pns[0]].size[1]
|
python/svr/dialog_svr.py
CHANGED
@@ -16,11 +16,12 @@ from io import BytesIO
|
|
16 |
from util import config
|
17 |
from timeit import default_timer as timer
|
18 |
from collections import OrderedDict
|
|
|
19 |
|
20 |
SE = None
|
21 |
CFIELD="content_ltks"
|
22 |
-
EMBEDDING =
|
23 |
-
LLM =
|
24 |
|
25 |
def get_QA_pairs(hists):
|
26 |
pa = []
|
|
|
16 |
from util import config
|
17 |
from timeit import default_timer as timer
|
18 |
from collections import OrderedDict
|
19 |
+
from llm import ChatModel, EmbeddingModel
|
20 |
|
21 |
SE = None
|
22 |
CFIELD="content_ltks"
|
23 |
+
EMBEDDING = EmbeddingModel
|
24 |
+
LLM = ChatModel
|
25 |
|
26 |
def get_QA_pairs(hists):
|
27 |
pa = []
|
python/svr/parse_user_docs.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import json, os, sys, hashlib, copy, time, random, re
|
2 |
from os.path import dirname, realpath
|
3 |
sys.path.append(dirname(realpath(__file__)) + "/../")
|
4 |
from util.es_conn import HuEs
|
@@ -7,10 +7,10 @@ from util.minio_conn import HuMinio
|
|
7 |
from util import rmSpace, findMaxDt
|
8 |
from FlagEmbedding import FlagModel
|
9 |
from nlp import huchunk, huqie, search
|
10 |
-
import base64, hashlib
|
11 |
from io import BytesIO
|
12 |
import pandas as pd
|
13 |
from elasticsearch_dsl import Q
|
|
|
14 |
from parser import (
|
15 |
PdfParser,
|
16 |
DocxParser,
|
@@ -40,6 +40,15 @@ def chuck_doc(name, binary):
|
|
40 |
if suff.find("doc") >= 0: return DOC(binary)
|
41 |
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
42 |
if suff.find("ppt") >= 0: return PPT(binary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
return TextChunker()(binary)
|
45 |
|
@@ -119,7 +128,6 @@ def build(row):
|
|
119 |
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
120 |
return []
|
121 |
|
122 |
-
print(row["doc_name"], obj)
|
123 |
if not obj.text_chunks and not obj.table_chunks:
|
124 |
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
125 |
return []
|
@@ -146,7 +154,10 @@ def build(row):
|
|
146 |
if not img:
|
147 |
docs.append(d)
|
148 |
continue
|
149 |
-
|
|
|
|
|
|
|
150 |
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
151 |
output_buffer.getvalue())
|
152 |
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
|
|
1 |
+
import json, os, sys, hashlib, copy, time, random, re
|
2 |
from os.path import dirname, realpath
|
3 |
sys.path.append(dirname(realpath(__file__)) + "/../")
|
4 |
from util.es_conn import HuEs
|
|
|
7 |
from util import rmSpace, findMaxDt
|
8 |
from FlagEmbedding import FlagModel
|
9 |
from nlp import huchunk, huqie, search
|
|
|
10 |
from io import BytesIO
|
11 |
import pandas as pd
|
12 |
from elasticsearch_dsl import Q
|
13 |
+
from PIL import Image
|
14 |
from parser import (
|
15 |
PdfParser,
|
16 |
DocxParser,
|
|
|
40 |
if suff.find("doc") >= 0: return DOC(binary)
|
41 |
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
42 |
if suff.find("ppt") >= 0: return PPT(binary)
|
43 |
+
if os.envirement.get("PARSE_IMAGE") \
|
44 |
+
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
45 |
+
name.lower()):
|
46 |
+
from llm import CvModel
|
47 |
+
txt = CvModel.describe(binary)
|
48 |
+
field = TextChunker.Fields()
|
49 |
+
field.text_chunks = [(txt, binary)]
|
50 |
+
field.table_chunks = []
|
51 |
+
|
52 |
|
53 |
return TextChunker()(binary)
|
54 |
|
|
|
128 |
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
129 |
return []
|
130 |
|
|
|
131 |
if not obj.text_chunks and not obj.table_chunks:
|
132 |
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
133 |
return []
|
|
|
154 |
if not img:
|
155 |
docs.append(d)
|
156 |
continue
|
157 |
+
|
158 |
+
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
|
159 |
+
else: output_buffer = BytesIO(img)
|
160 |
+
|
161 |
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
162 |
output_buffer.getvalue())
|
163 |
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
python/util/__init__.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
import re
|
2 |
|
|
|
3 |
def rmSpace(txt):
|
4 |
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
5 |
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
6 |
|
|
|
7 |
def findMaxDt(fnm):
|
8 |
m = "1970-01-01 00:00:00"
|
9 |
try:
|
10 |
with open(fnm, "r") as f:
|
11 |
while True:
|
12 |
l = f.readline()
|
13 |
-
if not l:
|
|
|
14 |
l = l.strip("\n")
|
15 |
-
if l == 'nan':
|
16 |
-
|
|
|
|
|
17 |
except Exception as e:
|
18 |
-
print("WARNING: can't find "+ fnm)
|
19 |
return m
|
|
|
1 |
import re
|
2 |
|
3 |
+
|
4 |
def rmSpace(txt):
|
5 |
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
6 |
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
7 |
|
8 |
+
|
9 |
def findMaxDt(fnm):
|
10 |
m = "1970-01-01 00:00:00"
|
11 |
try:
|
12 |
with open(fnm, "r") as f:
|
13 |
while True:
|
14 |
l = f.readline()
|
15 |
+
if not l:
|
16 |
+
break
|
17 |
l = l.strip("\n")
|
18 |
+
if l == 'nan':
|
19 |
+
continue
|
20 |
+
if l > m:
|
21 |
+
m = l
|
22 |
except Exception as e:
|
23 |
+
print("WARNING: can't find " + fnm)
|
24 |
return m
|
python/util/config.py
CHANGED
@@ -1,25 +1,31 @@
|
|
1 |
-
from configparser
|
2 |
-
import os
|
|
|
3 |
|
4 |
CF = ConfigParser()
|
5 |
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
6 |
-
if not os.path.exists(__fnm):
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
CF.read(__fnm)
|
11 |
|
|
|
12 |
class Config:
|
13 |
def __init__(self, env):
|
14 |
self.env = env
|
15 |
-
if env == "spark":
|
|
|
16 |
|
17 |
def get(self, key, default=None):
|
18 |
global CF
|
19 |
-
return os.environ.get(key.upper(),
|
20 |
-
|
21 |
-
|
|
|
22 |
|
23 |
def init(env):
|
24 |
return Config(env)
|
25 |
-
|
|
|
1 |
+
from configparser import ConfigParser
|
2 |
+
import os
|
3 |
+
import inspect
|
4 |
|
5 |
CF = ConfigParser()
|
6 |
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
7 |
+
if not os.path.exists(__fnm):
|
8 |
+
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
9 |
+
assert os.path.exists(
|
10 |
+
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
11 |
+
if not os.path.exists(__fnm):
|
12 |
+
__fnm = "./sys.cnf"
|
13 |
|
14 |
CF.read(__fnm)
|
15 |
|
16 |
+
|
17 |
class Config:
|
18 |
def __init__(self, env):
|
19 |
self.env = env
|
20 |
+
if env == "spark":
|
21 |
+
CF.read("./cv.cnf")
|
22 |
|
23 |
def get(self, key, default=None):
|
24 |
global CF
|
25 |
+
return os.environ.get(key.upper(),
|
26 |
+
CF[self.env].get(key, default)
|
27 |
+
)
|
28 |
+
|
29 |
|
30 |
def init(env):
|
31 |
return Config(env)
|
|
python/util/db_conn.py
CHANGED
@@ -3,6 +3,7 @@ import time
|
|
3 |
from util import config
|
4 |
import pandas as pd
|
5 |
|
|
|
6 |
class Postgres(object):
|
7 |
def __init__(self, env, dbnm):
|
8 |
self.config = config.init(env)
|
@@ -13,36 +14,42 @@ class Postgres(object):
|
|
13 |
def __open__(self):
|
14 |
import psycopg2
|
15 |
try:
|
16 |
-
if self.conn:
|
|
|
17 |
del self.conn
|
18 |
except Exception as e:
|
19 |
pass
|
20 |
|
21 |
try:
|
22 |
-
self.conn = psycopg2.connect(f"dbname={self.dbnm}
|
|
|
|
|
|
|
|
|
23 |
except Exception as e:
|
24 |
-
logging.error(
|
25 |
-
|
|
|
26 |
|
27 |
def __close__(self):
|
28 |
try:
|
29 |
self.conn.close()
|
30 |
except Exception as e:
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
def select(self, sql):
|
35 |
for _ in range(10):
|
36 |
try:
|
37 |
return pd.read_sql(sql, self.conn)
|
38 |
except Exception as e:
|
39 |
-
logging.error(f"Fail to exec {sql} "+str(e))
|
40 |
self.__open__()
|
41 |
time.sleep(1)
|
42 |
|
43 |
return pd.DataFrame()
|
44 |
|
45 |
-
|
46 |
def update(self, sql):
|
47 |
for _ in range(10):
|
48 |
try:
|
@@ -53,11 +60,11 @@ class Postgres(object):
|
|
53 |
cur.close()
|
54 |
return updated_rows
|
55 |
except Exception as e:
|
56 |
-
logging.error(f"Fail to exec {sql} "+str(e))
|
57 |
self.__open__()
|
58 |
time.sleep(1)
|
59 |
return 0
|
60 |
|
|
|
61 |
if __name__ == "__main__":
|
62 |
Postgres("infiniflow", "docgpt")
|
63 |
-
|
|
|
3 |
from util import config
|
4 |
import pandas as pd
|
5 |
|
6 |
+
|
7 |
class Postgres(object):
|
8 |
def __init__(self, env, dbnm):
|
9 |
self.config = config.init(env)
|
|
|
14 |
def __open__(self):
|
15 |
import psycopg2
|
16 |
try:
|
17 |
+
if self.conn:
|
18 |
+
self.__close__()
|
19 |
del self.conn
|
20 |
except Exception as e:
|
21 |
pass
|
22 |
|
23 |
try:
|
24 |
+
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
25 |
+
user={self.config.get('postgres_user')}
|
26 |
+
password={self.config.get('postgres_password')}
|
27 |
+
host={self.config.get('postgres_host')}
|
28 |
+
port={self.config.get('postgres_port')}""")
|
29 |
except Exception as e:
|
30 |
+
logging.error(
|
31 |
+
"Fail to connect %s " %
|
32 |
+
self.config.get("pgdb_host") + str(e))
|
33 |
|
34 |
def __close__(self):
|
35 |
try:
|
36 |
self.conn.close()
|
37 |
except Exception as e:
|
38 |
+
logging.error(
|
39 |
+
"Fail to close %s " %
|
40 |
+
self.config.get("pgdb_host") + str(e))
|
41 |
|
42 |
def select(self, sql):
|
43 |
for _ in range(10):
|
44 |
try:
|
45 |
return pd.read_sql(sql, self.conn)
|
46 |
except Exception as e:
|
47 |
+
logging.error(f"Fail to exec {sql} " + str(e))
|
48 |
self.__open__()
|
49 |
time.sleep(1)
|
50 |
|
51 |
return pd.DataFrame()
|
52 |
|
|
|
53 |
def update(self, sql):
|
54 |
for _ in range(10):
|
55 |
try:
|
|
|
60 |
cur.close()
|
61 |
return updated_rows
|
62 |
except Exception as e:
|
63 |
+
logging.error(f"Fail to exec {sql} " + str(e))
|
64 |
self.__open__()
|
65 |
time.sleep(1)
|
66 |
return 0
|
67 |
|
68 |
+
|
69 |
if __name__ == "__main__":
|
70 |
Postgres("infiniflow", "docgpt")
|
|
python/util/es_conn.py
CHANGED
@@ -228,7 +228,8 @@ class HuEs:
|
|
228 |
return False
|
229 |
|
230 |
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
231 |
-
if not isinstance(q, dict):
|
|
|
232 |
for i in range(3):
|
233 |
try:
|
234 |
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
@@ -274,9 +275,10 @@ class HuEs:
|
|
274 |
|
275 |
return False
|
276 |
|
277 |
-
|
278 |
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
279 |
-
ubq = UpdateByQuery(
|
|
|
|
|
280 |
ubq = ubq.script(source=scripts)
|
281 |
ubq = ubq.params(refresh=True)
|
282 |
ubq = ubq.params(slices=5)
|
@@ -294,7 +296,6 @@ class HuEs:
|
|
294 |
|
295 |
return False
|
296 |
|
297 |
-
|
298 |
def deleteByQuery(self, query, idxnm=""):
|
299 |
for i in range(3):
|
300 |
try:
|
@@ -392,7 +393,7 @@ class HuEs:
|
|
392 |
return rr
|
393 |
|
394 |
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
395 |
-
|
396 |
for _ in range(100):
|
397 |
try:
|
398 |
page = self.es.search(
|
|
|
228 |
return False
|
229 |
|
230 |
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
231 |
+
if not isinstance(q, dict):
|
232 |
+
q = Search().query(q).to_dict()
|
233 |
for i in range(3):
|
234 |
try:
|
235 |
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
|
|
275 |
|
276 |
return False
|
277 |
|
|
|
278 |
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
279 |
+
ubq = UpdateByQuery(
|
280 |
+
index=self.idxnm if not idxnm else idxnm).using(
|
281 |
+
self.es).query(q)
|
282 |
ubq = ubq.script(source=scripts)
|
283 |
ubq = ubq.params(refresh=True)
|
284 |
ubq = ubq.params(slices=5)
|
|
|
296 |
|
297 |
return False
|
298 |
|
|
|
299 |
def deleteByQuery(self, query, idxnm=""):
|
300 |
for i in range(3):
|
301 |
try:
|
|
|
393 |
return rr
|
394 |
|
395 |
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
396 |
+
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
397 |
for _ in range(100):
|
398 |
try:
|
399 |
page = self.es.search(
|
python/util/minio_conn.py
CHANGED
@@ -4,6 +4,7 @@ from util import config
|
|
4 |
from minio import Minio
|
5 |
from io import BytesIO
|
6 |
|
|
|
7 |
class HuMinio(object):
|
8 |
def __init__(self, env):
|
9 |
self.config = config.init(env)
|
@@ -12,64 +13,62 @@ class HuMinio(object):
|
|
12 |
|
13 |
def __open__(self):
|
14 |
try:
|
15 |
-
if self.conn:
|
|
|
16 |
except Exception as e:
|
17 |
pass
|
18 |
|
19 |
try:
|
20 |
self.conn = Minio(self.config.get("minio_host"),
|
21 |
-
access_key=self.config.get("
|
22 |
-
secret_key=self.config.get("
|
23 |
secure=False
|
24 |
-
|
25 |
except Exception as e:
|
26 |
-
logging.error(
|
27 |
-
|
|
|
28 |
|
29 |
def __close__(self):
|
30 |
del self.conn
|
31 |
self.conn = None
|
32 |
|
33 |
-
|
34 |
def put(self, bucket, fnm, binary):
|
35 |
for _ in range(10):
|
36 |
try:
|
37 |
if not self.conn.bucket_exists(bucket):
|
38 |
self.conn.make_bucket(bucket)
|
39 |
|
40 |
-
r = self.conn.put_object(bucket, fnm,
|
41 |
BytesIO(binary),
|
42 |
len(binary)
|
43 |
-
|
44 |
return r
|
45 |
except Exception as e:
|
46 |
-
logging.error(f"Fail put {bucket}/{fnm}: "+str(e))
|
47 |
self.__open__()
|
48 |
time.sleep(1)
|
49 |
|
50 |
-
|
51 |
def get(self, bucket, fnm):
|
52 |
for _ in range(10):
|
53 |
try:
|
54 |
r = self.conn.get_object(bucket, fnm)
|
55 |
return r.read()
|
56 |
except Exception as e:
|
57 |
-
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
|
58 |
self.__open__()
|
59 |
time.sleep(1)
|
60 |
-
return
|
61 |
-
|
62 |
|
63 |
def get_presigned_url(self, bucket, fnm, expires):
|
64 |
for _ in range(10):
|
65 |
try:
|
66 |
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
67 |
except Exception as e:
|
68 |
-
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
|
69 |
self.__open__()
|
70 |
time.sleep(1)
|
71 |
-
return
|
72 |
-
|
73 |
|
74 |
|
75 |
if __name__ == "__main__":
|
@@ -78,9 +77,8 @@ if __name__ == "__main__":
|
|
78 |
from PIL import Image
|
79 |
img = Image.open(fnm)
|
80 |
buff = BytesIO()
|
81 |
-
img.save(buff, format='JPEG')
|
82 |
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
83 |
bts = conn.get("test", "11-408.jpg")
|
84 |
img = Image.open(BytesIO(bts))
|
85 |
img.save("test.jpg")
|
86 |
-
|
|
|
4 |
from minio import Minio
|
5 |
from io import BytesIO
|
6 |
|
7 |
+
|
8 |
class HuMinio(object):
|
9 |
def __init__(self, env):
|
10 |
self.config = config.init(env)
|
|
|
13 |
|
14 |
def __open__(self):
|
15 |
try:
|
16 |
+
if self.conn:
|
17 |
+
self.__close__()
|
18 |
except Exception as e:
|
19 |
pass
|
20 |
|
21 |
try:
|
22 |
self.conn = Minio(self.config.get("minio_host"),
|
23 |
+
access_key=self.config.get("minio_user"),
|
24 |
+
secret_key=self.config.get("minio_password"),
|
25 |
secure=False
|
26 |
+
)
|
27 |
except Exception as e:
|
28 |
+
logging.error(
|
29 |
+
"Fail to connect %s " %
|
30 |
+
self.config.get("minio_host") + str(e))
|
31 |
|
32 |
def __close__(self):
|
33 |
del self.conn
|
34 |
self.conn = None
|
35 |
|
|
|
36 |
def put(self, bucket, fnm, binary):
|
37 |
for _ in range(10):
|
38 |
try:
|
39 |
if not self.conn.bucket_exists(bucket):
|
40 |
self.conn.make_bucket(bucket)
|
41 |
|
42 |
+
r = self.conn.put_object(bucket, fnm,
|
43 |
BytesIO(binary),
|
44 |
len(binary)
|
45 |
+
)
|
46 |
return r
|
47 |
except Exception as e:
|
48 |
+
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
49 |
self.__open__()
|
50 |
time.sleep(1)
|
51 |
|
|
|
52 |
def get(self, bucket, fnm):
|
53 |
for _ in range(10):
|
54 |
try:
|
55 |
r = self.conn.get_object(bucket, fnm)
|
56 |
return r.read()
|
57 |
except Exception as e:
|
58 |
+
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
59 |
self.__open__()
|
60 |
time.sleep(1)
|
61 |
+
return
|
|
|
62 |
|
63 |
def get_presigned_url(self, bucket, fnm, expires):
|
64 |
for _ in range(10):
|
65 |
try:
|
66 |
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
67 |
except Exception as e:
|
68 |
+
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
69 |
self.__open__()
|
70 |
time.sleep(1)
|
71 |
+
return
|
|
|
72 |
|
73 |
|
74 |
if __name__ == "__main__":
|
|
|
77 |
from PIL import Image
|
78 |
img = Image.open(fnm)
|
79 |
buff = BytesIO()
|
80 |
+
img.save(buff, format='JPEG')
|
81 |
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
82 |
bts = conn.get("test", "11-408.jpg")
|
83 |
img = Image.open(BytesIO(bts))
|
84 |
img.save("test.jpg")
|
|