Kevin Hu commited on
Commit
70153b9
·
1 Parent(s): e9078f4

add inputs to display to every components (#3242)

Browse files

### What problem does this PR solve?

#3240

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

agent/component/base.py CHANGED
@@ -37,6 +37,7 @@ class ComponentParamBase(ABC):
37
  self.output_var_name = "output"
38
  self.message_history_window_size = 22
39
  self.query = []
 
40
 
41
  def set_name(self, name: str):
42
  self._name = name
@@ -444,8 +445,13 @@ class ComponentBase(ABC):
444
  if self._param.query:
445
  outs = []
446
  for q in self._param.query:
447
- if q["value"]: outs.append(pd.DataFrame([{"content": q["value"]}]))
448
- if q["component_id"]: outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
 
 
 
 
 
449
  if outs:
450
  df = pd.concat(outs, ignore_index=True)
451
  if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
@@ -463,31 +469,38 @@ class ComponentBase(ABC):
463
  if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
464
  o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
465
  if o is not None:
 
466
  upstream_outs.append(o)
467
  continue
468
- if u not in self._canvas.get_component(self._id)["upstream"]: continue
469
  if self.component_name.lower().find("switch") < 0 \
470
  and self.get_component_name(u) in ["relevant", "categorize"]:
471
  continue
472
  if u.lower().find("answer") >= 0:
473
  for r, c in self._canvas.history[::-1]:
474
  if r == "user":
475
- upstream_outs.append(pd.DataFrame([{"content": c}]))
476
  break
477
  break
478
  if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
479
  continue
480
  o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
481
  if o is not None:
 
482
  upstream_outs.append(o)
483
  break
484
 
485
- if upstream_outs:
486
- df = pd.concat(upstream_outs, ignore_index=True)
487
- if "content" in df:
488
- df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
489
- return df
490
- return pd.DataFrame(self._canvas.get_history(3)[-1:])
 
 
 
 
 
491
 
492
  def get_stream_input(self):
493
  reversed_cpnts = []
 
37
  self.output_var_name = "output"
38
  self.message_history_window_size = 22
39
  self.query = []
40
+ self.inputs = []
41
 
42
  def set_name(self, name: str):
43
  self._name = name
 
445
  if self._param.query:
446
  outs = []
447
  for q in self._param.query:
448
+ if q["component_id"]:
449
+ outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
450
+ self._param.inputs.append({"component_id": q["component_id"],
451
+ "content": "\n".join([str(d["content"]) for d in outs[-1].to_dict('records')])})
452
+ elif q["value"]:
453
+ self._param.inputs.append({"component_id": None, "content": q["value"]})
454
+ outs.append(pd.DataFrame([{"content": q["value"]}]))
455
  if outs:
456
  df = pd.concat(outs, ignore_index=True)
457
  if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
 
469
  if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
470
  o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
471
  if o is not None:
472
+ o["component_id"] = u
473
  upstream_outs.append(o)
474
  continue
475
+ if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
476
  if self.component_name.lower().find("switch") < 0 \
477
  and self.get_component_name(u) in ["relevant", "categorize"]:
478
  continue
479
  if u.lower().find("answer") >= 0:
480
  for r, c in self._canvas.history[::-1]:
481
  if r == "user":
482
+ upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
483
  break
484
  break
485
  if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
486
  continue
487
  o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
488
  if o is not None:
489
+ o["component_id"] = u
490
  upstream_outs.append(o)
491
  break
492
 
493
+ assert upstream_outs, "Can't inference the where the component input is."
494
+
495
+ df = pd.concat(upstream_outs, ignore_index=True)
496
+ if "content" in df:
497
+ df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
498
+
499
+ self._param.inputs = []
500
+ for _,r in df.iterrows():
501
+ self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
502
+
503
+ return df
504
 
505
  def get_stream_input(self):
506
  reversed_cpnts = []
agent/component/cite.py DELETED
@@ -1,75 +0,0 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- from abc import ABC
17
-
18
- import pandas as pd
19
-
20
- from api.db import LLMType
21
- from api.db.services.knowledgebase_service import KnowledgebaseService
22
- from api.db.services.llm_service import LLMBundle
23
- from api.settings import retrievaler
24
- from agent.component.base import ComponentBase, ComponentParamBase
25
-
26
-
27
- class CiteParam(ComponentParamBase):
28
-
29
- """
30
- Define the Retrieval component parameters.
31
- """
32
- def __init__(self):
33
- super().__init__()
34
- self.cite_sources = []
35
-
36
- def check(self):
37
- self.check_empty(self.cite_source, "Please specify where you want to cite from.")
38
-
39
-
40
- class Cite(ComponentBase, ABC):
41
- component_name = "Cite"
42
-
43
- def _run(self, history, **kwargs):
44
- input = "\n- ".join(self.get_input()["content"])
45
- sources = [self._canvas.get_component(cpn_id).output()[1] for cpn_id in self._param.cite_source]
46
- query = []
47
- for role, cnt in history[::-1][:self._param.message_history_window_size]:
48
- if role != "user":continue
49
- query.append(cnt)
50
- query = "\n".join(query)
51
-
52
- kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
53
- if not kbs:
54
- raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids))
55
- embd_nms = list(set([kb.embd_id for kb in kbs]))
56
- assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
57
-
58
- embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, embd_nms[0])
59
-
60
- rerank_mdl = None
61
- if self._param.rerank_id:
62
- rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
63
-
64
- kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
65
- 1, self._param.top_n,
66
- self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
67
- aggs=False, rerank_mdl=rerank_mdl)
68
-
69
- if not kbinfos["chunks"]: return pd.DataFrame()
70
- df = pd.DataFrame(kbinfos["chunks"])
71
- df["content"] = df["content_with_weight"]
72
- del df["content_with_weight"]
73
- return df
74
-
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/component/generate.py CHANGED
@@ -101,8 +101,8 @@ class Generate(ComponentBase):
101
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
102
  prompt = self._param.prompt
