Kevin Hu
commited on
Commit
·
70153b9
1
Parent(s):
e9078f4
add inputs to display to every components (#3242)
Browse files### What problem does this PR solve?
#3240
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- agent/component/base.py +23 -10
- agent/component/cite.py +0 -75
- agent/component/generate.py +15 -3
- agent/component/keyword.py +3 -6
- api/db/services/dialog_service.py +1 -0
- conf/llm_factories.json +1 -1
- rag/llm/__init__.py +3 -3
- rag/llm/rerank_model.py +1 -0
agent/component/base.py
CHANGED
@@ -37,6 +37,7 @@ class ComponentParamBase(ABC):
|
|
37 |
self.output_var_name = "output"
|
38 |
self.message_history_window_size = 22
|
39 |
self.query = []
|
|
|
40 |
|
41 |
def set_name(self, name: str):
|
42 |
self._name = name
|
@@ -444,8 +445,13 @@ class ComponentBase(ABC):
|
|
444 |
if self._param.query:
|
445 |
outs = []
|
446 |
for q in self._param.query:
|
447 |
-
if q["
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
449 |
if outs:
|
450 |
df = pd.concat(outs, ignore_index=True)
|
451 |
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
@@ -463,31 +469,38 @@ class ComponentBase(ABC):
|
|
463 |
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
464 |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
465 |
if o is not None:
|
|
|
466 |
upstream_outs.append(o)
|
467 |
continue
|
468 |
-
if u not in self._canvas.get_component(self._id)["upstream"]: continue
|
469 |
if self.component_name.lower().find("switch") < 0 \
|
470 |
and self.get_component_name(u) in ["relevant", "categorize"]:
|
471 |
continue
|
472 |
if u.lower().find("answer") >= 0:
|
473 |
for r, c in self._canvas.history[::-1]:
|
474 |
if r == "user":
|
475 |
-
upstream_outs.append(pd.DataFrame([{"content": c}]))
|
476 |
break
|
477 |
break
|
478 |
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
|
479 |
continue
|
480 |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
481 |
if o is not None:
|
|
|
482 |
upstream_outs.append(o)
|
483 |
break
|
484 |
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
def get_stream_input(self):
|
493 |
reversed_cpnts = []
|
|
|
37 |
self.output_var_name = "output"
|
38 |
self.message_history_window_size = 22
|
39 |
self.query = []
|
40 |
+
self.inputs = []
|
41 |
|
42 |
def set_name(self, name: str):
|
43 |
self._name = name
|
|
|
445 |
if self._param.query:
|
446 |
outs = []
|
447 |
for q in self._param.query:
|
448 |
+
if q["component_id"]:
|
449 |
+
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
|
450 |
+
self._param.inputs.append({"component_id": q["component_id"],
|
451 |
+
"content": "\n".join([str(d["content"]) for d in outs[-1].to_dict('records')])})
|
452 |
+
elif q["value"]:
|
453 |
+
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
454 |
+
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
455 |
if outs:
|
456 |
df = pd.concat(outs, ignore_index=True)
|
457 |
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
|
|
469 |
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
470 |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
471 |
if o is not None:
|
472 |
+
o["component_id"] = u
|
473 |
upstream_outs.append(o)
|
474 |
continue
|
475 |
+
if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
|
476 |
if self.component_name.lower().find("switch") < 0 \
|
477 |
and self.get_component_name(u) in ["relevant", "categorize"]:
|
478 |
continue
|
479 |
if u.lower().find("answer") >= 0:
|
480 |
for r, c in self._canvas.history[::-1]:
|
481 |
if r == "user":
|
482 |
+
upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
|
483 |
break
|
484 |
break
|
485 |
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
|
486 |
continue
|
487 |
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
488 |
if o is not None:
|
489 |
+
o["component_id"] = u
|
490 |
upstream_outs.append(o)
|
491 |
break
|
492 |
|
493 |
+
assert upstream_outs, "Can't inference the where the component input is."
|
494 |
+
|
495 |
+
df = pd.concat(upstream_outs, ignore_index=True)
|
496 |
+
if "content" in df:
|
497 |
+
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
498 |
+
|
499 |
+
self._param.inputs = []
|
500 |
+
for _,r in df.iterrows():
|
501 |
+
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
|
502 |
+
|
503 |
+
return df
|
504 |
|
505 |
def get_stream_input(self):
|
506 |
reversed_cpnts = []
|
agent/component/cite.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
#
|
16 |
-
from abc import ABC
|
17 |
-
|
18 |
-
import pandas as pd
|
19 |
-
|
20 |
-
from api.db import LLMType
|
21 |
-
from api.db.services.knowledgebase_service import KnowledgebaseService
|
22 |
-
from api.db.services.llm_service import LLMBundle
|
23 |
-
from api.settings import retrievaler
|
24 |
-
from agent.component.base import ComponentBase, ComponentParamBase
|
25 |
-
|
26 |
-
|
27 |
-
class CiteParam(ComponentParamBase):
|
28 |
-
|
29 |
-
"""
|
30 |
-
Define the Retrieval component parameters.
|
31 |
-
"""
|
32 |
-
def __init__(self):
|
33 |
-
super().__init__()
|
34 |
-
self.cite_sources = []
|
35 |
-
|
36 |
-
def check(self):
|
37 |
-
self.check_empty(self.cite_source, "Please specify where you want to cite from.")
|
38 |
-
|
39 |
-
|
40 |
-
class Cite(ComponentBase, ABC):
|
41 |
-
component_name = "Cite"
|
42 |
-
|
43 |
-
def _run(self, history, **kwargs):
|
44 |
-
input = "\n- ".join(self.get_input()["content"])
|
45 |
-
sources = [self._canvas.get_component(cpn_id).output()[1] for cpn_id in self._param.cite_source]
|
46 |
-
query = []
|
47 |
-
for role, cnt in history[::-1][:self._param.message_history_window_size]:
|
48 |
-
if role != "user":continue
|
49 |
-
query.append(cnt)
|
50 |
-
query = "\n".join(query)
|
51 |
-
|
52 |
-
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
|
53 |
-
if not kbs:
|
54 |
-
raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids))
|
55 |
-
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
56 |
-
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
57 |
-
|
58 |
-
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
59 |
-
|
60 |
-
rerank_mdl = None
|
61 |
-
if self._param.rerank_id:
|
62 |
-
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
63 |
-
|
64 |
-
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
65 |
-
1, self._param.top_n,
|
66 |
-
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
67 |
-
aggs=False, rerank_mdl=rerank_mdl)
|
68 |
-
|
69 |
-
if not kbinfos["chunks"]: return pd.DataFrame()
|
70 |
-
df = pd.DataFrame(kbinfos["chunks"])
|
71 |
-
df["content"] = df["content_with_weight"]
|
72 |
-
del df["content_with_weight"]
|
73 |
-
return df
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/component/generate.py
CHANGED
@@ -101,8 +101,8 @@ class Generate(ComponentBase):
|
|
101 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
102 |
prompt = self._param.prompt
|
103 |
|
104 |
-
retrieval_res =
|
105 |
-
|
106 |
for para in self._param.parameters:
|
107 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
108 |
if cpn.component_name.lower() == "answer":
|
@@ -112,12 +112,24 @@ class Generate(ComponentBase):
|
|
112 |
if "content" not in out.columns:
|
113 |
kwargs[para["key"]] = "Nothing"
|
114 |
else:
|
|
|
|
|
115 |
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
kwargs["input"] = input
|
118 |
for n, v in kwargs.items():
|
119 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
122 |
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
123 |
"obj"].component_name.lower() == "answer":
|
|
|
101 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
102 |
prompt = self._param.prompt
|
103 |
|
104 |
+
retrieval_res = []
|
105 |
+
self._param.inputs = []
|
106 |
for para in self._param.parameters:
|
107 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
108 |
if cpn.component_name.lower() == "answer":
|
|
|
112 |
if "content" not in out.columns:
|
113 |
kwargs[para["key"]] = "Nothing"
|
114 |
else:
|
115 |
+
if cpn.component_name.lower() == "retrieval":
|
116 |
+
retrieval_res.append(out)
|
117 |
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
118 |
+
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
119 |
+
|
120 |
+
if retrieval_res:
|
121 |
+
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
122 |
+
else: retrieval_res = pd.DataFrame([])
|
123 |
|
|
|
124 |
for n, v in kwargs.items():
|
125 |
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
126 |
|
127 |
+
if not self._param.inputs and prompt.find("{input}") >= 0:
|
128 |
+
retrieval_res = self.get_input()
|
129 |
+
input = (" - " + "\n - ".join(
|
130 |
+
[c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
|
131 |
+
prompt = re.sub(r"\{input\}", re.escape(input), prompt)
|
132 |
+
|
133 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
134 |
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
135 |
"obj"].component_name.lower() == "answer":
|
agent/component/keyword.py
CHANGED
@@ -50,14 +50,11 @@ class KeywordExtract(Generate, ABC):
|
|
50 |
component_name = "KeywordExtract"
|
51 |
|
52 |
def _run(self, history, **kwargs):
|
53 |
-
|
54 |
-
|
55 |
-
if r == "user":
|
56 |
-
q += c
|
57 |
-
break
|
58 |
|
59 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
60 |
-
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content":
|
61 |
self._param.gen_conf())
|
62 |
|
63 |
ans = re.sub(r".*keyword:", "", ans).strip()
|
|
|
50 |
component_name = "KeywordExtract"
|
51 |
|
52 |
def _run(self, history, **kwargs):
|
53 |
+
query = self.get_input()
|
54 |
+
query = str(query["content"][0]) if "content" in query else ""
|
|
|
|
|
|
|
55 |
|
56 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
57 |
+
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}],
|
58 |
self._param.gen_conf())
|
59 |
|
60 |
ans = re.sub(r".*keyword:", "", ans).strip()
|
api/db/services/dialog_service.py
CHANGED
@@ -396,6 +396,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
396 |
rows = ["|" +
|
397 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
398 |
"|" for r in tbl["rows"]]
|
|
|
399 |
if quota:
|
400 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
401 |
else:
|
|
|
396 |
rows = ["|" +
|
397 |
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
398 |
"|" for r in tbl["rows"]]
|
399 |
+
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
400 |
if quota:
|
401 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
402 |
else:
|
conf/llm_factories.json
CHANGED
@@ -1299,7 +1299,7 @@
|
|
1299 |
"llm": []
|
1300 |
},
|
1301 |
{
|
1302 |
-
"name": "
|
1303 |
"logo": "",
|
1304 |
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
|
1305 |
"status": "1",
|
|
|
1299 |
"llm": []
|
1300 |
},
|
1301 |
{
|
1302 |
+
"name": "Cohere",
|
1303 |
"logo": "",
|
1304 |
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
|
1305 |
"status": "1",
|
rag/llm/__init__.py
CHANGED
@@ -39,7 +39,7 @@ EmbeddingModel = {
|
|
39 |
"NVIDIA": NvidiaEmbed,
|
40 |
"LM-Studio": LmStudioEmbed,
|
41 |
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
42 |
-
"
|
43 |
"TogetherAI": TogetherAIEmbed,
|
44 |
"PerfXCloud": PerfXCloudEmbed,
|
45 |
"Upstage": UpstageEmbed,
|
@@ -92,7 +92,7 @@ ChatModel = {
|
|
92 |
"NVIDIA": NvidiaChat,
|
93 |
"LM-Studio": LmStudioChat,
|
94 |
"OpenAI-API-Compatible": OpenAI_APIChat,
|
95 |
-
"
|
96 |
"LeptonAI": LeptonAIChat,
|
97 |
"TogetherAI": TogetherAIChat,
|
98 |
"PerfXCloud": PerfXCloudChat,
|
@@ -117,7 +117,7 @@ RerankModel = {
|
|
117 |
"NVIDIA": NvidiaRerank,
|
118 |
"LM-Studio": LmStudioRerank,
|
119 |
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
120 |
-
"
|
121 |
"TogetherAI": TogetherAIRerank,
|
122 |
"SILICONFLOW": SILICONFLOWRerank,
|
123 |
"BaiduYiyan": BaiduYiyanRerank,
|
|
|
39 |
"NVIDIA": NvidiaEmbed,
|
40 |
"LM-Studio": LmStudioEmbed,
|
41 |
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
42 |
+
"Cohere": CoHereEmbed,
|
43 |
"TogetherAI": TogetherAIEmbed,
|
44 |
"PerfXCloud": PerfXCloudEmbed,
|
45 |
"Upstage": UpstageEmbed,
|
|
|
92 |
"NVIDIA": NvidiaChat,
|
93 |
"LM-Studio": LmStudioChat,
|
94 |
"OpenAI-API-Compatible": OpenAI_APIChat,
|
95 |
+
"Cohere": CoHereChat,
|
96 |
"LeptonAI": LeptonAIChat,
|
97 |
"TogetherAI": TogetherAIChat,
|
98 |
"PerfXCloud": PerfXCloudChat,
|
|
|
117 |
"NVIDIA": NvidiaRerank,
|
118 |
"LM-Studio": LmStudioRerank,
|
119 |
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
120 |
+
"Cohere": CoHereRerank,
|
121 |
"TogetherAI": TogetherAIRerank,
|
122 |
"SILICONFLOW": SILICONFLOWRerank,
|
123 |
"BaiduYiyan": BaiduYiyanRerank,
|
rag/llm/rerank_model.py
CHANGED
@@ -394,6 +394,7 @@ class VoyageRerank(Base):
|
|
394 |
rank[r.index] = r.relevance_score
|
395 |
return rank, res.total_tokens
|
396 |
|
|
|
397 |
class QWenRerank(Base):
|
398 |
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
|
399 |
import dashscope
|
|
|
394 |
rank[r.index] = r.relevance_score
|
395 |
return rank, res.total_tokens
|
396 |
|
397 |
+
|
398 |
class QWenRerank(Base):
|
399 |
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
|
400 |
import dashscope
|