103
 
104
- retrieval_res = self.get_input()
105
- input = (" - "+"\n - ".join([c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
106
  for para in self._param.parameters:
107
  cpn = self._canvas.get_component(para["component_id"])["obj"]
108
  if cpn.component_name.lower() == "answer":
@@ -112,12 +112,24 @@ class Generate(ComponentBase):
112
  if "content" not in out.columns:
113
  kwargs[para["key"]] = "Nothing"
114
  else:
 
 
115
  kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
 
 
 
 
 
116
 
117
- kwargs["input"] = input
118
  for n, v in kwargs.items():
119
  prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
120
 
 
 
 
 
 
 
121
  downstreams = self._canvas.get_component(self._id)["downstream"]
122
  if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
123
  "obj"].component_name.lower() == "answer":
 
101
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
102
  prompt = self._param.prompt
103
 
104
+ retrieval_res = []
105
+ self._param.inputs = []
106
  for para in self._param.parameters:
107
  cpn = self._canvas.get_component(para["component_id"])["obj"]
108
  if cpn.component_name.lower() == "answer":
 
112
  if "content" not in out.columns:
113
  kwargs[para["key"]] = "Nothing"
114
  else:
115
+ if cpn.component_name.lower() == "retrieval":
116
+ retrieval_res.append(out)
117
  kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
118
+ self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
119
+
120
+ if retrieval_res:
121
+ retrieval_res = pd.concat(retrieval_res, ignore_index=True)
122
+ else: retrieval_res = pd.DataFrame([])
123
 
 
124
  for n, v in kwargs.items():
125
  prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
126
 
127
+ if not self._param.inputs and prompt.find("{input}") >= 0:
128
+ retrieval_res = self.get_input()
129
+ input = (" - " + "\n - ".join(
130
+ [c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
131
+ prompt = re.sub(r"\{input\}", re.escape(input), prompt)
132
+
133
  downstreams = self._canvas.get_component(self._id)["downstream"]
134
  if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
135
  "obj"].component_name.lower() == "answer":
agent/component/keyword.py CHANGED
@@ -50,14 +50,11 @@ class KeywordExtract(Generate, ABC):
50
  component_name = "KeywordExtract"
51
 
52
  def _run(self, history, **kwargs):
53
- q = ""
54
- for r, c in self._canvas.history[::-1]:
55
- if r == "user":
56
- q += c
57
- break
58
 
59
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
60
- ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
61
  self._param.gen_conf())
62
 
63
  ans = re.sub(r".*keyword:", "", ans).strip()
 
50
  component_name = "KeywordExtract"
51
 
52
  def _run(self, history, **kwargs):
53
+ query = self.get_input()
54
+ query = str(query["content"][0]) if "content" in query else ""
 
 
 
55
 
56
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
57
+ ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}],
58
  self._param.gen_conf())
59
 
60
  ans = re.sub(r".*keyword:", "", ans).strip()
api/db/services/dialog_service.py CHANGED
@@ -396,6 +396,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
396
  rows = ["|" +
397
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
398
  "|" for r in tbl["rows"]]
 
399
  if quota:
400
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
401
  else:
 
396
  rows = ["|" +
397
  "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
398
  "|" for r in tbl["rows"]]
399
+ rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
400
  if quota:
401
  rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
402
  else:
conf/llm_factories.json CHANGED
@@ -1299,7 +1299,7 @@
1299
  "llm": []
1300
  },
1301
  {
1302
- "name": "cohere",
1303
  "logo": "",
1304
  "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
1305
  "status": "1",
 
1299
  "llm": []
1300
  },
1301
  {
1302
+ "name": "Cohere",
1303
  "logo": "",
1304
  "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
1305
  "status": "1",
rag/llm/__init__.py CHANGED
@@ -39,7 +39,7 @@ EmbeddingModel = {
39
  "NVIDIA": NvidiaEmbed,
40
  "LM-Studio": LmStudioEmbed,
41
  "OpenAI-API-Compatible": OpenAI_APIEmbed,
42
- "cohere": CoHereEmbed,
43
  "TogetherAI": TogetherAIEmbed,
44
  "PerfXCloud": PerfXCloudEmbed,
45
  "Upstage": UpstageEmbed,
@@ -92,7 +92,7 @@ ChatModel = {
92
  "NVIDIA": NvidiaChat,
93
  "LM-Studio": LmStudioChat,
94
  "OpenAI-API-Compatible": OpenAI_APIChat,
95
- "cohere": CoHereChat,
96
  "LeptonAI": LeptonAIChat,
97
  "TogetherAI": TogetherAIChat,
98
  "PerfXCloud": PerfXCloudChat,
@@ -117,7 +117,7 @@ RerankModel = {
117
  "NVIDIA": NvidiaRerank,
118
  "LM-Studio": LmStudioRerank,
119
  "OpenAI-API-Compatible": OpenAI_APIRerank,
120
- "cohere": CoHereRerank,
121
  "TogetherAI": TogetherAIRerank,
122
  "SILICONFLOW": SILICONFLOWRerank,
123
  "BaiduYiyan": BaiduYiyanRerank,
 
39
  "NVIDIA": NvidiaEmbed,
40
  "LM-Studio": LmStudioEmbed,
41
  "OpenAI-API-Compatible": OpenAI_APIEmbed,
42
+ "Cohere": CoHereEmbed,
43
  "TogetherAI": TogetherAIEmbed,
44
  "PerfXCloud": PerfXCloudEmbed,
45
  "Upstage": UpstageEmbed,
 
92
  "NVIDIA": NvidiaChat,
93
  "LM-Studio": LmStudioChat,
94
  "OpenAI-API-Compatible": OpenAI_APIChat,
95
+ "Cohere": CoHereChat,
96
  "LeptonAI": LeptonAIChat,
97
  "TogetherAI": TogetherAIChat,
98
  "PerfXCloud": PerfXCloudChat,
 
117
  "NVIDIA": NvidiaRerank,
118
  "LM-Studio": LmStudioRerank,
119
  "OpenAI-API-Compatible": OpenAI_APIRerank,
120
+ "Cohere": CoHereRerank,
121
  "TogetherAI": TogetherAIRerank,
122
  "SILICONFLOW": SILICONFLOWRerank,
123
  "BaiduYiyan": BaiduYiyanRerank,
rag/llm/rerank_model.py CHANGED
@@ -394,6 +394,7 @@ class VoyageRerank(Base):
394
  rank[r.index] = r.relevance_score
395
  return rank, res.total_tokens
396
 
 
397
  class QWenRerank(Base):
398
  def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
399
  import dashscope
 
394
  rank[r.index] = r.relevance_score
395
  return rank, res.total_tokens
396
 
397
+
398
  class QWenRerank(Base):
399
  def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
400
  import dashscope